Activation Patching#
đ You can find an interactive Colab version of this tutorial here.
Activation patching is a technique used to understand how different parts of a model contribute to its behavior. In an activation patching experiment, we modify or âpatchâ the activations of certain model components and observe the impact on model output.
Activation patching experiments typically follow these steps:
Baseline Run: Run the model and record original activations.
Corrupted Run: Run the model with with a counterfactual (i.e., corrupted) prompt and record the difference in activations.
Patching: Replace activations at the model component of interest with alternate activations (or zeros, which is sometimes referred to as ablation).
By systematically testing different components this way, researchers can determine how information flows through the model. One common use case is circuit identification, where a circuit is a subgraph of a full model that is responsible for a specific and human-interpretable task (e.g., detecting whether an input is in English). Activation patching can help identify which model components are essential for model performance on a given task.
In this tutorial, we use nnsight
to perform a simple activation patching experiment using an indirect object identification (IOI) task.
Note: IOI Task#
Activation patching was used to find the Indirect Object Identification (IOI) circuit in GPT-2 small. IOI is a natural language task in which a model predicts the indirect object in a sentence. IOI tasks typically involve identifying the indirect object from two names introduced in an initial dependent clause. One name (e.g. âMaryâ) is the subject (S1), and the other name (e.g. âJohnâ) is the indirect object (IO). In the main clause, a second occurrence of the subject (S2) typically performs an action involving the exchange of an item. The sentence always ends with the preposition âto,â and the task is to correctly complete it by identifying the non-repeated name (IO).
In this exercise, we will use the following âcleanâ prompt:
"After John and Mary went to the store, Mary gave a bottle of milk to"
This promptâs correct answer (and thus its indirect object) is: " John"
We will also use a corrupted prompt to test how activation patching works. This corrupted prompt will switch the identity of the indirect object, so we can test how the model responds to this change.
"After John and Mary went to the store, John gave a bottle of milk to"
This promptâs correct answer (and thus its indirect object) is: " Mary"
Set Up#
If using Colab, install NNsight:
!pip install -U nnsight
[1]:
from IPython.display import clear_output
[2]:
import nnsight
from nnsight import CONFIG
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Letâs start with our imports:
[3]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected+colab+notebook"
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
[4]:
# Load gpt2
model = LanguageModel("openai-community/gpt2", device_map="auto")
clear_output()
[5]:
print(model)
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2SdpaAttention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
(generator): WrapperModule()
)
Next up, we define our clean prompt and our corrupted prompt. As prompts may be associated with many different feature circuits (i.e., circuits responsible for IOI, deciding if the language is English, or prompt refusal), choosing a counterfactual prompt with only changes directly related your feature of interest is essential.
Here, we switch the name of the repeated subject, thus swapping out the indirect object for our IOI task:
[6]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
"After John and Mary went to the store, John gave a bottle of milk to"
)
We then use the tokenizer on the two words of interest (âJohnâ and âMaryâ) to find the token that represents them. That way we can grab the prediction for these two tokens and compare. Because our prompts donât end in a space, make sure to add a space before each word (i.e., the combined space + word token is what weâre looking for).
[7]:
correct_index = model.tokenizer(" John")["input_ids"][0] # includes a space
incorrect_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space
print(f"' John': {correct_index}")
print(f"' Mary': {incorrect_index}")
' John': 1757
' Mary': 5335
Patching Experiment#
Now we can run the actual patching intervention! What does this even mean?
We now have two prompts, a âcleanâ one and a âcorruptedâ one. Intuitively, the model output for each of these prompts should be different: weâd expect the model to answer âJohnâ for the clean prompt and âMaryâ for the corrupted prompt.
In this experiment, we run the model with the clean prompt as an input and then (1) get each layerâs output value (i.e., residual stream) and (2) calculate the logit difference between the correct and incorrect answers for this run. Next, we calculate the logit difference between the correct and incorrect answers for the corrupted prompt.
Step 1: Clean Run#
First, we run the model with the clean prompt:
"After John and Mary went to the store, Mary gave a bottle of milk to"
During this clean run, we collect the final output of each layer. We also record the logit difference in the final model output between the correct answer token " John"
and the incorrect token " Mary"
.
[8]:
N_LAYERS = len(model.transformer.h)
# Clean run
with model.trace(clean_prompt) as tracer:
clean_tokens = tracer.invoker.inputs[0]['input_ids'][0]
# Get hidden states of all layers in the network.
# We index the output at 0 because it's a tuple where the first index is the hidden state.
clean_hs = [
model.transformer.h[layer_idx].output[0].save()
for layer_idx in range(N_LAYERS)
]
# Get logits from the lm_head.
clean_logits = model.lm_head.output
# Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
clean_logit_diff = (
clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
).save()
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Step 2: Corrupted Run#
Next, we run the model using the corrupted input prompt:
"After John and Mary went to the store, John gave a bottle of milk to"
During this corrupted run, we collect the logit difference in the final model output between the correct and incorrect answer tokens
Note: because we are testing changes induced by the corrupted prompt, the target answers remain the same as in the clean run. That is, the correct token is still " John"
and the incorrect token is still " Mary"
.
[9]:
# Corrupted run
with model.trace(corrupted_prompt) as tracer:
corrupted_logits = model.lm_head.output
# Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
corrupted_logit_diff = (
corrupted_logits[0, -1, correct_index]
- corrupted_logits[0, -1, incorrect_index]
).save()
Step 3: Activation Patching Intervention#
Finally, we perform our activation patching procedure. For each token position in the clean prompt, we loop through all layers of the model. Within each layer, we run a forward pass using the corrupted prompt, and patch in the corresponding activation from our clean run at the given token position. We then collect the final output difference between the correct and incorrect answer tokens for each patched activation.
[10]:
# Activation Patching Intervention
ioi_patching_results = []
# Iterate through all the layers
for layer_idx in range(len(model.transformer.h)):
_ioi_patching_results = []
# Iterate through all tokens
for token_idx in range(len(clean_tokens)):
# Patching corrupted run at given layer and token
with model.trace(corrupted_prompt) as tracer:
# Apply the patch from the clean hidden states to the corrupted hidden states.
model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]
patched_logits = model.lm_head.output
patched_logit_diff = (
patched_logits[0, -1, correct_index]
- patched_logits[0, -1, incorrect_index]
)
# Calculate the improvement in the correct token after patching.
patched_result = (patched_logit_diff - corrupted_logit_diff) / (
clean_logit_diff - corrupted_logit_diff
)
_ioi_patching_results.append(patched_result.save())
ioi_patching_results.append(_ioi_patching_results)
Note: Optimize workflow with NNsight batching#
Although we broke up the workflow for ease of understanding, we can use nnsight
to further speed up computation.
Thanks to nnsight
, the whole experiment can happen in one forward pass by breaking up inputs into multiple invocation calls and batching them.
[13]:
N_LAYERS = len(model.transformer.h)
# Enter nnsight tracing context
with model.trace() as tracer:
# Clean run
with tracer.invoke(clean_prompt) as invoker:
clean_tokens = invoker.inputs[0]['input_ids'][0]
# No need to call .save() as we don't need the values after the run, just within the experiment run.
clean_hs = [
model.transformer.h[layer_idx].output[0]
for layer_idx in range(N_LAYERS)
]
# Get logits from the lm_head.
clean_logits = model.lm_head.output
# Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
clean_logit_diff = (
clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
).save()
# Corrupted run
with tracer.invoke(corrupted_prompt) as invoker:
corrupted_logits = model.lm_head.output
# Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
corrupted_logit_diff = (
corrupted_logits[0, -1, correct_index]
- corrupted_logits[0, -1, incorrect_index]
).save()
ioi_patching_results = []
# Iterate through all the layers
for layer_idx in range(len(model.transformer.h)):
_ioi_patching_results = []
# Iterate through all tokens
for token_idx in range(len(clean_tokens)):
# Patching corrupted run at given layer and token
with tracer.invoke(corrupted_prompt) as invoker:
# Apply the patch from the clean hidden states to the corrupted hidden states.
model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]
patched_logits = model.lm_head.output
patched_logit_diff = (
patched_logits[0, -1, correct_index]
- patched_logits[0, -1, incorrect_index]
)
# Calculate the improvement in the correct token after patching.
patched_result = (patched_logit_diff - corrupted_logit_diff) / (
clean_logit_diff - corrupted_logit_diff
)
_ioi_patching_results.append(patched_result.save())
ioi_patching_results.append(_ioi_patching_results)
Visualize Results#
Letâs define a function to plot our activation patching results.
[11]:
def plot_ioi_patching_results(model,
ioi_patching_results,
x_labels,
plot_title="Normalized Logit Difference After Patching Residual Stream on the IOI Task"):
ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value.item(), Proxy)
fig = px.imshow(
ioi_patching_results,
color_continuous_midpoint=0.0,
color_continuous_scale="RdBu",
labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
x=x_labels,
title=plot_title,
)
return fig
Letâs see how the patching intervention changes the logit difference! Letâs use a heatmap to examine how the logit difference changes after patching each layerâs output across token positions.
[12]:
print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")
clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]
fig = plot_ioi_patching_results(model, ioi_patching_results,token_labels,"Patching GPT-2-small Residual Stream on IOI task")
fig.show()
Clean logit difference: 4.124
Corrupted logit difference: -2.272