Custom Boundaries¶
custom marks a traceable Python function as a boundary and registers per-transform rules for that boundary.
- autoform.custom(func, /)[source]¶
Mark a traceable Python function as a custom Autoform transform boundary.
customis a decorator for traceable functions that should keep their ordinary call behavior while optionally overriding how Autoform transforms treat them. Without any registered rules,pushforward,pullback, andbatchproduce the same results as transforming the function body directly.The returned wrapper supports these rule registration decorators:
set_pushforward(rule)for a synchronous pushforward rule.aset_pushforward(rule)for an asynchronous pushforward rule.set_pullback(rule)for a synchronous pullback backward rule.aset_pullback(rule)for an asynchronous pullback backward rule.set_batch(rule)for a synchronous batch rule.aset_batch(rule)for an asynchronous batch rule.
Rules receive one positional
in_treeargument. The original behavior is available as the keyword-onlycallargument, so a rule can usecall(*primals)when it wants to reuse the normal primal behavior. The rule signatures are:Pushforward:
rule((primals, tangents), /, *, call) -> (p_out, t_out).Pullback backward:
rule(((primals, output), cotangent), /, *, call) -> cotangents.Batch:
rule((batch_size, axes, values), /, *, call) -> (outputs, output_axes).
Synchronous and asynchronous registrations are independent. Use both
set_*andaset_*when both execution modes need custom behavior.- Parameters:
func (Callable[[...], Any]) – Function to wrap. The function body may use Autoform primitives and any normal Python structure that is valid while tracing.
- Returns:
A callable wrapper with the same call behavior as
funcand rule registration methods. The concrete wrapper class is an implementation detail; use only the returned callable and itsset_*/aset_*methods.- Return type:
CustomFunc
Example
Direct calls behave like calls to the original function.
>>> import autoform as af >>> @af.custom ... def bracket(x): ... return af.format("[{}]", x) >>> bracket("hello") '[hello]'
Without custom rules, transforms behave as if they were applied to the function body.
>>> base = af.trace(lambda x: bracket(x))("seed") >>> af.pushforward(base).call(("hello",), ("change",)) ('[hello]', '[change]') >>> af.pullback(base).call(("hello",), "feedback") ('[hello]', ('feedback',)) >>> af.batch(base).call(["a", "b"]) ['[a]', '[b]']
A pushforward rule can replace only the pushforward behavior.
>>> @bracket.set_pushforward ... def bracket_push(in_tree, /, *, call): ... primals, tangents = in_tree ... (dx,) = tangents ... p_out = call(*primals) ... t_out = af.format("delta: {}", af.ad.materialize(dx)) ... return p_out, t_out >>> af.pushforward(base).call(("hello",), ("change",)) ('[hello]', 'delta: change')
Pullback and batch rules use the same one-
in_treeconvention.>>> @bracket.set_pullback ... def bracket_pull(in_tree, /, *, call): ... del call ... (primals, output), cotangent = in_tree ... (x,) = primals ... return (af.format("{} via {} from {}", cotangent, output, x),) >>> af.pullback(base).call(("hello",), "feedback") ('[hello]', ('feedback via [hello] from hello',))
>>> @bracket.set_batch ... def bracket_batch(in_tree, /, *, call): ... del call ... batch_size, axes, values = in_tree ... assert batch_size == 2 ... (xs,) = values ... (x_axis,) = axes ... assert x_axis is True ... return [af.format("<{}>", x) for x in xs], True >>> af.batch(base).call(["a", "b"]) ['<a>', '<b>']
A common use is wrapping an LM call so the forward call remains normal, while pushforward and pullback use prompts written for that application.
>>> from autoform.ad import materialize >>> @af.custom ... def summarize(text, model): ... message = af.format("Summarize this in one sentence: {}", text) ... return af.lm_call([{"role": "user", "content": message}], model=model)
The custom pushforward rule can ask the model how the output should change under an input edit.
>>> @summarize.set_pushforward ... def summarize_push(in_tree, /, *, call): ... primals, tangents = in_tree ... text, model = primals ... text_tangent, _ = tangents ... p_out = call(*primals) ... prompt = af.format( ... "Original input:\n{}\n\nInput edit:\n{}\n\n" ... "Describe how the summary should change.", ... text, ... materialize(text_tangent), ... ) ... t_out = af.lm_call([{"role": "user", "content": prompt}], model=model) ... return p_out, t_out
The custom pullback rule can replace the default backward LM prompt with a domain-specific feedback prompt.
>>> @summarize.set_pullback ... def summarize_pull(in_tree, /, *, call): ... del call ... (primals, output), cotangent = in_tree ... text, model = primals ... prompt = af.format( ... "Original input:\n{}\n\nLM output:\n{}\n\n" ... "Downstream feedback:\n{}\n\n" ... "Return feedback for improving the original input.", ... text, ... output, ... materialize(cotangent), ... ) ... text_cotangent = af.lm_call([{"role": "user", "content": prompt}], model=model) ... return text_cotangent, ""
>>> lm_ir = af.trace(lambda text, model: summarize(text, model))("topic", "model") >>> af.pushforward(lm_ir).call( ... ("recursion", "gpt-5.5"), ... ("focus on the recursive step", ""), ... ) >>> af.pullback(lm_ir).call( ... ("recursion", "gpt-5.5"), ... "make the answer more concrete", ... )