batching¶
batching
¶
Batching support for multi-invoke traces.
This module provides the batching infrastructure that enables multiple invokes to share a single forward pass. Each invoke's input is combined into one batch, and during interleaving each invoke sees only its slice of the batch.
There are two types of invokes:
-
Input invokes:
tracer.invoke(input)— provides input data that contributes to the batch. Each input invoke gets abatch_group = [start, size]that specifies its slice of the batch dimension. -
Empty invokes:
tracer.invoke()(no arguments) — operates on the entire batch from all previous input invokes. Empty invokes getbatch_group = None, sonarrow()returns the full batch andswap()replaces the full batch. This is useful for running different intervention logic on the combined batch, or for breaking up interventions across multiple invokes to avoid execution-order conflicts within a single invoke.
To support multiple input invokes, model classes must subclass :class:Batchable
and implement :meth:_prepare_input and :meth:_batch. See
nnsight.modeling.language.LanguageModel for a reference implementation.
Without these methods, you can still use one input invoke and any number of
empty invokes.
Batchable
¶
Abstract mixin that defines how a model's inputs are prepared and batched.
Subclasses should override :meth:_prepare_input and :meth:_batch to
enable multiple input invokes in a single trace. The base Envoy class
inherits from Batchable but does not override these methods, so only
a single input invoke is supported by default.
See LanguageModel for a full implementation that handles tokenization,
padding, and attention mask construction.
Batcher
¶
Manages input batching and per-invoke slicing for a single trace.
One Batcher is created per trace. As invokes are defined, :meth:batch
accumulates their inputs into a single batch. During interleaving,
:meth:narrow extracts each invoke's slice and :meth:swap replaces it.
| ATTRIBUTE | DESCRIPTION |
|---|---|
batched_args |
Combined positional arguments from all input invokes.
|
batched_kwargs |
Combined keyword arguments from all input invokes.
|
last_batch_group |
The
TYPE:
|
needs_batching |
True once there are 2+ input invokes (narrowing needed).
|
current_value |
The current activation value being narrowed/swapped.
TYPE:
|
batch
¶
batch(batchable: Batchable, *args, **kwargs) -> Tuple[Tuple[Any, Any], Optional[List[int]]]
Register an invoke's input and return its batch group.
For input invokes (args/kwargs provided), this calls
batchable._prepare_input() and optionally batchable._batch()
to merge with previous inputs. For empty invokes (no args), returns
batch_group=None which tells :meth:narrow to return the full batch.
| PARAMETER | DESCRIPTION |
|---|---|
batchable
|
The model instance (implements :class:
TYPE:
|
*args
|
Positional arguments from the invoke.
DEFAULT:
|
**kwargs
|
Keyword arguments from the invoke.
DEFAULT:
|
| RETURNS | DESCRIPTION |
|---|---|
Tuple[Any, Any]
|
A 2-tuple of |
Optional[List[int]]
|
|
Tuple[Tuple[Any, Any], Optional[List[int]]]
|
or |
narrow
¶
Extract an invoke's slice from the current activation value.
For input invokes, narrows each tensor along dimension 0 using the
invoke's batch_group = [start, size]. For empty invokes
(batch_group=None), returns the entire batch unmodified.
| PARAMETER | DESCRIPTION |
|---|---|
batch_group
|
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
The narrowed (or full) activation data. |
swap
¶
Replace an invoke's slice in the current activation value.
For input invokes, splices swap_value into the correct batch
slice. For empty invokes (batch_group=None), replaces the
entire current_value.
Handles two cases for tensor replacement:
- If the tensor is a leaf with requires_grad or has a base tensor
(view), uses torch.cat to avoid in-place modification issues.
- Otherwise, uses direct index assignment for efficiency.
| PARAMETER | DESCRIPTION |
|---|---|
batch_group
|
TYPE:
|
swap_value
|
The new value to splice in.
TYPE:
|
DiffusionBatcher
¶
Bases: Batcher
A specialized batcher for diffusion models that handles multiple images per prompt and guided diffusion.
This class extends the base Batcher to support diffusion model-specific batching scenarios, including multiple images per prompt and guided diffusion with conditional/unconditional guidance.
The DiffusionBatcher handles three main tensor batch size scenarios:
1. Regular batch size (total_batch_size)
2. Image batch size (total_batch_size * num_images_per_prompt)
3. Guided diffusion batch size (total_batch_size * num_images_per_prompt * 2)
| ATTRIBUTE | DESCRIPTION |
|---|---|
num_images_per_prompt |
Number of images to generate per prompt. Defaults to 1.
TYPE:
|
image_batch_groups |
Batch groups scaled for multiple images per prompt, where each tuple contains (batch_start, batch_size).
TYPE:
|
num_images_per_prompt
instance-attribute
¶
last_image_batch_group
instance-attribute
¶
batch
¶
batch(batchable: Batchable, *args, **kwargs) -> Tuple[Tuple[Any, Any], Optional[List[int]]]
Batch inputs for diffusion models, accounting for multiple images per prompt.
This method extends the base batcher functionality to handle diffusion models that can generate multiple images per prompt. It creates image-specific batch groups by scaling the regular batch groups by the number of images per prompt.
| PARAMETER | DESCRIPTION |
|---|---|
batchable
|
The batchable object that implements the batching interface.
TYPE:
|
*args
|
Variable length argument list to be batched.
DEFAULT:
|
**kwargs
|
Arbitrary keyword arguments to be batched.
DEFAULT:
|
| RETURNS | DESCRIPTION |
|---|---|
Tuple[Tuple[Any, Any], Optional[List[int]]]
|
Tuple[Tuple[Any, Any], Union[int, None]]: A tuple containing: - A tuple of (batched_args, batched_kwargs) - The batch group index (int) or None if no batching was needed |