Attribution Patching#
📗 This tutorial is adapted from Neel Nanda’s blog post.
Activation patching is a method to determine how model components influence model computations (see our activation patching tutorial for more information). Although activation patching is a useful tool for circuit identification, it requires a separate forward pass through the model for each patched activation, making it time- and resource-intensive.
Attribution patching uses gradients to take a linear approximation to activation patching and can be done simultaneously in two forward and one backward pass, making it much more scalable to large models.
You can find a colab version of the tutorial here or Neel’s version here.
Read more about an application of Attribution Patching in Attribution Patching Outperforms Automated Circuit Discovery. 📙
Setup#
If you are using Colab or haven’t yet installed NNsight, install the package:
!pip install -U nnsight
Import libraries
[12]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected+colab+notebook"
from nnsight import LanguageModel
1️⃣ Indirect Object Identification (IOI) Patching Setup#
Indirect object identification (IOI) is the ability to infer the correct indirect object in a sentence, allowing one to complete sentences like “John and Mary went to the shops, John gave a bag to” with the correct answer ” Mary”. Understanding how language models like GPT-2 perform linguistic tasks like IOI helps us gain insights into their internal mechanisms and decision-making processes.
Here, we apply the IOI task to explore how GPT-2 small is performing IOI with attribution patching.
📚 Note: For more detail on the IOI task, check out the ARENA walkthrough.
[13]:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
clear_output()
print(model)
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2SdpaAttention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
(generator): WrapperModule()
)
Looking at the model architecture, we can see there are 12 layers, each with 12 GPT-2 Blocks. We will use attribution patching to approximate the contribution of each layer and each attention head for the IOI task.
We next define 8 IOI prompts, with each prompt having one related corrupted prompt variation (i.e., the indirect object is swapped out).
[14]:
prompts = [
"When John and Mary went to the shops, John gave the bag to",
"When John and Mary went to the shops, Mary gave the bag to",
"When Tom and James went to the park, James gave the ball to",
"When Tom and James went to the park, Tom gave the ball to",
"When Dan and Sid went to the shops, Sid gave an apple to",
"When Dan and Sid went to the shops, Dan gave an apple to",
"After Martin and Amy went to the park, Amy gave a drink to",
"After Martin and Amy went to the park, Martin gave a drink to",
]
# Answers are each formatted as (correct, incorrect):
answers = [
(" Mary", " John"),
(" John", " Mary"),
(" Tom", " James"),
(" James", " Tom"),
(" Dan", " Sid"),
(" Sid", " Dan"),
(" Martin", " Amy"),
(" Amy", " Martin"),
]
# Tokenize clean and corrupted inputs:
clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
[(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]
# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
[
[model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
for i in range(len(answers))
]
)
Next, we create a function to calculate the mean logit difference for the correct vs incorrect answer tokens.
[15]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
logits = logits[:, -1, :]
correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
return (correct_logits - incorrect_logits).mean()
We then calculate the logit difference for both the clean and the corrupted baselines.
[16]:
clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()
CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")
CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Clean logit diff: 2.8138
Corrupted logit diff: -2.8138
Now let’s define an ioi_metric
function to evaluate patched IOI changes normalized to our clean and corruped baselines.
[17]:
def ioi_metric(
logits,
answer_token_indices=answer_token_indices,
):
return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
CLEAN_BASELINE - CORRUPTED_BASELINE
)
print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")
Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000
2️⃣ Attribution Patching Over Components#
Attribution patching is a technique that uses gradients to take a linear approximation to activation patching. The key assumption is that the corrupted run is a locally linear function of its activations.
We thus take the gradient of the patch metric (ioi_metric
) with respect to its activations, where we consider a patch of activations to be applying corrupted_x
to corrupted_x + (clean_x - corrupted_x)
. Then, we compute the patch metric’s change: (corrupted_grad_x * (clean_x - corrupted_x)).sum()
. All we need to do is take a backwards pass on the corrupted prompt with respect to the patch metric and cache all gradients with respect to the activations.
Let’s see how this breaks down in NNsight!
A note on c_proj: Most HuggingFace models don’t have nice individual attention head representations to hook. Instead, we have the linear layer ``c_proj`` which implicitly combines the “projection per attention head” and the “sum over attention head” operations. See this snippet from ARENA for more information.
TL;DR: We will use the input to c_proj
for causal interventions on a particular attention head.
[18]:
clean_out = []
corrupted_out = []
corrupted_grads = []
with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass
with tracer.invoke(clean_tokens) as invoker_clean:
# Gather each layer's attention
for layer in model.transformer.h:
# Get clean attention output for this layer
# across all attention heads
attn_out = layer.attn.c_proj.input
clean_out.append(attn_out.save())
with tracer.invoke(corrupted_tokens) as invoker_corrupted:
# Gather each layer's attention and gradients
for layer in model.transformer.h:
# Get corrupted attention output for this layer
# across all attention heads
attn_out = layer.attn.c_proj.input
corrupted_out.append(attn_out.save())
# save corrupted gradients for attribution patching
corrupted_grads.append(attn_out.grad.save())
# Let's get the logits for the model's output
# for the corrupted run
logits = model.lm_head.output.save()
# Our metric uses tensors saved on cpu, so we
# need to move the logits to cpu.
value = ioi_metric(logits.cpu())
# We also need to run a backwards pass to
# update gradient values
value.backward()
Next, for a given activation we compute (corrupted_grad_act * (clean_act - corrupted_act)).sum()
. We use einops.reduce
to rearrange and sum activations over the correct dimension. In this case, we want to estimate the effect of specific attention heads, so we sum over heads rather than token position.
[19]:
patching_results = []
for corrupted_grad, corrupted, clean, layer in zip(
corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):
residual_attr = einops.reduce(
corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
"batch (head dim) -> head",
"sum",
head = 12,
dim = 64,
)
patching_results.append(
residual_attr.detach().cpu().numpy()
)
[20]:
fig = px.imshow(
patching_results,
color_continuous_scale="RdBu",
color_continuous_midpoint=0.0,
title="Attribution Patching Over Attention Heads",
labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},
)
fig.show()
Here, we see that the early layer attention heads may not be important for IOI.
3️⃣ Attribution Patching Over Position#
One benefit of attribution patching is efficiency. Activation patching requires a separate forward pass per activation patched while every attribution patch can be done simultaneously in two forward passes and one backward pass. Attribution patching makes patching much more scalable to large models and can serve as a useful heuristic to find the interesting activations to patch.
In practice, whie this approximation is decent when patching in “small” activations like head outputs, performance decreases significantly when patching in “big” activations like those found in the residual stream.
Using the same outputs we cached above, we can get the individual contributions at each token position simply by summing across token positions. Although this is messy, it’s a quick approximation of the attention mechanism’s contribution across token position.
Note: in our specific case here, patching across positions does NOT reflect the entire residual stream, just the post-attention output (i.e., excludes MLPs).
[21]:
patching_results = []
for corrupted_grad, corrupted, clean, layer in zip(
corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):
residual_attr = einops.reduce(
corrupted_grad.value * (clean.value - corrupted.value),
"batch pos dim -> pos",
"sum",
)
patching_results.append(
residual_attr.detach().cpu().numpy()
)
[22]:
fig = px.imshow(
patching_results,
color_continuous_scale="RdBu",
color_continuous_midpoint=0.0,
title="Attribution Patching Over Token Position",
labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},
)
fig.show()
This result looks similar to our previous result using activation patching but is much less precise, as expected!