Batch Helpers¶
- autoform.extend.batch_index(in_tree, in_batched, b, /)[source]¶
Extract one batch item from a tree.
Batched leaves are indexed at
b. Non-batched leaves are broadcast by returning the original leaf unchanged.- Parameters:
- Returns:
A tree with the same structure as
in_batchedcontaining one batch item.- Return type:
Tree
Example
>>> batch_index(([1, 2], "x"), (True, False), 0) (1, 'x')
- autoform.extend.batch_spec(in_tree, in_batched, /)[source]¶
Return the common output container spec for batched leaves.
Batch rules use this to decide whether an operation has any batched inputs and how per-example results should be repacked. If no input leaf is batched,
Noneis returned.- Parameters:
in_tree (Tree) – Tree containing batched and non-batched leaves.
in_batched (Tree[bool]) – Tree of booleans with the same outer structure as
in_tree.Truemarks a batched leaf.
- Returns:
The common pytree spec of all batched leaves, or
Noneif there are no batched leaves.- Raises:
AssertionError – If batched leaves do not share the same container structure.
- Return type:
PyTreeSpec | None