Info
Last Execution: 2026-02-17
| Package | Version |
|---|---|
| nnsight | 0.5.15 |
| Python | 3.12.3 |
| torch | 2.10.0+cu128 |
| transformers | 5.2.0 |
Gradients¶
nnsight lets you access and intervene on gradients using the .backward() context. This is the foundation for attribution methods like integrated gradients and attribution patching.
Setup¶
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
Accessing Gradients¶
To access gradients, first capture a tensor during the forward pass, enable gradients with .requires_grad_(True), then enter a with loss.backward(): context and access .grad on the tensor.
with model.trace("The Eiffel Tower is in the city of"):
hs = model.transformer.h[-1].output[0]
hs.requires_grad_(True)
logits = model.lm_head.output
loss = logits.sum()
with loss.backward():
grad = hs.grad.save()
print(f"Gradient shape: {grad.shape}")
print(f"Gradient mean: {grad.abs().mean():.4f}")
Gradient shape: torch.Size([1, 10, 768]) Gradient mean: 472.7236
Gradient access order is reversed
Inside a .backward() context, you must access .grad in reverse execution order (following backpropagation). If you captured layer[0] and layer[-1] during the forward pass, access layer[-1].grad first inside the backward context.
Gradients at Multiple Layers¶
You can capture gradients at multiple points. Access them in reverse layer order inside the backward context.
with model.trace("The Eiffel Tower is in the city of"):
hs_early = model.transformer.h[0].output[0]
hs_early.requires_grad_(True)
hs_late = model.transformer.h[-1].output[0]
hs_late.requires_grad_(True)
logits = model.lm_head.output
with logits.sum().backward():
# Reverse order: late layer first
grad_late = hs_late.grad.save()
grad_early = hs_early.grad.save()
print(f"Layer 0 gradient mean: {grad_early.abs().mean():.4f}")
print(f"Layer 11 gradient mean: {grad_late.abs().mean():.4f}")
Layer 0 gradient mean: 7570.3931 Layer 11 gradient mean: 472.7236
Modifying Gradients¶
You can intervene on gradients just like forward-pass activations. Use in-place assignment to ablate or modify gradient values.
with model.trace("The Eiffel Tower is in the city of"):
hs = model.transformer.h[-1].output[0]
hs.requires_grad_(True)
logits = model.lm_head.output
with logits.sum().backward():
grad_before = hs.grad.clone().save()
hs.grad[:] = 0
grad_after = hs.grad.save()
print(f"Before ablation - mean: {grad_before.abs().mean():.4f}")
print(f"After ablation - mean: {grad_after.abs().mean():.4f}")
Before ablation - mean: 472.7236 After ablation - mean: 0.0000
Multiple Backward Passes¶
Use retain_graph=True on the first backward call to keep the computation graph alive for additional passes.
with model.trace("The Eiffel Tower is in the city of"):
hs = model.transformer.h[-1].output[0]
hs.requires_grad_(True)
logits = model.lm_head.output
with logits.sum().backward(retain_graph=True):
grad1 = hs.grad.save()
modified = logits * 2
with modified.sum().backward():
grad2 = hs.grad.save()
print(f"First backward grad mean: {grad1.abs().mean():.4f}")
print(f"Second backward grad mean: {grad2.abs().mean():.4f}")
First backward grad mean: 472.7236 Second backward grad mean: 945.4472
Standalone Backward¶
The .backward() context works anywhere — even outside a model.trace(). Run a forward pass first to get the tensors, then trace the backward pass separately.
import nnsight
with model.trace("The Eiffel Tower is in the city of"):
hs = model.transformer.h[-1].output[0]
hs.requires_grad_(True)
hs = hs.save()
logits = model.lm_head.output.save()
# Backward pass outside the trace
loss = logits.sum()
with loss.backward():
grad = hs.grad.save()
print(f"Standalone backward grad shape: {grad.shape}")
Standalone backward grad shape: torch.Size([1, 10, 768])
When to use standalone backward
Standalone backward is useful when you want to inspect forward-pass results before deciding what loss to backpropagate through, or when you want to run multiple different backward passes on the same saved tensors.
How standalone backward works
When you import nnsight, it monkey-patches torch.Tensor.backward so that it can be used as a with context manager. This is what enables the .backward() tracing context to work on any tensor, anywhere in your code.