Batch Helpers

autoform.extend.batch_index(in_tree, in_batched, b, /)[source]

Extract one batch item from a tree.

Batched leaves are indexed at b. Non-batched leaves are broadcast by returning the original leaf unchanged.

Parameters:
  • in_tree (Tree) – Tree containing batched and non-batched leaves.

  • in_batched (Tree[bool]) – Tree of booleans with the same outer structure as in_tree. True marks a batched leaf.

  • b (int) – Batch index to extract.

Returns:

A tree with the same structure as in_batched containing one batch item.

Return type:

Tree

Example

>>> batch_index(([1, 2], "x"), (True, False), 0)
(1, 'x')
autoform.extend.batch_spec(in_tree, in_batched, /)[source]

Return the common output container spec for batched leaves.

Batch rules use this to decide whether an operation has any batched inputs and how per-example results should be repacked. If no input leaf is batched, None is returned.

Parameters:
  • in_tree (Tree) – Tree containing batched and non-batched leaves.

  • in_batched (Tree[bool]) – Tree of booleans with the same outer structure as in_tree. True marks a batched leaf.

Returns:

The common pytree spec of all batched leaves, or None if there are no batched leaves.

Raises:

AssertionError – If batched leaves do not share the same container structure.

Return type:

PyTreeSpec | None

autoform.extend.batch_transpose(batch_size, in_batched, in_tree, /)[source]
Parameters:
  • batch_size (int)

  • in_batched (Tree[bool])

  • in_tree (Tree)

Return type:

Tree