Attribution Patching#

📗 This tutorial is adapted from Neel Nanda’s blog post.

Activation patching is a method to determine how model components influence model computations (see our activation patching tutorial for more information). Although activation patching is a useful tool for circuit identification, it requires a separate forward pass through the model for each patched activation, making it time- and resource-intensive.

Attribution patching uses gradients to take a linear approximation to activation patching and can be done simultaneously in two forward and one backward pass, making it much more scalable to large models.

You can find a colab version of the tutorial here or Neel’s version here.

Read more about an application of Attribution Patching in Attribution Patching Outperforms Automated Circuit Discovery. 📙

Setup#

If you are using Colab or haven’t yet installed NNsight, install the package:

!pip install -U nnsight
[1]:
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight

Import libraries

[2]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab" if is_colab else "plotly_mimetype+notebook_connected+notebook"


from nnsight import LanguageModel
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

[3]:
import nnsight
print(nnsight.__version__)
0.4.1

1️⃣ Indirect Object Identification (IOI) Patching#

Indirect object identification (IOI) is the ability to infer the correct indirect object in a sentence, allowing one to complete sentences like “John and Mary went to the shops, John gave a bag to” with the correct answer “ Mary”. Understanding how language models like GPT-2 perform linguistic tasks like IOI helps us gain insights into their internal mechanisms and decision-making processes.

Here, we apply the IOI task to explore how GPT-2 small is performing IOI with attribution patching.

📚 Note: For more detail on the IOI task, check out the ARENA walkthrough.

[4]:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
clear_output()
print(model)
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)

Looking at the model architecture, we can see there are 12 layers, each with 12 GPT-2 Blocks. We will use attribution patching to approximate the contribution of each layer and each attention head for the IOI task.

We next define 8 IOI prompts, with each prompt having one related corrupted prompt variation (i.e., the indirect object is swapped out).

[5]:
prompts = [
    "When John and Mary went to the shops, John gave the bag to",
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the park, James gave the ball to",
    "When Tom and James went to the park, Tom gave the ball to",
    "When Dan and Sid went to the shops, Sid gave an apple to",
    "When Dan and Sid went to the shops, Dan gave an apple to",
    "After Martin and Amy went to the park, Amy gave a drink to",
    "After Martin and Amy went to the park, Martin gave a drink to",
]

# Answers are each formatted as (correct, incorrect):
answers = [
    (" Mary", " John"),
    (" John", " Mary"),
    (" Tom", " James"),
    (" James", " Tom"),
    (" Dan", " Sid"),
    (" Sid", " Dan"),
    (" Martin", " Amy"),
    (" Amy", " Martin"),
]

# Tokenize clean and corrupted inputs:
clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

Next, we create a function to calculate the mean logit difference for the correct vs incorrect answer tokens.

[7]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

We then calculate the logit difference for both the clean and the corrupted baselines.

[8]:
clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")
Clean logit diff: 2.8138
Corrupted logit diff: -2.8138

Now let’s define an ioi_metric function to evaluate patched IOI changes normalized to our clean and corruped baselines.

[9]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")
Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000

2️⃣ Attribution Patching Over Components#

Attribution patching is a technique that uses gradients to take a linear approximation to activation patching. The key assumption is that the corrupted run is a locally linear function of its activations.

We thus take the gradient of the patch metric (ioi_metric) with respect to its activations, where we consider a patch of activations to be applying corrupted_x to corrupted_x + (clean_x - corrupted_x). Then, we compute the patch metric’s change: (corrupted_grad_x * (clean_x - corrupted_x)).sum(). All we need to do is take a backwards pass on the corrupted prompt with respect to the patch metric and cache all gradients with respect to the activations.

Let’s see how this breaks down in NNsight!

A note on c_proj: Most HuggingFace models don’t have nice individual attention head representations to hook. Instead, we have the linear layer ``c_proj`` which implicitly combines the “projection per attention head” and the “sum over attention head” operations. See this snippet from ARENA for more information.

TL;DR: We will use the input to c_proj for causal interventions on a particular attention head.

[10]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
        # Gather each layer's attention
        for layer in model.transformer.h:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in model.transformer.h:
            # Get corrupted attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            corrupted_out.append(attn_out.save())
            # save corrupted gradients for attribution patching
            corrupted_grads.append(attn_out.grad.save())

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = model.lm_head.output.save()

        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        value.backward()

Next, for a given activation we compute (corrupted_grad_act * (clean_act - corrupted_act)).sum(). We use einops.reduce to rearrange and sum activations over the correct dimension. In this case, we want to estimate the effect of specific attention heads, so we sum over heads rather than token position.

[11]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 12,
        dim = 64,
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )
[12]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Attention Heads",
    labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Here, we see that the early layer attention heads may not be important for IOI.

3️⃣ Attribution Patching Over Position#

One benefit of attribution patching is efficiency. Activation patching requires a separate forward pass per activation patched while every attribution patch can be done simultaneously in two forward passes and one backward pass. Attribution patching makes patching much more scalable to large models and can serve as a useful heuristic to find the interesting activations to patch.

In practice, whie this approximation is decent when patching in “small” activations like head outputs, performance decreases significantly when patching in “big” activations like those found in the residual stream.

Using the same outputs we cached above, we can get the individual contributions at each token position simply by summing across token positions. Although this is messy, it’s a quick approximation of the attention mechanism’s contribution across token position.

Note: in our specific case here, patching across positions does NOT reflect the entire residual stream, just the post-attention output (i.e., excludes MLPs).

[13]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )
[14]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position",
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

