Info
Last Execution: 2026-04-14
| Package | Version |
|---|---|
| nnsight | 0.6.3 |
| Python | 3.12.12 |
| torch | 2.10.0+cu128 |
| transformers | 4.57.6 |
Cache¶
tracer.cache() collects activations from many modules in a single forward pass without writing a .save() for each one. It registers persistent post-intervention hooks on the modules you ask for, so values are recorded directly as the model runs.
If all you want to do is capture activations (no editing, no conditional logic), tracer.cache() is the fastest way to do it — faster than calling .output.save() on each module. See Why cache is fast below.
Setup¶
import torch
import nnsight
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
/disk/u/jadenfk/miniconda3/envs/nn6/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Caching All Modules¶
Calling tracer.cache() with no arguments captures the output of every module in the model. The returned object is a dict-like container keyed by module path.
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache()
print(f"Modules cached: {len(cache.keys())}")
print(f"First few keys: {list(cache.keys())[:5]}")
print(f"Layer 0 output shape: {cache['model.transformer.h.0'].output[0].shape}")
Modules cached: 151 First few keys: ['model.transformer.wte', 'model.transformer.wpe', 'model.transformer.drop', 'model.transformer.h.0.ln_1', 'model.transformer.h.0.attn.c_attn'] Layer 0 output shape: torch.Size([1, 10, 768])
Selective Caching¶
Pass a list of modules (Envoy references or path strings) to cache only what you need. This keeps memory usage low and is the typical pattern for real experiments.
# By Envoy reference
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(modules=[
model.transformer.h[0],
model.transformer.h[5],
model.transformer.h[11],
model.lm_head,
])
print(f"Cached: {list(cache.keys())}")
Cached: ['model.transformer.h.0', 'model.transformer.h.5', 'model.transformer.h.11', 'model.lm_head']
# By path string — useful when iterating programmatically
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(modules=[f"model.transformer.h.{i}.attn" for i in range(12)])
print(f"Cached {len(cache.keys())} attention modules")
Cached 12 attention modules
Outputs, Inputs, or Both¶
By default, cache() records module outputs. Set include_inputs=True to also capture each module's inputs, or set include_output=False to capture only inputs.
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(
modules=[model.transformer.h[0], model.transformer.h[1]],
include_inputs=True,
)
# Layer 1's first input should equal layer 0's output
print("h[0].output == h[1].input:",
torch.equal(cache["model.transformer.h.0"].output[0],
cache["model.transformer.h.1"].input))
h[0].output == h[1].input: True
The Entry object exposes three accessors:
.output— the module's forward output.inputs— a(args, kwargs)tuple of all inputs.input— shorthand for the first positional argument
Use Case: Logit Lens in One Pass¶
A classic mech-interp pattern: project every layer's hidden state through the final layer norm and lm_head to see what token the model would predict at each layer. With cache(), you grab all 12 hidden states in a single forward pass.
prompt = "The Eiffel Tower is in the city of"
with model.trace(prompt) as tracer:
cache = tracer.cache(modules=[layer for layer in model.transformer.h])
# Apply logit lens *outside* the trace — cheap arithmetic on the saved tensors.
# Cache stores them on CPU by default, so move the small final-layer modules
# to CPU just for this read-out.
device = next(model.transformer.ln_f.parameters()).device
print(f"Prompt: {prompt!r}\n")
with torch.no_grad():
for i in range(12):
hs = cache[f"model.transformer.h.{i}"].output[0].to(device)
logits = model.lm_head(model.transformer.ln_f(hs))
top = logits[0, -1].argmax(dim=-1)
print(f"Layer {i:2d}: {model.tokenizer.decode(top)!r}")
Prompt: 'The Eiffel Tower is in the city of' Layer 0: ' the' Layer 1: ' the' Layer 2: ' the' Layer 3: ' the' Layer 4: ' the' Layer 5: ' the' Layer 6: ' East' Layer 7: ' Ing'
Layer 8: ' Rome' Layer 9: ' London' Layer 10: ' Paris' Layer 11: ' Paris'
Caching During Generation¶
When used inside model.generate(), the cache records each generation step. The entry for a module becomes a list of Entry objects — one per step.
with model.generate("The Eiffel Tower is in", max_new_tokens=4) as tracer:
cache = tracer.cache(modules=[model.transformer.h[-1]])
entries = cache["model.transformer.h.11"]
print(f"Number of generation steps cached: {len(entries)}")
for i, entry in enumerate(entries):
print(f" step {i}: hidden state shape {entry.output[0].shape}")
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation will be skipped.
Number of generation steps cached: 4 step 0: hidden state shape torch.Size([1, 7, 768]) step 1: hidden state shape torch.Size([1, 1, 768]) step 2: hidden state shape torch.Size([1, 1, 768]) step 3: hidden state shape torch.Size([1, 1, 768])
The first step contains all prompt tokens; subsequent steps contain only the newly generated token (when KV caching is active).
Verifying an Intervention¶
Cache hooks fire after intervention hooks, so they capture the post-edit values. This makes the cache a convenient way to confirm a patch landed.
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(modules=[model.transformer.h[5]])
# Zero out layer 5's hidden state
model.transformer.h[5].output[0][:] = 0
is_zero = torch.all(cache["model.transformer.h.5"].output[0] == 0).item()
print(f"Layer 5 output is all zeros: {is_zero}")
Layer 5 output is all zeros: True
Multiple Caches in One Trace¶
You can create separate caches for different module groups in the same trace. They live independently and are returned as separate dicts.
with model.trace("The Eiffel Tower is in the city of") as tracer:
attn_cache = tracer.cache(modules=[layer.attn for layer in model.transformer.h])
mlp_cache = tracer.cache(modules=[layer.mlp for layer in model.transformer.h])
print(f"Attention modules cached: {len(attn_cache.keys())}")
print(f"MLP modules cached: {len(mlp_cache.keys())}")
print(f"Sample attn key: {list(attn_cache.keys())[0]}")
print(f"Sample mlp key: {list(mlp_cache.keys())[0]}")
Attention modules cached: 12 MLP modules cached: 12 Sample attn key: model.transformer.h.0.attn Sample mlp key: model.transformer.h.0.mlp
Memory Management with device and dtype¶
For large models or long sequences, the activations themselves can dominate memory. cache() accepts device and dtype arguments that are applied to every tensor as it is recorded — letting you offload to CPU and/or downcast to a smaller dtype on the fly.
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(
modules=[layer for layer in model.transformer.h],
device=torch.device("cpu"), # offload to CPU as values are recorded
dtype=torch.float16, # downcast to fp16 in the same step
)
sample = cache["model.transformer.h.0"].output[0]
print(f"device: {sample.device}")
print(f"dtype: {sample.dtype}")
print(f"shape: {sample.shape}")
device: cpu dtype: torch.float16 shape: torch.Size([1, 10, 768])
The defaults (device=cpu, detach=True) are already memory-friendly: tensors are moved off the GPU and detached from the autograd graph as soon as they're captured. Override device=None if you want to keep them on the model's device.
Access Patterns¶
The cache supports both dictionary-style and attribute-style access. They return the exact same data — pick whichever feels nicer.
with model.trace("The Eiffel Tower is in the city of") as tracer:
cache = tracer.cache(modules=[layer for layer in model.transformer.h] + [model.lm_head])
# Dict-style
a = cache["model.transformer.h.0"].output[0]
# Attribute-style
b = cache.model.transformer.h[0].output[0]
print("Same tensor:", torch.equal(a, b))
print("LM head shape:", cache.model.lm_head.output.shape)
Same tensor: True LM head shape: torch.Size([1, 10, 50257])
Why Cache Is Fast¶
Every time you write module.output.save(), your worker thread blocks on a value handoff with the model's main thread — the mediator wakes the worker, hands it the tensor, and the worker schedules the save before yielding control back. That coordination is cheap, but it's not free, and the cost scales with the number of .output/.input accesses you make.
tracer.cache() skips the mediator entirely. Instead of routing values through the worker thread, it registers PyTorch forward hooks directly on the modules. When the model runs, those hooks fire inline and write straight into the cache dict. There's no thread synchronization, no AST-extracted intervention code to compile, no per-access blocking.
The practical rule:
- Just collecting activations? Use
tracer.cache(). It's the fast path. - Need to edit, branch on, or otherwise use a value inside the trace? That's what
.output/.save()are for.
You can mix the two in the same trace — cache the bulk of your activations and use .output / .save() only where you actually need to interact with a value.
When to use cache
- Bulk activation collection — grabbing hidden states from every layer for probing, logit lens, or SAE training data
- Selective sweeps — caching just attention or just MLP outputs across all layers
- Memory-constrained runs —
device="cpu"+dtype=torch.float16keeps GPU memory flat while you record - Intervention verification — confirm a patch took effect by reading the post-intervention value out of the cache
- Generation tracking — collect per-step activations across decoding without writing custom hooks
- Speed — when you only need to read activations (not edit them), cache avoids the mediator round-trip that
.save()incurs