NNsight Walkthrough#
The API for a transparent science on black-box AI#
In this era of large-scale deep learning, the most interesting AI models are massive black boxes that are hard to run. Ordinary commercial inference service APIs let us interact with huge models, but they do not let us access model internals.
The nnsight library is different: it provides full access to all neural network internals. When using nnsight together with a remote service like the National Deep Inference Fabric (NDIF), it is possible to run complex experiments on huge open models easily with fully transparent access.
Through NDIF and NNsight, our team wants to enable entire labs and independent researchers alike, as we believe a large, passionate, and collaborative community will produce the next big insights on this profoundly important field.
This walkthrough will teach you nnsight from the ground up, starting with the core mental model and building to advanced features.
Run an interactive version of this tutorial on colab
Table of Contents#
Getting Started - Setup and wrapping models
Intervening - Accessing and modifying activations
LLMs - LanguageModel, invokers, batching, and multi-token generation
Gradients - Accessing and modifying gradients
Advanced Features - Source tracing, caching, early stopping, scanning
Model Editing - Persistent modifications
Remote Execution - Running on NDIF
# 1. Getting Started
Let’s set up nnsight and run our first trace.
Installation#
[ ]:
# Install nnsight
!pip install nnsight
!pip install --upgrade transformers torch
from IPython.display import clear_output
clear_output()
A Tiny Model#
To demonstrate the core functionality and syntax of nnsight, we’ll define and use a tiny two layer neural network.
Our little model here is composed of two submodules – linear layers ‘layer1’ and ‘layer2’. We specify the sizes of each of these modules and create some complementary example input.
[1]:
from collections import OrderedDict
import torch
input_size = 5
hidden_dims = 10
output_size = 2
net = torch.nn.Sequential(
OrderedDict([
("layer1", torch.nn.Linear(input_size, hidden_dims)),
("layer2", torch.nn.Linear(hidden_dims, output_size)),
])
).requires_grad_(False)
# random input
input = torch.rand((1, input_size))
Wrapping with NNsight#
The core object of the nnsight package is NNsight. This wraps around a given PyTorch model to enable investigation of its internal parameters.
[2]:
from nnsight import NNsight
model = NNsight(net)
Printing a PyTorch model shows a named hierarchy of modules, which is very useful for knowing how to access sub-components directly. NNsight reflects the same hierarchy:
[3]:
print(model)
Sequential(
(layer1): Linear(in_features=5, out_features=10, bias=True)
(layer2): Linear(in_features=10, out_features=2, bias=True)
)
Python Contexts#
Before we actually get to using the model, let’s talk about Python contexts.
Python contexts define a scope using the with statement and are often used to create some object, or initiate some logic, that you later want to destroy or conclude.
The most common application is opening files:
with open('myfile.txt', 'r') as file:
text = file.read()
Python uses the with keyword to enter a context-like object. This object defines logic to be run at the start of the with block, as well as logic to be run when exiting. When using with for a file, entering the context opens the file and exiting the context closes it. Being within the context means we can read from the file.
Simple enough! Now we can discuss how nnsight uses contexts to enable intuitive access into the internals of a neural network.
# 2. Intervening
Now let’s access the model’s internals using the tracing context.
The Tracing Context#
The main tool in nnsight is a context for tracing. We enter the tracing context by calling model.trace(<input>) on an NNsight model, which defines how we want to run the model. Inside the context, we will be able to customize how the neural network runs. The model is actually run upon exiting the tracing context:
[4]:
input = torch.rand((1, input_size))
with model.trace(input):
# Your intervention code goes here
# The model runs when the context exits
pass
But where’s the output? To get that, we’ll have to learn how to request it from within the tracing context.
The .input and .output Properties#
When we wrapped our neural network with the NNsight class, this added a couple of properties to each module in the model (including the root model itself). The two most important ones are .input and .output:
model.input # The input to the model
model.output # The output from the model
The names are self-explanatory. They correspond to the inputs and outputs of their respective modules during a forward pass. We can use these attributes inside the with block to access values at any point in the network.
Let’s try accessing the model’s output:
[5]:
with model.trace(input):
output = model.output
print(output)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 4
1 with model.trace(input):
2 output = model.output
----> 4 print(output)
NameError: name 'output' is not defined
Oh no, an error! “Accessing value before it’s been set.”
Why doesn’t our output have a value? Values accessed inside a trace only exist during the trace. They will only persist after the context if we call .save() on them. This helps reduce memory costs - we only keep what we explicitly ask for.
Saving Values with .save()#
Adding .save() fixes the error:
[6]:
with model.trace(input):
output = model.output.save()
print(output)
tensor([[-0.1191, -0.5757]])
Success! We now have the model output. We just completed our first intervention using nnsight.
The .save() method tells nnsight “I want to use this value after the trace ends.”
💡 Tip: There’s also
nnsight.save(value)which is the preferred alternative. It works on any value and doesn’t require the object to have a.save()method:output = nnsight.save(model.output)Both approaches work, but
nnsight.save()is more explicit and works in more cases.
Accessing Submodule Outputs#
Just like we saved the model’s output, we can access any submodule’s output. Remember when we printed the model earlier? That showed us layer1 and layer2 - we can access those directly:
[7]:
with model.trace(input):
layer1_output = model.layer1.output.save()
layer2_output = model.layer2.output.save()
print("Layer 1 output:", layer1_output)
print("Layer 2 output:", layer2_output)
Layer 1 output: tensor([[-0.7436, 0.6647, -0.2002, -0.7936, -0.2744, -0.7671, 0.1406, 0.2293,
-0.1555, 0.6403]])
Layer 2 output: tensor([[-0.1191, -0.5757]])
[ ]:
Accessing Module Inputs#
We can also access the inputs to any module using .input:
Property |
Returns |
|---|---|
|
The module’s return value |
|
The first positional argument to the module |
|
All inputs as |
[8]:
with model.trace(input):
layer2_input = model.layer2.input.save()
print("Layer 2 input:", layer2_input)
print("(Notice it equals layer1 output!)")
Layer 2 input: tensor([[-0.7436, 0.6647, -0.2002, -0.7936, -0.2744, -0.7671, 0.1406, 0.2293,
-0.1555, 0.6403]])
(Notice it equals layer1 output!)
Operations on Values#
Since you’re working with real tensors, you can apply any PyTorch operations:
[9]:
with model.trace(input):
layer1_out = model.layer1.output
# Apply operations - these are real tensor operations!
max_idx = torch.argmax(layer1_out, dim=1).save()
total = (model.layer1.output.sum() + model.layer2.output.sum()).save()
print("Max index:", max_idx)
print("Total:", total)
Max index: tensor([1])
Total: tensor(-1.9540)
The Core Paradigm: Interleaving#
When you write intervention code inside a with model.trace(...) block, here’s what actually happens:
Your code is captured - nnsight extracts the code inside the
withblockThe code is compiled into an executable function
Your code runs in parallel with the model - as the model executes its forward pass, your intervention code runs alongside it
Your code waits for values - when you access
.output, your code pauses until the model reaches that pointThe model provides values via hooks - PyTorch hooks inject values into your waiting code
Your code can modify values - before the forward pass continues, you can change activations
This process is called interleaving - your intervention code and the model’s forward pass take turns executing, synchronized at specific points (module inputs and outputs).
┌─────────────────────────────────────────────────────────────────────┐
│ Forward Pass (main) Intervention Code (your code) │
│ ───────────────────── ───────────────────────────── │
│ │
│ model(input) # Your code starts │
│ │ │ │
│ ▼ ▼ │
│ layer1.forward() hs = model.layer1.output │
│ │ │ │
│ │──── hook provides value ──────────►│ │
│ │ │ │
│ │◄─── your code continues ────────── │ │
│ │ (can modify value) │ │
│ ▼ ▼ │
│ layer2.forward() out = model.layer2.output │
│ │ │ │
│ ▼ ▼ │
│ return output # Your code finishes │
└─────────────────────────────────────────────────────────────────────┘
Key insight:
Because your code waits for values as the forward pass progresses, you must access modules in the order they execute.
✅ Correct: Access layer 0, then layer 5
with model.trace("Hello"):
layer0_out = model.layers[0].output.save() # Waits for layer 0
layer5_out = model.layers[5].output.save() # Then waits for layer 5
❌ Wrong: Access layer 5, then layer 0
with model.trace("Hello"):
layer5_out = model.layers[5].output.save() # Waits for layer 5
layer0_out = model.layers[0].output.save() # ERROR! Layer 0 already executed
# Raises OutOfOrderError
When you try to access a module that has already executed, nnsight raises an OutOfOrderError. This is because the forward pass has already moved past that point - you missed your chance to intercept that value.
Modification#
Not only can we view intermediate states of the model, we can modify them and see the effect on the output.
Use indexing with [:] for in-place modifications:
[10]:
with model.trace(input):
# Save original (clone first since we'll modify in-place)
before = model.layer1.output.clone().save()
# Zero out the first dimension
model.layer1.output[:, 0] = 0
# Save modified
after = model.layer1.output.save()
print("Before:", before)
print("After: ", after)
Before: tensor([[-0.7436, 0.6647, -0.2002, -0.7936, -0.2744, -0.7671, 0.1406, 0.2293,
-0.1555, 0.6403]])
After: tensor([[ 0.0000, 0.6647, -0.2002, -0.7936, -0.2744, -0.7671, 0.1406, 0.2293,
-0.1555, 0.6403]])
Replacement#
You can also replace an output entirely:
[11]:
with model.trace(input):
original = model.layer1.output.clone()
# Add noise to the activation
noise = 0.1 * torch.randn_like(original)
model.layer1.output = original + noise
modified = model.layer1.output.save()
print("Modified output:", modified)
Modified output: tensor([[-0.5552, 0.6286, -0.2002, -0.8105, -0.2760, -0.8521, 0.0105, 0.2263,
-0.1350, 0.5815]])
Error Handling#
If you make an error (like invalid indexing), nnsight provides clear error messages with line numbers:
[12]:
# This will fail because hidden_dims=10, so valid indices are 0-9
try:
with model.trace(input):
model.layer1.output[:, hidden_dims] = 0 # Index 10 is out of bounds!
except IndexError as e:
print("Caught error:", e)
Caught error:
Traceback (most recent call last):
File "/tmp/ipykernel_1571564/3815762194.py", line 4, in <module>
model.layer1.output[:, hidden_dims] = 0 # Index 10 is out of bounds!
IndexError: index 10 is out of bounds for dimension 1 with size 10
Debugging tips:
Use ``print()`` inside traces - it works normally and prints values as they’re computed
Use ``breakpoint()`` to drop into pdb and inspect values interactively
Toggle internal frames with
nnsight.CONFIG.APP.DEBUG = Trueto see NNsight’s internal execution (helpful when the default traceback isn’t clear)
with model.trace(input):
out = model.layer1.output
print("Layer 1 shape:", out.shape) # Works!
breakpoint() # Drops into pdb - inspect `out`, etc.
# 3. LLMs
Now that we have the basics of nnsight under our belt, we can scale our model up and combine the techniques we’ve learned into more interesting experiments!
The NNsight class we used in Part 2 is very bare bones. It wraps a pre-defined model and does no pre-processing on the inputs we enter. It’s designed to be extended with more complex and powerful types of models.
For language models, nnsight provides LanguageModel, a subclass that greatly simplifies the process:
Automatic tokenization - pass strings directly, no manual tokenization needed
HuggingFace integration - load any model from the HuggingFace Hub by its ID
Generation support - built-in support for multi-token generation with
.generate()Batching - efficiently process multiple inputs in one forward pass
Let’s load GPT-2 and start experimenting!
Loading a Language Model#
While we could define and create a model to pass in directly, LanguageModel includes special support for HuggingFace language models - it automatically loads the model AND the appropriate tokenizer from a HuggingFace ID.
Under the hood, LanguageModel uses AutoModelForCausalLM.from_pretrained() to load the model. Any keyword arguments you pass are forwarded directly to HuggingFace, so you can use all the same options:
# Example with common HuggingFace kwargs:
model = LanguageModel(
"meta-llama/Llama-3.1-8B",
device_map="auto", # Distribute across GPUs
torch_dtype=torch.float16, # Set precision
trust_remote_code=True, # For custom model code
)
A note on model initialization:
device_map="auto"tells HuggingFace Accelerate to automatically distribute model layers across all available GPUs (and CPU if the model doesn’t fit). This is the recommended setting for large models.By default, nnsight uses lazy loading - the model isn’t loaded into memory until the first trace. Pass
dispatch=Trueto load immediately.
[13]:
from nnsight import LanguageModel
llm = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
print(llm)
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
(generator): Generator(
(streamer): Streamer()
)
)
Notice the model structure! GPT-2 has:
transformer.wte- token embeddingstransformer.h- a list of transformer blocks (layers 0-11)lm_head- the output projection to vocabulary
With LanguageModel, you can pass strings directly - tokenization happens automatically:
[14]:
with llm.trace("The Eiffel Tower is in the city of"):
# Access hidden states from the last layer
hidden_states = llm.transformer.h[-1].output[0].save()
# Access the final logits
logits = llm.lm_head.output.save()
print("Hidden states shape:", hidden_states.shape)
print("Predicted next token:", llm.tokenizer.decode(logits[0, -1].argmax()))
Hidden states shape: torch.Size([1, 10, 768])
Predicted next token: Paris
Everything you learned with the tiny model applies here! The same .input, .output, and .save() patterns work. The key difference is you can pass strings directly.
💡 Note: GPT-2 transformer layers return tuples where
[0]contains the hidden states. That’s why we use.output[0]instead of just.output.
Invokers and Batching#
So far we’ve been running one input at a time. But what if you want to process multiple inputs efficiently, or apply different interventions to each?
This is where invokers come in. When you call .trace() without an input, you can create multiple invokers - each one defines an input and the interventions for that input:
The key insight: all invokers are batched together into one forward pass. This is much more efficient than running separate traces!
[15]:
with llm.trace() as tracer:
# First invoker: run on "Paris" prompt
with tracer.invoke("The Eiffel Tower is in"):
paris_logits = llm.lm_head.output[:, -1].save()
# Second invoker: run on "London" prompt
with tracer.invoke("Big Ben is in"):
london_logits = llm.lm_head.output[:, -1].save()
# Both ran in ONE forward pass!
print("Paris prediction:", llm.tokenizer.decode(paris_logits.argmax()))
print("London prediction:", llm.tokenizer.decode(london_logits.argmax()))
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Paris prediction: the
London prediction: in
How Invokers Execute#
Invokers run serially - one after another, not in parallel. This means you can reference values from earlier invokes:
┌──────────────────────────────────────────────────────────────────┐
│ Invoke 1 starts │ Invoke 2 starts (after 1 finishes) │
│ │ │ │ │
│ ▼ │ ▼ │
│ Wait for wte.output │ Wait for wte.output │
│ │ │ │ │
│ ▼ │ ▼ │
│ Wait for lm_head.output │ Wait for lm_head.output │
│ │ │ │ │
│ ▼ │ ▼ │
│ Invoke 1 finishes │ Invoke 2 finishes │
└──────────────────────────────────────────────────────────────────┘
This enables powerful cross-prompt interventions - like patching activations from one prompt into another:
Why do we need ``barrier()`` here?
Both invokes access llm.transformer.wte.output. Without a barrier, invokes run serially - the first would complete entirely before the second starts. By the time the second invoke tries to use paris_embeddings, it wouldn’t be defined in scope!
The barrier synchronizes both invokes at a specific point, allowing them to share variables while both are accessing the same module.
[16]:
with llm.trace() as tracer:
barrier = tracer.barrier(2) # Create barrier for 2 participants
# First invoke: capture embeddings from "Paris" prompt
with tracer.invoke("The Eiffel Tower is in"):
paris_embeddings = llm.transformer.wte.output
barrier()
# Second invoke: patch those embeddings into a different prompt!
with tracer.invoke("_ _ _ _ _"): # Dummy tokens (same length)
barrier()
llm.transformer.wte.output = paris_embeddings # Inject Paris embeddings
patched_output = llm.lm_head.output[:, -1].save()
# The model now predicts as if it saw "The Eiffel Tower is in"!
print("Patched prediction:", llm.tokenizer.decode(patched_output.argmax()))
Patched prediction: in
Multi-Token Generation#
So far we’ve done single forward passes. But language models generate text by running multiple forward passes - one per token. This means the same modules are called multiple times!
Use .generate() instead of .trace() for multi-token generation:
[17]:
with llm.generate("The Eiffel Tower is in", max_new_tokens=3) as tracer:
output = llm.generator.output.save()
print(llm.tokenizer.decode(output[0]))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation will be skipped.
The Eiffel Tower is in the middle of
Iterating Over Generation Steps with .iter#
During generation, modules are called once per token. What if you want to intervene or collect data at each step?
Use tracer.iter[:] to iterate over all generation steps. This is crucial whenever modules are called more than once - generation, diffusion steps, recurrent networks, etc:
[18]:
with llm.generate("The Eiffel Tower is in", max_new_tokens=3) as tracer:
tokens = list().save()
# Iterate over ALL generation steps
with tracer.iter[:]:
token = llm.lm_head.output[0, -1].argmax(dim=-1)
tokens.append(token)
print("Generated tokens:", llm.tokenizer.batch_decode(tokens))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Generated tokens: [' the', ' middle', ' of']
tracer.iter accepts different patterns:
Pattern |
Meaning |
|---|---|
|
All steps |
|
First step only |
|
Steps 1 and 2 |
|
Every other step |
Conditional Per-Step Interventions#
Use as step_idx to get the current step index. This lets you apply different logic at different steps:
[19]:
with llm.generate("Hello", max_new_tokens=5) as tracer:
tokens = list().save()
with tracer.iter[:] as step_idx:
# Only intervene on step 2
if step_idx == 2:
llm.transformer.h[0].output[0][:] = 0 # Zero out layer 0
tokens.append(llm.lm_head.output[0, -1].argmax(dim=-1))
print(f"Generated {len(tokens)} tokens (step 2 had zeroed activations)")
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Generated 0 tokens (step 2 had zeroed activations)
💡 Key Takeaway:
.iterworks anywhere modules are called multiple times - not just LLM generation. It’s useful for diffusion model denoising steps, RNN time steps, or any iterative computation.
⚠️ Warning: Unbounded Iteration Footgun#
Critical: When using tracer.iter[:] or tracer.all(), code AFTER the iter block never executes!
These unbounded iterators don’t know when to stop - they wait forever for the “next” iteration. When generation finishes, any code after the iter block is skipped:
# WRONG - final_output never gets defined!
with model.generate("Hello", max_new_tokens=3) as tracer:
with tracer.iter[:]:
hidden = model.transformer.h[-1].output.save()
# ⚠️ THIS NEVER EXECUTES!
final_output = model.output.save()
print(final_output) # NameError: 'final_output' is not defined
Solution: Use a separate empty invoker for code that should run after iteration:
with model.generate("Hello", max_new_tokens=3) as tracer:
with tracer.invoke(): # First invoker handles iteration
with tracer.iter[:]:
hidden = model.transformer.h[-1].output.save()
with tracer.invoke(): # Second invoker runs after generation
final_output = model.output.save() # Now this works!
Section 3 Summary#
You’ve learned the core patterns for working with LLMs in nnsight:
LanguageModel - Load HuggingFace models with automatic tokenization
Invokers - Process multiple inputs efficiently in one batched forward pass
Cross-invoke sharing - Reference values from one invoke in another
Multi-token generation - Use
.generate()instead of.trace()Iteration with ``.iter`` - Intervene at each step when modules are called multiple times
These patterns form the foundation for interpretability research!
# 4. Gradients
nnsight supports gradient access and modification through a special backward tracing context. This is essential for gradient-based interpretability methods like attribution, saliency maps, and gradient-based steering.
Just like we use with model.trace() to intercept the forward pass, we use with loss.backward() to intercept the backward pass. The key insight: during backpropagation, gradients flow in reverse order - from the loss back through the model. So you must access .grad in the reverse order of how you accessed the tensors during the forward pass!
[20]:
with llm.trace("Hello"):
# FORWARD PASS: Access tensors in forward order
# First, get the tensor we want gradients for
hs = llm.transformer.h[-1].output[0]
hs.requires_grad_(True)
# Then compute the loss (comes after hidden states in forward pass)
logits = llm.lm_head.output
loss = logits.sum()
# BACKWARD PASS: Access gradients in REVERSE order!
# Gradients flow from loss → logits → hidden states → earlier layers
with loss.backward():
# hs.grad is available because we're going backwards from loss
grad = hs.grad.save()
print("Gradient shape:", grad.shape)
Gradient shape: torch.Size([1, 1, 768])
Understanding Gradient Order#
This is the same interleaving principle from the forward pass, but reversed:
Forward pass order: layer0 → layer1 → ... → layer11 → lm_head → loss
Backward pass order: loss → lm_head → layer11 → ... → layer1 → layer0
If you accessed layer5.output and layer10.output during the forward pass, you must access their gradients in reverse: layer10.grad first, then layer5.grad.
Important rules for gradients:
.gradis only accessible inside awith tensor.backward():context.gradis a property of tensors, not modulesGet the tensor via
.outputbefore entering the backward contextCall
.requires_grad_(True)on the tensor you want gradients forAccess gradients in reverse order of how you got the tensors
Modifying Gradients#
You can modify gradients just like activations - useful for techniques like gradient clipping or steering:
[21]:
with llm.trace("Hello"):
hs = llm.transformer.h[-1].output[0]
hs.requires_grad_(True)
logits = llm.lm_head.output
loss = logits.sum()
with loss.backward():
# Save original gradient
original_grad = hs.grad.clone().save()
# Modify gradient (e.g., zero it out)
hs.grad[:] = 0
# Save modified
modified_grad = hs.grad.save()
print("Original grad mean:", original_grad.mean().item())
print("Modified grad mean:", modified_grad.mean().item())
Original grad mean: -8.900960892788135e-06
Modified grad mean: 0.0
# 5. Advanced Features
Let’s explore some powerful advanced features that unlock deeper investigations.
[22]:
# Print source to discover available operations inside a module
print(llm.transformer.h[0].attn.source)
* @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
0 def forward(
1 self,
2 hidden_states: Optional[tuple[torch.FloatTensor]],
3 past_key_values: Optional[Cache] = None,
4 cache_position: Optional[torch.LongTensor] = None,
5 attention_mask: Optional[torch.FloatTensor] = None,
6 head_mask: Optional[torch.FloatTensor] = None,
7 encoder_hidden_states: Optional[torch.Tensor] = None,
8 encoder_attention_mask: Optional[torch.FloatTensor] = None,
9 output_attentions: Optional[bool] = False,
10 **kwargs,
11 ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
12 is_cross_attention = encoder_hidden_states is not None
13 if past_key_values is not None:
isinstance_0 -> 14 if isinstance(past_key_values, EncoderDecoderCache):
past_key_values_is_updated_get_0 -> 15 is_updated = past_key_values.is_updated.get(self.layer_idx)
16 if is_cross_attention:
17 # after the first generated id, we can subsequently re-use all key/value_layer from cache
18 curr_past_key_value = past_key_values.cross_attention_cache
19 else:
20 curr_past_key_value = past_key_values.self_attention_cache
21 else:
22 curr_past_key_value = past_key_values
23
24 if is_cross_attention:
hasattr_0 -> 25 if not hasattr(self, "q_attn"):
ValueError_0 -> 26 raise ValueError(
27 "If class is used as cross attention, the weights `q_attn` have to be defined. "
28 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
29 )
self_q_attn_0 -> 30 query_states = self.q_attn(hidden_states)
31 attention_mask = encoder_attention_mask
32
33 # Try to get key/value states from cache if possible
34 if past_key_values is not None and is_updated:
35 key_states = curr_past_key_value.layers[self.layer_idx].keys
36 value_states = curr_past_key_value.layers[self.layer_idx].values
37 else:
self_c_attn_0 -> 38 key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
split_0 -> + ...
39 shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
key_states_view_0 -> 40 key_states = key_states.view(shape_kv).transpose(1, 2)
transpose_0 -> + ...
value_states_view_0 -> 41 value_states = value_states.view(shape_kv).transpose(1, 2)
transpose_1 -> + ...
42 else:
self_c_attn_1 -> 43 query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
split_1 -> + ...
44 shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
key_states_view_1 -> 45 key_states = key_states.view(shape_kv).transpose(1, 2)
transpose_2 -> + ...
value_states_view_1 -> 46 value_states = value_states.view(shape_kv).transpose(1, 2)
transpose_3 -> + ...
47
48 shape_q = (*query_states.shape[:-1], -1, self.head_dim)
query_states_view_0 -> 49 query_states = query_states.view(shape_q).transpose(1, 2)
transpose_4 -> + ...
50
51 if (past_key_values is not None and not is_cross_attention) or (
52 past_key_values is not None and is_cross_attention and not is_updated
53 ):
54 # save all key/value_layer to cache to be re-used for fast auto-regressive generation
55 cache_position = cache_position if not is_cross_attention else None
curr_past_key_value_update_0 -> 56 key_states, value_states = curr_past_key_value.update(
57 key_states, value_states, self.layer_idx, {"cache_position": cache_position}
58 )
59 # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
60 if is_cross_attention:
61 past_key_values.is_updated[self.layer_idx] = True
62
63 is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
64
65 using_eager = self.config._attn_implementation == "eager"
66 attention_interface: Callable = eager_attention_forward
67 if self.config._attn_implementation != "eager":
68 attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
69
70 if using_eager and self.reorder_and_upcast_attn:
self__upcast_and_reordered_attn_0 -> 71 attn_output, attn_weights = self._upcast_and_reordered_attn(
72 query_states, key_states, value_states, attention_mask, head_mask
73 )
74 else:
attention_interface_0 -> 75 attn_output, attn_weights = attention_interface(
76 self,
77 query_states,
78 key_states,
79 value_states,
80 attention_mask,
81 head_mask=head_mask,
82 dropout=self.attn_dropout.p if self.training else 0.0,
83 is_causal=is_causal,
84 **kwargs,
85 )
86
attn_output_reshape_0 -> 87 attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
contiguous_0 -> + ...
self_c_proj_0 -> 88 attn_output = self.c_proj(attn_output)
self_resid_dropout_0 -> 89 attn_output = self.resid_dropout(attn_output)
90
91 return attn_output, attn_weights
92
5.1 Source Tracing#
Sometimes you need to access values inside a module’s forward pass, not just its inputs and outputs. The .source property rewrites the forward method to hook every operation, letting you access intermediate computations:
Source operations have the same interface as modules - .output, .input, .inputs:
[23]:
with llm.trace("Hello"):
# Access an internal operation by name
attn_output = llm.transformer.h[0].attn.source.attention_interface_0.output.save()
print("Attention output type:", type(attn_output))
Attention output type: <class 'tuple'>
5.2 Caching Activations#
Use tracer.cache() to automatically save all module outputs - no need to manually call .save() on each one:
[24]:
with llm.trace("Hello") as tracer:
cache = tracer.cache()
# Access cached values after the trace
print("Layer 0 output shape:", cache['model.transformer.h.0'].output[0].shape)
# Attribute-style access also works
print("Same thing:", cache.model.transformer.h[0].output[0].shape)
Layer 0 output shape: torch.Size([1, 1, 768])
Same thing: torch.Size([1, 1, 768])
5.3 Early Stopping#
If you only need early layers, stop execution early to save computation:
[25]:
with llm.trace("Hello") as tracer:
layer0 = llm.transformer.h[0].output[0].save()
tracer.stop() # Don't execute remaining layers
print("Early stop - only ran first layer")
print("Layer 0 shape:", layer0.shape)
Early stop - only ran first layer
Layer 0 shape: torch.Size([1, 1, 768])
[26]:
with llm.trace("Hello"):
# Get layer 0 output
layer0_out = llm.transformer.h[0].output
# Skip layer 1 - use layer 0's output instead
llm.transformer.h[1].skip(layer0_out)
# Continue with rest of model
output = llm.lm_head.output.save()
print("Skipped layer 1!")
Skipped layer 1!
5.4 Scanning (Shape Inference)#
Use .scan() to get shapes without running the full model - useful for debugging:
[27]:
with llm.scan("Hello"):
hidden_dim = llm.transformer.h[0].output[0].shape[-1].save()
print("Hidden dimension:", hidden_dim)
Hidden dimension: 768
# 6. Model Editing
Create persistent model modifications that apply to all future traces:
[28]:
# First, get hidden states that predict "Paris"
with llm.trace("The Eiffel Tower is in the city of"):
paris_hidden = llm.transformer.h[-1].output[0][:, -1, :].save()
# Create an edited model that always uses these hidden states
with llm.edit() as llm_edited:
llm.transformer.h[-1].output[0][:, -1, :] = paris_hidden
# Original model: normal prediction
with llm.trace("Vatican is in the city of"):
original = llm.lm_head.output.argmax(dim=-1).save()
# Edited model: always predicts "Paris"!
with llm_edited.trace("Vatican is in the city of"):
modified = llm.lm_head.output.argmax(dim=-1).save()
print("Original:", llm.tokenizer.decode(original[0, -1]))
print("Edited: ", llm.tokenizer.decode(modified[0, -1]))
Original: Rome
Edited: Paris
Use llm.clear_edits() to remove all persistent edits.
# 7. Remote Execution (NDIF)
nnsight can run interventions on large models hosted by the National Deep Inference Fabric (NDIF). Everything works the same - just add remote=True.
Setup#
Get your API key at https://login.ndif.us:
[ ]:
from nnsight import CONFIG
CONFIG.set_default_api_key("YOUR_API_KEY")
Check available models at https://nnsight.net/status/
Remote Tracing#
Load a large model and run remotely - your interventions execute on NDIF’s infrastructure:
[ ]:
import os
os.environ['HF_TOKEN'] = "YOUR_HUGGING_FACE_TOKEN"
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
# Just add remote=True - everything else is the same!
with llama.trace("The Eiffel Tower is in the city of", remote=True):
hidden_states = llama.model.layers[-1].output.save()
output = llama.output.save()
print("Hidden states shape:", hidden_states[0].shape)
Next Steps#
Congratulations! You’ve learned the core concepts of nnsight:
Wrapping models with
NNsightandLanguageModelAccessing activations with
.output,.input,.save()Modifying activations with in-place and replacement patterns
The interleaving paradigm - your code runs alongside the model
Invokers and batching - efficient multi-input processing
Multi-token generation -
.generate()and.iterfor iterative operationsGradients -
with tensor.backward():for gradient accessAdvanced features - source tracing, caching, early stopping, scanning
Model editing - persistent modifications with
.edit()Remote execution - running on NDIF with
remote=True
For more tutorials implementing classic interpretability techniques, visit nnsight.net/tutorials.
For deep technical details, see the NNsight.md design document.
Getting Involved!#
Both nnsight and NDIF are in active development. Join us:
Discord: discord.gg/6uFJmCSwW7
Forum: discuss.ndif.us
Twitter/X: @ndif_team
LinkedIn: National Deep Inference Fabric
We’d love to hear about your work using nnsight! 💟
[ ]:
print("Walkthrough complete! Visit nnsight.net for more tutorials.")