This result looks similar to our previous result using activation patching but is much less precise, as expected!

Remote Attribution Patching#

Now that we know how to run an attribution patching experiment in nnsight, let’s go over how you can use NDIF’s publicly-hosted models to further scale your research!

We’re going to run the same experiment, but now using Llama 3.1 8B. Completing this section of the tutorial will require you to configure NNsight for remote execution if you haven’t already.

Remote Setup#

[15]:
from nnsight import CONFIG

if is_colab:
    # include your HuggingFace Token and NNsight API key on Colab secrets
    from google.colab import userdata
    NDIF_API = userdata.get('NDIF_API')
    HF_TOKEN = userdata.get('HF_TOKEN')

    CONFIG.set_default_api_key(NDIF_API)
    !huggingface-cli login -token HF_TOKEN

clear_output()
[28]:
import torch
import nnsight
from nnsight import LanguageModel

Next, let’s load the Llama 3.1 8B model, once again using NNsight’s LanguageModel wrapper. Because we’ll be running the model on NDIF’s remote servers, no need to specify a device_map!

[29]:
# Load model
llm = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
print(llm)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)

IOI Task Setup#

We’ve already defined some prompts in the above tutorial, but we’ll have to re-tokenize them for Llama 8B.

[41]:
# Tokenize clean and corrupted inputs:
clean_tokens = llm.tokenizer(prompts, return_tensors="pt")["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [llm.tokenizer(answers[i][j])["input_ids"][1] for j in range(2)]
        for i in range(len(answers))
    ]
)

Next, we’ll establish clean & corrupted baselines for our IOI metric, using the model’s clean and corrupted logits and the get_logit_diff function defined earlier.

[33]:
clean_logits = llm.trace(clean_tokens, trace=False, remote=True)
corrupted_logits = llm.trace(corrupted_tokens, trace=False, remote=True)

clean_logits = clean_logits['logits']
corrupted_logits = corrupted_logits['logits']

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"\n\nClean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")
2025-02-06 14:34:29,807 a81d37fa-2304-4023-ad29-81c9f54063ab - RECEIVED: Your job has been received and is waiting approval.
2025-02-06 14:34:30,130 a81d37fa-2304-4023-ad29-81c9f54063ab - APPROVED: Your job was approved and is waiting to be run.
2025-02-06 14:34:30,344 a81d37fa-2304-4023-ad29-81c9f54063ab - RUNNING: Your job has started running.
2025-02-06 14:34:33,597 a81d37fa-2304-4023-ad29-81c9f54063ab - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 46.5M/46.5M [00:07<00:00, 6.05MB/s]
2025-02-06 14:34:42,033 357e3fb9-23bf-47ae-9cea-12f11b62de31 - RECEIVED: Your job has been received and is waiting approval.
2025-02-06 14:34:42,326 357e3fb9-23bf-47ae-9cea-12f11b62de31 - APPROVED: Your job was approved and is waiting to be run.
2025-02-06 14:34:42,541 357e3fb9-23bf-47ae-9cea-12f11b62de31 - RUNNING: Your job has started running.
2025-02-06 14:34:46,309 357e3fb9-23bf-47ae-9cea-12f11b62de31 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 46.5M/46.5M [00:05<00:00, 9.16MB/s]


Clean logit diff: 5.6875
Corrupted logit diff: -5.6875

We’ve also already defined our ioi_metric function. Let’s plug in our logit values.

[34]:
print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")
Clean Baseline is 1: 0.5273
Corrupted Baseline is 0: 0.4727

Remote Attribution Patching#

Great! We have some baselines. Now, let’s run the attribution patching pipeline on Llama 8B. We can’t copy the code exactly, because Llama 8B has a different model structure than GPT-2, but we’re following the same steps: a clean run and a corrupted run as invokes during one tracing context.

[35]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with llm.trace(remote = True) as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
      # need to set requires grad to true for remote
        llm.model.layers[0].self_attn.o_proj.input.requires_grad = True
        # Gather each layer's attention
        for layer in llm.model.layers:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.self_attn.o_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in llm.model.layers:
            # Get corrupted attention output for this layer
            # across all attention heads
            attn_out = layer.self_attn.o_proj.input
            corrupted_out.append(attn_out.save())
            # save corrupted gradients for attribution patching
            corrupted_grads.append(attn_out.grad.save())

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = llm.lm_head.output.save()

        # Our IOI metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        value.backward()
2025-02-06 14:35:04,105 7373f906-bedf-4c24-b53c-b559b94c3211 - RECEIVED: Your job has been received and is waiting approval.
2025-02-06 14:35:04,468 7373f906-bedf-4c24-b53c-b559b94c3211 - APPROVED: Your job was approved and is waiting to be run.
2025-02-06 14:35:04,706 7373f906-bedf-4c24-b53c-b559b94c3211 - RUNNING: Your job has started running.
2025-02-06 14:35:25,407 7373f906-bedf-4c24-b53c-b559b94c3211 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 250M/250M [00:46<00:00, 5.38MB/s]

Awesome! Let’s take a look at attention head contributions across layers.

[37]:
# format data for plotting across attention heads
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 32,
        dim = 128,
    )

    patching_results.append(
        (residual_attr.float()).detach().numpy()
    )
[38]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Attention Heads",
    labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Next, let’s check out the contribution of the residual stream over token position across layers.

[39]:
# format data for plotting across input tokens
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        (residual_attr.detach().cpu().float()).numpy()
    )
[40]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position",
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Great! We’ve now successfully performed an attribution patching experiment on GPT-2 and Llama 8b.