Custom Boundaries

custom marks a traceable Python function as a boundary and registers per-transform rules for that boundary.

autoform.custom(func, /)[source]

Mark a traceable Python function as a custom Autoform transform boundary.

custom is a decorator for traceable functions that should keep their ordinary call behavior while optionally overriding how Autoform transforms treat them. Without any registered rules, pushforward, pullback, and batch produce the same results as transforming the function body directly.

The returned wrapper supports these rule registration decorators:

  • set_pushforward(rule) for a synchronous pushforward rule.

  • aset_pushforward(rule) for an asynchronous pushforward rule.

  • set_pullback(rule) for a synchronous pullback backward rule.

  • aset_pullback(rule) for an asynchronous pullback backward rule.

  • set_batch(rule) for a synchronous batch rule.

  • aset_batch(rule) for an asynchronous batch rule.

Rules receive one positional in_tree argument. The original behavior is available as the keyword-only call argument, so a rule can use call(*primals) when it wants to reuse the normal primal behavior. The rule signatures are:

  • Pushforward: rule((primals, tangents), /, *, call) -> (p_out, t_out).

  • Pullback backward: rule(((primals, output), cotangent), /, *, call) -> cotangents.

  • Batch: rule((batch_size, axes, values), /, *, call) -> (outputs, output_axes).

Synchronous and asynchronous registrations are independent. Use both set_* and aset_* when both execution modes need custom behavior.

Parameters:

func (Callable[[...], Any]) – Function to wrap. The function body may use Autoform primitives and any normal Python structure that is valid while tracing.

Returns:

A callable wrapper with the same call behavior as func and rule registration methods. The concrete wrapper class is an implementation detail; use only the returned callable and its set_*/aset_* methods.

Return type:

CustomFunc

Example

Direct calls behave like calls to the original function.

>>> import autoform as af
>>> @af.custom
... def bracket(x):
...     return af.format("[{}]", x)
>>> bracket("hello")
'[hello]'

Without custom rules, transforms behave as if they were applied to the function body.

>>> base = af.trace(lambda x: bracket(x))("seed")
>>> af.pushforward(base).call(("hello",), ("change",))
('[hello]', '[change]')
>>> af.pullback(base).call(("hello",), "feedback")
('[hello]', ('feedback',))
>>> af.batch(base).call(["a", "b"])
['[a]', '[b]']

A pushforward rule can replace only the pushforward behavior.

>>> @bracket.set_pushforward
... def bracket_push(in_tree, /, *, call):
...     primals, tangents = in_tree
...     (dx,) = tangents
...     p_out = call(*primals)
...     t_out = af.format("delta: {}", af.ad.materialize(dx))
...     return p_out, t_out
>>> af.pushforward(base).call(("hello",), ("change",))
('[hello]', 'delta: change')

Pullback and batch rules use the same one-in_tree convention.

>>> @bracket.set_pullback
... def bracket_pull(in_tree, /, *, call):
...     del call
...     (primals, output), cotangent = in_tree
...     (x,) = primals
...     return (af.format("{} via {} from {}", cotangent, output, x),)
>>> af.pullback(base).call(("hello",), "feedback")
('[hello]', ('feedback via [hello] from hello',))
>>> @bracket.set_batch
... def bracket_batch(in_tree, /, *, call):
...     del call
...     batch_size, axes, values = in_tree
...     assert batch_size == 2
...     (xs,) = values
...     (x_axis,) = axes
...     assert x_axis is True
...     return [af.format("<{}>", x) for x in xs], True
>>> af.batch(base).call(["a", "b"])
['<a>', '<b>']

A common use is wrapping an LM call so the forward call remains normal, while pushforward and pullback use prompts written for that application.

>>> from autoform.ad import materialize
>>> @af.custom
... def summarize(text, model):
...     message = af.format("Summarize this in one sentence: {}", text)
...     return af.lm_call([{"role": "user", "content": message}], model=model)

The custom pushforward rule can ask the model how the output should change under an input edit.

>>> @summarize.set_pushforward
... def summarize_push(in_tree, /, *, call):
...     primals, tangents = in_tree
...     text, model = primals
...     text_tangent, _ = tangents
...     p_out = call(*primals)
...     prompt = af.format(
...         "Original input:\n{}\n\nInput edit:\n{}\n\n"
...         "Describe how the summary should change.",
...         text,
...         materialize(text_tangent),
...     )
...     t_out = af.lm_call([{"role": "user", "content": prompt}], model=model)
...     return p_out, t_out

The custom pullback rule can replace the default backward LM prompt with a domain-specific feedback prompt.

>>> @summarize.set_pullback
... def summarize_pull(in_tree, /, *, call):
...     del call
...     (primals, output), cotangent = in_tree
...     text, model = primals
...     prompt = af.format(
...         "Original input:\n{}\n\nLM output:\n{}\n\n"
...         "Downstream feedback:\n{}\n\n"
...         "Return feedback for improving the original input.",
...         text,
...         output,
...         materialize(cotangent),
...     )
...     text_cotangent = af.lm_call([{"role": "user", "content": prompt}], model=model)
...     return text_cotangent, ""
>>> lm_ir = af.trace(lambda text, model: summarize(text, model))("topic", "model")
>>> af.pushforward(lm_ir).call(  
...     ("recursion", "gpt-5.5"),
...     ("focus on the recursive step", ""),
... )
>>> af.pullback(lm_ir).call(  
...     ("recursion", "gpt-5.5"),
...     "make the answer more concrete",
... )