Activation Patching#

Here we use nnsight for the Indirect Object Identification task.

First lets do our imports

[1]:
import plotly.express as px

from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
[2]:
# Load gpt2
model = LanguageModel("openai-community/gpt2", device_map="cuda:0")

We define our clean prompt and corrupted prompt

[3]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)

We use the tokenizer on the two words of interest, “John” and “Mary” to find the token that represents them. That way we can grab the prediction for these two tokens and compare. Make sure to add a space before the word as the combined space + word token is what were looking for as our prompts don’t end with a space.

[4]:
correct_index = model.tokenizer(" John")["input_ids"][0]
incorrect_index = model.tokenizer(" Mary")["input_ids"][0]

print(f"' John': {correct_index}")
print(f"' Mary': {incorrect_index}")
' John': 1757
' Mary': 5335

Now we do the actual patching intervention!

Thanks to nnsight, the whole experiment can happen in one forward pass by breaking up inputs into multiple invocation calls and batching them.

[7]:
N_LAYERS = model.config.n_layer

# Enter nnsight tracing context
with model.trace() as tracer:

    # Clean run
    with tracer.invoke(clean_prompt) as invoker:
        clean_tokens = model.input[1]["input_ids"].squeeze()

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        # No need to call .save() as we don't need the values after the run, just within the experiment run.
        clean_hs = [
            model.transformer.h[layer_idx].output[0]
            for layer_idx in range(N_LAYERS)
        ]

        # Get logits from the lm_head.
        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
        ).save()

    # Corrupted run
    with tracer.invoke(corrupted_prompt) as invoker:
        corrupted_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupted_logit_diff = (
            corrupted_logits[0, -1, correct_index]
            - corrupted_logits[0, -1, incorrect_index]
        ).save()

    ioi_patching_results = []

    # Iterate through all the layers
    for layer_idx in range(len(model.transformer.h)):
        _ioi_patching_results = []

        # Iterate through all tokens
        for token_idx in range(len(clean_tokens)):

            # Patching corrupted run at given layer and token
            with tracer.invoke(corrupted_prompt) as invoker:

                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[
                    layer_idx
                ].t[token_idx]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, correct_index]
                    - patched_logits[0, -1, incorrect_index]
                )

                # Calculate the improvement in the correct token after patching.
                patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                    clean_logit_diff - corrupted_logit_diff
                )

                _ioi_patching_results.append(patched_result.save())

        ioi_patching_results.append(_ioi_patching_results)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[8]:
print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value.item(), Proxy)

clean_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_tokens)]

fig = px.imshow(
    ioi_patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer"},
    x=token_labels,
    title="Normalized Logit Difference After Patching Residual Stream on the IOI Task",
)

fig.show()
Clean logit difference: 4.124
Corrupted logit difference: -2.272