Attribution Patching#

📗 This tutorial is adapted from Neel Nanda’s blog post, here. 📘 Attribution patching was used in Attribution Patching Outperforms Automated Circuit Discovery. 📙 You can find a colab version of the tutorial here, or Neel’s version here. 📚 Note: You can skip over section one which sets up the IOI scenario. For more detail on this task, check out the ARENA walkthrough.

Setup (Ignore)#

[12]:
import einops
import torch
import plotly.express as px

from nnsight import LanguageModel

1️⃣ IOI Patching Setup

Here, we set up the IOI task to demonstrate the use of attribution patching. This circuit is responsible for the model’s ability to complete sentences like "John and Mary went to the shops, John gave a bag to" with the correct token " Mary".

[13]:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
[14]:
prompts = [
    "When John and Mary went to the shops, John gave the bag to",
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the park, James gave the ball to",
    "When Tom and James went to the park, Tom gave the ball to",
    "When Dan and Sid went to the shops, Sid gave an apple to",
    "When Dan and Sid went to the shops, Dan gave an apple to",
    "After Martin and Amy went to the park, Amy gave a drink to",
    "After Martin and Amy went to the park, Martin gave a drink to",
]

answers = [
    (" Mary", " John"),
    (" John", " Mary"),
    (" Tom", " James"),
    (" James", " Tom"),
    (" Dan", " Sid"),
    (" Sid", " Dan"),
    (" Martin", " Amy"),
    (" Amy", " Martin"),
]

clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]

corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)
[15]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Clean logit diff: 2.8138
Corrupted logit diff: -2.8138

Our metric will be a linear function of the logit difference, where logit difference is the difference in logit between the indirect object’s name and the subject’s name (eg, logit(Mary) - logit(John)).

[16]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")
Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000

2️⃣ Attribution Patching Over Components

Attribution patching is a technique that uses gradients to take a linear approximation to activation patching.

The key assumption is that the corrupted run is a locally linear function of its activations. We can take the gradient of the patch metric with respect to its activations where we consider a patch of activations to be applying the corrupted_x -> corrupted_x + (clean_x - corrupted_x). Then, we compute the patch metric change (corrupted_grad_x * (clean_x - corrupted_x)).sum().

All we need to do is take a backwards pass on the corrupted prompt with respect to the patching metric, and cache all gradients with respect to the activations. Let’s see how this breaks down in NNsight!

A note on c_proj


Most HuggingFace models don’t have nice, individual attention head representations to hook. Instead, we have the linear layer c_proj which implicitly combines the “projection per attention head” and the “sum over attention head” operations. See this snippet from ARENA for more information.


[17]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:

    with tracer.invoke(clean_tokens) as invoker_clean:

        for layer in model.transformer.h:
            attn_out = layer.attn.c_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:

        for layer in model.transformer.h:
            attn_out = layer.attn.c_proj.input
            corrupted_out.append(attn_out.save())
            corrupted_grads.append(attn_out.grad.save())

        logits = model.lm_head.output.save()
        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())
        value.backward()

Then, for a given activation we compute (corrupted_grad_act * (clean_act - corrupted_act)).sum(). Below, we use einops.reduce to rearrange and sum activations over the correct dimension. In this case, we want to estimate the effect specific components, so we sum over heads rather than position.

[18]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 12,
        dim = 64,
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )
[19]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Attention Heads"
)

fig.update_layout(
    xaxis_title="Head",
    yaxis_title="Layer"
)

fig.show()