Skip to content

tracer

tracer

Cache

Cache(modules: Optional[List[Union[Envoy, str]]] = None, device: Optional[device] = device('cpu'), dtype: Optional[dtype] = None, detach: Optional[bool] = True, include_output: bool = True, include_inputs: bool = False, rename: Optional[Dict[str, str]] = None, alias: Optional[Dict[str, str]] = None)

A cache for storing module activations during tracing.

Persistent cache hooks (registered via :func:hooks.cache_output_hook and :func:hooks.cache_input_hook) fire on every forward pass and call :meth:add to record values. Hooks are registered by :meth:InterleavingTracer.cache when the user creates a cache, and removed automatically when the interleaver exits.

The cache applies optional transformations (detach, device, dtype) to values before storing them. Hook handles live on the owning :class:Mediator (in mediator.hooks) and are drained by :meth:Mediator.remove_hooks, not on the cache itself.

PARAMETER DESCRIPTION
modules

Optional list of modules (Envoy objects or path strings) to cache. If None, all modules are cached.

TYPE: Optional[List[Union[Envoy, str]]] DEFAULT: None

device

Device to move cached tensors to (default: CPU).

TYPE: Optional[device] DEFAULT: device('cpu')

dtype

Optional dtype to convert cached tensors to.

TYPE: Optional[dtype] DEFAULT: None

detach

Whether to detach tensors from the computation graph.

TYPE: Optional[bool] DEFAULT: True

include_output

Whether to cache module outputs.

TYPE: bool DEFAULT: True

include_inputs

Whether to cache module inputs.

TYPE: bool DEFAULT: False

rename

Rename mapping for alias path resolution.

TYPE: Optional[Dict[str, str]] DEFAULT: None

alias

Alias mapping for CacheDict attribute access.

TYPE: Optional[Dict[str, str]] DEFAULT: None

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

detach instance-attribute

detach = detach

modules instance-attribute

modules = modules

include_output instance-attribute

include_output = include_output

include_inputs instance-attribute

include_inputs = include_inputs

cache instance-attribute

cache = save()

Entry dataclass

Entry(output: Optional[Any] = None, inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None)
output class-attribute instance-attribute
output: Optional[Any] = None
inputs class-attribute instance-attribute
inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None
input property
input

Gets the first positional argument of the inputs value to the cached module. Returns None if no inputs were cached.

CacheDict

CacheDict(data: Union[CacheDict, Dict[str, Entry]], path: str = '', alias: Dict[str, str] = dict(), rename: Dict[str, str] = dict(), alias_paths: Dict[str, str] = dict())

Bases: Dict

A dictionary subclass that provides convenient access to cached module activations.

This class extends the standard dictionary to provide both dictionary-style access and attribute-style access to cached activations. It supports hierarchical access to nested modules using dot notation and indexing for module lists.

Examples:

Access cached activations using dictionary keys:

>>> cache['model.transformer.h.0.attn']

Access using attribute notation:

>>> cache.model.transformer.h[0].attn

Access module outputs and inputs:

>>> cache.model.transformer.h[0].output
>>> cache.model.transformer.h[0].inputs
>>> cache.model.transformer.h[0].input  # First input argument

The class maintains an internal path that tracks the current location in the module hierarchy, allowing for intuitive navigation through nested modules.

output property
output

Returns the output attribute from the Cache.Entry at the current path.

inputs property
inputs

Returns the inputs attribute from the Cache.Entry at the current path.

input property
input

Returns the input property from the Cache.Entry at the current path.

__iter__
__iter__()
__len__
__len__()
__contains__
__contains__(key)
keys
keys(alias: bool = False)
values
values()
items
items()
__repr__
__repr__()
__getitem__
__getitem__(key)
__getattr__
__getattr__(attr: str)
__getstate__
__getstate__()
__setstate__
__setstate__(state)

add

add(module_path: str, key: str, value: Any)

Add a value to the cache with optional transformations.

Called by persistent cache hooks registered via :func:hooks.cache_output_hook and :func:hooks.cache_input_hook.

PARAMETER DESCRIPTION
module_path

The module's envoy path (e.g. "model.transformer.h.0").

TYPE: str

key

"output" or "inputs".

TYPE: str

value

The tensor value to store.

TYPE: Any

InterleavingTracer

InterleavingTracer(fn: Callable, model: Envoy, *args, backend: Backend = None, **kwargs)

Bases: Tracer

Tracer that manages the interleaving of model execution and interventions.

This class coordinates the execution of the model's forward pass and user-defined intervention functions through the Interleaver.

PARAMETER DESCRIPTION
fn

