Skip to content

batching

batching

Batching support for multi-invoke traces.

This module provides the batching infrastructure that enables multiple invokes to share a single forward pass. Each invoke's input is combined into one batch, and during interleaving each invoke sees only its slice of the batch.

There are two types of invokes:

  • Input invokes: tracer.invoke(input) — provides input data that contributes to the batch. Each input invoke gets a batch_group = [start, size] that specifies its slice of the batch dimension.

  • Empty invokes: tracer.invoke() (no arguments) — operates on the entire batch from all previous input invokes. Empty invokes get batch_group = None, so narrow() returns the full batch and swap() replaces the full batch. This is useful for running different intervention logic on the combined batch, or for breaking up interventions across multiple invokes to avoid execution-order conflicts within a single invoke.

To support multiple input invokes, model classes must subclass :class:Batchable and implement :meth:_prepare_input and :meth:_batch. See nnsight.modeling.language.LanguageModel for a reference implementation. Without these methods, you can still use one input invoke and any number of empty invokes.

Batchable

Abstract mixin that defines how a model's inputs are prepared and batched.

Subclasses should override :meth:_prepare_input and :meth:_batch to enable multiple input invokes in a single trace. The base Envoy class inherits from Batchable but does not override these methods, so only a single input invoke is supported by default.

See LanguageModel for a full implementation that handles tokenization, padding, and attention mask construction.

Batcher

Batcher(*args, **kwargs)

Manages input batching and per-invoke slicing for a single trace.

One Batcher is created per trace. As invokes are defined, :meth:batch accumulates their inputs into a single batch. During interleaving, :meth:narrow extracts each invoke's slice and :meth:swap replaces it.

ATTRIBUTE DESCRIPTION
batched_args

Combined positional arguments from all input invokes.

batched_kwargs

Combined keyword arguments from all input invokes.

last_batch_group

The [start, size] for the most recent input invoke.

TYPE: Optional[List[int]]

needs_batching

True once there are 2+ input invokes (narrowing needed).

current_value

The current activation value being narrowed/swapped.

TYPE: Optional[Any]

batched_args instance-attribute

batched_args = None

batched_kwargs instance-attribute

batched_kwargs = None

last_batch_group instance-attribute

last_batch_group: Optional[List[int]] = None

needs_batching instance-attribute

needs_batching = False

current_value instance-attribute

current_value: Optional[Any] = None

total_batch_size property

total_batch_size

Total number of samples across all input invokes.

batch

batch(batchable: Batchable, *args, **kwargs) -> Tuple[Tuple[Any, Any], Optional[List[int]]]

Register an invoke's input and return its batch group.

For input invokes (args/kwargs provided), this calls batchable._prepare_input() and optionally batchable._batch() to merge with previous inputs. For empty invokes (no args), returns batch_group=None which tells :meth:narrow to return the full batch.

PARAMETER DESCRIPTION
batchable

The model instance (implements :class:Batchable).

TYPE: Batchable

*args

Positional arguments from the invoke.

DEFAULT: ()

**kwargs

Keyword arguments from the invoke.

DEFAULT: {}

RETURNS DESCRIPTION
Tuple[Any, Any]

A 2-tuple of ((args, kwargs), batch_group) where

Optional[List[int]]

batch_group is [start_idx, batch_size] for input invokes

Tuple[Tuple[Any, Any], Optional[List[int]]]

or None for empty invokes.

narrow

narrow(batch_group: Optional[List[int]]) -> Any

Extract an invoke's slice from the current activation value.

For input invokes, narrows each tensor along dimension 0 using the invoke's batch_group = [start, size]. For empty invokes (batch_group=None), returns the entire batch unmodified.

PARAMETER DESCRIPTION
batch_group

[start_idx, batch_size] for input invokes, or None for empty invokes (returns full batch).

TYPE: Optional[List[int]]

RETURNS DESCRIPTION
Any

The narrowed (or full) activation data.

swap

swap(batch_group: Optional[List[int]], swap_value: Any) -> None

Replace an invoke's slice in the current activation value.

For input invokes, splices swap_value into the correct batch slice. For empty invokes (batch_group=None), replaces the entire current_value.

Handles two cases for tensor replacement: - If the tensor is a leaf with requires_grad or has a base tensor (view), uses torch.cat to avoid in-place modification issues. - Otherwise, uses direct index assignment for efficiency.

PARAMETER DESCRIPTION
batch_group

[start_idx, batch_size] for input invokes, or None for empty invokes (replaces full batch).

TYPE: Optional[List[int]]

swap_value

The new value to splice in.

TYPE: Any

DiffusionBatcher

DiffusionBatcher(*args, **kwargs)

Bases: Batcher

A specialized batcher for diffusion models that handles multiple images per prompt and guided diffusion.

This class extends the base Batcher to support diffusion model-specific batching scenarios, including multiple images per prompt and guided diffusion with conditional/unconditional guidance.

The DiffusionBatcher handles three main tensor batch size scenarios: 1. Regular batch size (total_batch_size) 2. Image batch size (total_batch_size * num_images_per_prompt)
3. Guided diffusion batch size (total_batch_size * num_images_per_prompt * 2)

ATTRIBUTE DESCRIPTION
num_images_per_prompt

Number of images to generate per prompt. Defaults to 1.

TYPE: int

image_batch_groups

Batch groups scaled for multiple images per prompt, where each tuple contains (batch_start, batch_size).

TYPE: List[Tuple[int, int]]

num_images_per_prompt instance-attribute

num_images_per_prompt: int = get('num_images_per_prompt', 1)

last_image_batch_group instance-attribute

last_image_batch_group: Optional[Tuple[int, int]] = None

image_batch_groups property

image_batch_groups: Dict[int, Tuple[int, int]]

batch

batch(batchable: Batchable, *args, **kwargs) -> Tuple[Tuple[Any, Any], Optional[List[int]]]

Batch inputs for diffusion models, accounting for multiple images per prompt.

This method extends the base batcher functionality to handle diffusion models that can generate multiple images per prompt. It creates image-specific batch groups by scaling the regular batch groups by the number of images per prompt.

PARAMETER DESCRIPTION
batchable

The batchable object that implements the batching interface.

TYPE: Batchable

*args

Variable length argument list to be batched.

DEFAULT: ()

**kwargs

Arbitrary keyword arguments to be batched.

DEFAULT: {}

RETURNS DESCRIPTION
Tuple[Tuple[Any, Any], Optional[List[int]]]

Tuple[Tuple[Any, Any], Union[int, None]]: A tuple containing: - A tuple of (batched_args, batched_kwargs) - The batch group index (int) or None if no batching was needed