IR Transforms

These functions take an IR and return another IR.

autoform.batch(ir, /, *, in_axes=True)[source]

Transform an IR to process batched inputs.

Creates a batched version of the IR that processes multiple inputs simultaneously. Use in_axes to specify which inputs are batched (True) vs broadcast (False).

Parameters:
  • ir (IR) – The IR to transform.

  • in_axes (Tree[bool]) – Axis specification tree matching input structure. - True: This input is batched (a collection of values). - False: This input is broadcast (same value for all batch items).

Returns:

A new IR that takes batched inputs and returns batched outputs.

Return type:

IR

Example

>>> import autoform as af
>>> def greet(greeting, name):
...     return af.concat(greeting, name)
>>> ir = af.trace(greet)("Hi", "World")
>>> # Batch over names, broadcast greeting
>>> batched = af.batch(ir, in_axes=(False, True))
>>> batched.call("Hello, ", ["x0", "x1", "x2"])
['Hello, x0', 'Hello, x1', 'Hello, x2']
autoform.pullback(ir, /)[source]

Transform an IR to compute outputs and input cotangents (reverse-mode AD).

Creates a new IR that computes gradients by backpropagating cotangent (adjoint).

Parameters:

ir (IR) – The IR to transform.

Returns:

(inputs, output_cotangents) -> (outputs, input_cotangents)

Return type:

A new IR

Example

>>> import autoform as af
>>> def program(x, y):
...     return af.concat(x, y)
>>> ir = af.trace(program)("a", "b")
>>> pb_ir = af.pullback(ir)
>>> outputs, cotangents = pb_ir.call(("Hello", " World"), "feedback")
>>> outputs
'Hello World'
>>> cotangents  # Gradient flows back to both inputs
('feedback', 'feedback')
autoform.pushforward(ir, /)[source]

Transform an IR to compute primals and tangents (forward-mode AD).

Creates a new IR that propagates tangent (perturbation) alongside primal values.

Parameters:

ir (IR) – The IR to transform.

Returns:

(p_in, t_in) -> (p_out, t_out)

Return type:

A new IR

Example

>>> import autoform as af
>>> def program(x, y):
...     return af.concat(x, y)
>>> ir = af.trace(program)("a", "b")
>>> pf_ir = af.pushforward(ir)
>>> p_out, t_out = pf_ir.call(("Hello", " World"), ("dx", "dy"))
>>> p_out
'Hello World'
>>> t_out
'dxdy'
autoform.sched(ir, /, *, cond=None)[source]

Schedule independent operations for parallel execution.

Parameters:
  • ir (IR[Unpack[A], R]) – The IR to schedule.

  • cond (Callable[[IREqn], bool] | None) – Predicate that takes an IR Equation and returns True if the equation should be parallelized. If None, all operations are candidates for parallelization.

Returns:

A new IR with independent operations grouped together for parallel execution.

Return type:

IR[Unpack[A], R]

Example

>>> import autoform as af
>>> import asyncio
>>>
>>> def parallel_calls(x):
...     msg1 = [dict(role="user", content=af.format("Q1: {}", x))]
...     msg2 = [dict(role="user", content=af.format("Q2: {}", x))]
...     a = af.lm_call(msg1, model="gpt-5.5")
...     b = af.lm_call(msg2, model="gpt-5.5")
...     return af.concat(a, b)
>>>
>>> ir = af.trace(parallel_calls)("input")
>>> scheduled = af.sched(ir)
>>>
>>> # sync execution (sequential)
>>> result = scheduled.call("hello") 
>>>
>>> # async execution (concurrent via asyncio.gather)
>>> result = asyncio.run(scheduled.acall("hello")) 
autoform.dce(ir, /, *, out_used=None)[source]

Remove dead code from an IR.

Performs backward pass to identify which equations contribute to output.

Parameters:
  • ir (IR[Unpack[A], R]) – The IR to optimize.

  • out_used (UsedTree | None) – A pytree of bool matching the ir output pytree that denotes which output is used.

Return type:

IR[Unpack[A], R]

Example

>>> import autoform as af
>>> def program(x):
...     dead = af.concat(x, " dead")  # unused
...     live = af.concat(x, " live")  # returned
...     return live
>>> ir = af.trace(program)("test")
>>> len(ir.ir_eqns)
2
>>> dced = af.dce(ir)
>>> len(dced.ir_eqns)
1
autoform.weighted(ir, /)[source]

Transform an IR to return (output, path_weight) for one path.

Parameters:

ir (IR)

Return type:

IR