The function to execute (typically the model's forward pass)

TYPE: Callable

model

The model envoy to intervene on

TYPE: Envoy

*args

Additional arguments to pass to the function

DEFAULT: ()

**kwargs

Additional keyword arguments to pass to the function

DEFAULT: {}

fn instance-attribute

fn = fn

model instance-attribute

model = model

mediators instance-attribute

mediators: List[Mediator] = []

batcher instance-attribute

batcher = _batcher_class()(*args, **kwargs)

interleaver property

interleaver: Interleaver

iter property

iter

__exit__

__exit__(exc_type, exc_value, traceback)

capture

capture(frame=None)

Capture the code block within the 'with' statement.

compile

compile() -> Callable

Compile the captured code block into a callable function.

RETURNS DESCRIPTION
Callable

A callable function that executes the captured code block

get_frame

get_frame()

Get the frame of the tracer.

execute

execute(fn: Callable)

First executes the parent Tracer's execute method to set up the context, then creates an Interleaver to manage the interventions during model execution.

invoke

invoke(*args, **kwargs)

Create an Invoker to capture and execute an intervention function.

PARAMETER DESCRIPTION
*args

Additional arguments to pass to the intervention function

DEFAULT: ()

**kwargs

Additional keyword arguments to pass to the intervention function

DEFAULT: {}

RETURNS DESCRIPTION

An Invoker instance

stop

stop()

Raise an EarlyStopException to stop the execution of the model.

all

all()

next

next(step: int = 1)

cache

cache(modules: Optional[List[Union[Envoy, str]]] = None, device: Optional[device] = device('cpu'), dtype: Optional[dtype] = None, detach: Optional[bool] = True, include_output: bool = True, include_inputs: bool = False) -> Union[Dict, Object]

Create a cache that records module activations during execution.

Registers persistent hooks on the target modules (all modules if modules is None, otherwise only the specified subset). The hooks fire after any intervention hooks (mediator_idx=inf) so they capture post-intervention values. Hooks persist across generation steps and are automatically removed when the interleaver exits.

PARAMETER DESCRIPTION
modules

Modules to cache — Envoy objects or path strings. If None, caches all modules in the model.

TYPE: Optional[List[Union[Envoy, str]]] DEFAULT: None

device

Device to move cached tensors to (default: CPU).

TYPE: Optional[device] DEFAULT: device('cpu')

dtype

Optional dtype to convert cached tensors to.

TYPE: Optional[dtype] DEFAULT: None

detach

Whether to detach tensors from the computation graph.

TYPE: Optional[bool] DEFAULT: True

include_output

Whether to cache module outputs.

TYPE: bool DEFAULT: True

include_inputs

Whether to cache module inputs.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
A

class:Cache.CacheDict that is populated during execution.

TYPE: Union[Dict, Object]

barrier

barrier(n_participants: int)

nnsight barrier: A synchronization primitive for coordinating multiple concurrent invocations in nnsight.

This works similarly to a threading.Barrier, but is designed for use with nnsight's model tracing and intervention system. A barrier allows you to pause execution in multiple parallel invocations until all participants have reached the barrier, at which point all are released to continue. This is useful when you want to synchronize the execution of different model runs, for example to ensure that all have reached a certain point (such as after embedding lookup) before proceeding to the next stage (such as generation or intervention).

Example usage:

with gpt2.generate(max_new_tokens=3) as tracer:
    barrier = tracer.barrier(2)

    with tracer.invoke(MSG_prompt):
        embeddings = gpt2.transformer.wte.output
        barrier()
        output1 = gpt2.generator.output.save()

    with tracer.invoke("_ _ _ _ _ _ _ _ _"):
        barrier()
        gpt2.transformer.wte.output = embeddings
        output2 = gpt2.generator.output.save()

In this example, both invocations will pause at the barrier until both have reached it, ensuring synchronization.

result

result() -> Object

The return value of the method being traced.

__getstate__

__getstate__()

Get the state of the tracer for serialization.

__setstate__

__setstate__(state)

Set the state of the tracer for deserialization.

ScanningTracer

ScanningTracer(fn: Callable, model: Envoy, *args, backend: Backend = None, **kwargs)

Bases: InterleavingTracer

A tracer that runs the model in fake tensor mode to validate operations and inspect tensor shapes.

This tracer uses PyTorch's FakeTensorMode to run the model without actual computation, allowing for shape validation and operation checking.

execute

execute(fn: Callable)

Execute the model in fake tensor mode.

This method: 1. Runs the model in fake tensor mode to validate operations 2. Allows intervention code inside the scan context to access fake tensor shapes

PARAMETER DESCRIPTION
fn

The function to execute (typically the model's forward pass)

TYPE: Callable

Barrier

Barrier(model: Envoy, n_participants: int)

model instance-attribute

model = model

n_participants instance-attribute

n_participants = n_participants

participants instance-attribute

participants: Set[str] = set()

__call__

__call__()