Primitives

A primitive is a named operation that the IR records instead of executing inline during tracing. Examples include format, concat, lm_call, switch, checkpoint, and factor.

The name matters because transforms dispatch on primitive identity. pullback knows how to route feedback through the lm_call primitive because a rule is registered for it. Plain Python operations do not have those rules, so they either run at trace time or fail when they need a concrete runtime value.

Rule Registries

Every primitive can have rules for different phases and transforms:

  • impl_rules: synchronous execution.

  • abstract_rules: output-shape and output-type inference while tracing.

  • batch_rules: vectorized behavior for batch.

  • push_rules: forward-mode behavior for pushforward.

  • pull_fwd_rules: the forward sweep for pullback.

  • pull_bwd_rules: the backward sweep for pullback.

The split pullback rules matter: the forward sweep records the values needed later, and the backward sweep uses those residuals plus the cotangent to produce input cotangents.

Public Primitive Groups

String

  • format: traceable string formatting.

  • concat: traceable string concatenation.

  • match: traceable string equality.

LM

Control Flow

  • switch: choose one traced branch at execution time.

  • while_loop: run a traced loop with an explicit iteration cap.

  • stop_gradient: pass x forward but block cotangents in pullback.

  • depends: make a returned result wait for extra dependencies without changing its value.

Intercepts

Trace Weight

  • factor: multiply the current path weight. Ordinary execution treats it as a no-output effect; weighted returns the accumulated path weight.

Primitive Definitions

Defining a primitive means defining its behavior under execution, tracing, batching, pushforward, pullback, and sometimes DCE. Most user code should not do that.

Use Write a Primitive when an operation cannot run on traced values and must still appear as one IR equation.