Causal Mediation Analysis II: Explaining LLMs¶
In this tutorial, we'll walk through how to apply causal mediation analysis to uncover the internal processes of LLMs.
Before we begin, we encourage you to read through causal mediation analysis introductory tutorial, which goes over key concepts in causal mediation analysis and uses them to explain a toy arithmetic neural network.
In this tutorial, we'll broaden our horizons to explain an LLM. Along the way, we'll cover:
- Information flow: how to track the passage of information from the input to the output.
- Attention head plots: understanding which components might be responsible for model behavior.
- Gender bias in LMs: what causal mediation analysis can teach us about bias in LMs.
📗 Prefer to use Colab? Follow the tutorial here!
0️⃣ Setup¶
from IPython.display import clear_output
# install nnsight to analyze neural network internals
!pip install -U nnsight
clear_output()
Note
In this tutorial, we use the Llama-3.2 1B Instruct model. Before starting the tutorial, please go to the model's huggingface page and request permission to use the model. Then, log in to this notebook with your huggingface access token.
from huggingface_hub import notebook_login
notebook_login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
# load model
import nnsight
model = nnsight.LanguageModel(
'meta-llama/Llama-3.2-1B',
device_map='auto'
)
clear_output()
1️⃣ Causal mediation & information flow¶
In this tutorial, we'll apply causal mediation analysis to understand how an LLM recalls information about cities. For example, when we start a sentence with "Paris is in the city of ____", an LLM knows that the next token is "France". Where does the model store this information?
# does our model know where Paris is?
import torch
base_prompt = "Paris is in the country of"
with torch.no_grad():
with model.trace(base_prompt):
# get logits over final token
base_logits = model.output.logits[:, -1, :].save()
print('Prompt:', base_prompt)
print(f'Model completion:{model.tokenizer.decode(base_logits.argmax(dim=-1))}')
Some parameters are on the meta device because they were offloaded to the cpu.
Prompt: Paris is in the country of Model completion: France
Somewhere, somehow, our model stores information about Paris that allows it to recall the country France. Can we intervene on the neurons that mediate the connection between Paris and France to make our model think the Paris is in a completely different country, like Brazil?
To do this, we'll follow the same causal mediation analysis steps that we covered in the previous tutorial.
First, create a base and source pair where only the city changes. Then, at each layer and token position,
- Run a forward pass on the source input and store the value of the model's activation.
- Run a forward pass on the base input and intervene on the activation by patching in its source value from (1).
- Collect the result of the intervened run and check - did patching change the value of the base output?

We'll use Paris and Rio as our base and source cities.
base_prompt = "Paris is in the country of"
source_prompt = "Rio is in the country of"
Let's go through causal mediation analysis step by step!
# 1. collect source activations
with torch.no_grad():
with model.trace(source_prompt):
# get hidden states of all layers in the network
# we index the output at 0 because it's a tuple where the first index is the hidden state
source_hidden_states = [
layer.output[0].save()
for layer in model.model.layers
]
This step might take 1-2 minutes, because we're looping through layers & token positions.
# 2 and 3. intervene and collect intervention results
# we'll keep track of the probability of outputing Brazil vs. the probability of outputing France
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space
base_country = model.tokenizer(" France")["input_ids"][1] # includes a space
num_tokens = len(model.tokenizer(base_prompt).input_ids) # get number of tokens in prompt
causal_effects = []
# iterate through all the layers
for layer_idx in range(model.config.num_hidden_layers):
causal_effect_per_layer = []
# iterate through all tokens
for token_idx in range(num_tokens):
with torch.no_grad():
with model.trace(base_prompt) as tracer:
# 2. change the value of the base activation to the source value
model.model.layers[layer_idx].output[0][:, token_idx, :] = \
source_hidden_states[layer_idx][:, token_idx, :]
# 3. get intervened output & compare to base output
intervened_logits = model.output.logits[:, -1, :]
intervened_probs = intervened_logits.softmax(dim=-1)
# in this case, we'll keep track of how much we changed
# towards Brazil & away from France -> the higher, the more causal effect!
intervened_prob_diff = (intervened_probs[0, source_country] - intervened_probs[0, base_country]).item().save()
causal_effect_per_layer.append(intervened_prob_diff)
causal_effects.append(causal_effect_per_layer)
Let's check out our results! The darker a cell is, the greater the difference between $p(``Brazil")$ and $p(``France")$, and hence the more causal effect the residual stream at that layer / token position has on mediating the path between the city and its country.
# visualize our results!
import plotly.express as px
fig = px.imshow(
causal_effects,
x=['<|begin_of_text|>', 'Rio -> Paris', 'is', 'in', 'the', 'country', 'of'],
y=list(range(model.config.num_hidden_layers)),
template='simple_white',
color_continuous_scale=[[0, '#FFFFFF'], [1, "#A59FD9"]]
)
fig.update_layout(
xaxis_title='token',
yaxis_title='layer',
yaxis=dict(autorange='min')
)
fig
Let's break this plot down. These type of residual stream causal mediation plots come up a lot, so they're useful to understand! (for example, see the applied tutorial on geometry of truth!)
We'll start from the lower layers and work our way up.
- Total effect: at the very first layer, the city token ("Rio" -> "Paris") has a causal effect. This makes sense, since it corresponds to the total effect of the input - changing just that token in the input changes the model's prediction of the next token.
- Intermediate tokens: in this example, the other tokens in the sentence (up until the last one) don't have significant causal effect. Depending on your intuition, this also might make some sense - the model isn't storing any information related to Paris in those tokens.
- Final layers of city token: this is where things start to get interesting. At the very last layers (layer 9 and above), changing the residual stream from "Paris" to "Rio" doesn't actually change the model's output! In a sense, the information is "no longer there", or rather, the model no longer has need for it, so it doesn't get used. Where did the information about Paris' country go then?
- Final layers of final token: at some point, the model has to output the next token! And with modern autoregressive models, only the information in the final token can be used in the final layer. Hence, at some point during computation, the model must move information from previous tokens into the final token position. That's what's likely going on at layer 9! (keep this layer number in mind - we'll return to it soon).
Information flow plots are a great way to visualize where interesting computation is happening in the model. It doesn't quite tell us how the model is going about solving the task, but it does help us understand which components to inspect!
Next, we'll try to look into the "how" a little more by patching specific components in the model. We'll do a search over the entire model, but you might already be able to guess which layers will contain the mechanism we're looking for.
2️⃣ Intervening on specific model components¶
In the last section, we saw how residual stream patching can help us identify where information flows within a language model. In this section, we'll go a step deeper and try to find the components responsible for moving this information. It will lead to some very pretty-looking attention plots!
We know that our language model has to move information from the city token to the final token before it can predict the next token in the sequence. Yet there's only one type of component that's able to perform this operation: attention heads. This ability to move information betweent tokens make attention heads quite powerful, and a useful component to analyze when interpreting a model's behavior.
Thankfully, causal mediation analysis is agnostic to our selection of model component! We can follow the same steps to find attention heads that causally mediate the recalling of a city's country. For each attention head, we will
- Run a forward pass on the source input and store the value of the head's activation.
- Run a forward pass on the base input and intervene on the head's activation by patching in its source value from (1).
- Collect the result of the intervened run and check - did patching the attention head change the value of the base output?
# 1. collect source activations
with torch.no_grad():
with model.trace(source_prompt):
# index into attention head outputs
# (note: we chose the o-projection input,
# but you can try intervening on q/k/v-proj too!)
source_hidden_states = [
layer.self_attn.o_proj.input.save()
for layer in model.model.layers
]
We need to change some small things to make sure that we're indexing into the right attention head, but otherwise it's almost the same code as the residual stream interventions in the previous section!
# 2 and 3. intervene and collect intervention results
# we'll keep track of the probability of outputing Brazil vs. the probability of outputing France
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space
base_country = model.tokenizer(" France")["input_ids"][1]
attn_dim = model.config.hidden_size // model.config.num_attention_heads
causal_effects = []
# iterate through all the layers
for layer_idx in range(model.config.num_hidden_layers):
causal_effect_per_layer = []
# iterate through all tokens
for head_index in range(model.config.num_attention_heads):
with torch.no_grad():
with model.trace(base_prompt) as tracer:
# 2. change the value of the base activation to the source value
attention_value = model.model.layers[layer_idx].self_attn.o_proj.input
# change value only at attention head index (across all tokens)
attention_value[:, :, head_index * attn_dim:(head_index + 1) * attn_dim] = \
source_hidden_states[layer_idx][:, :, head_index * attn_dim:(head_index + 1) * attn_dim]
# put attention head back in
model.model.layers[layer_idx].self_attn.o_proj.input = attention_value
# 3. get intervened output & compare to base output
intervened_logits = model.output.logits[:, -1, :]
intervened_probs = intervened_logits.softmax(dim=-1)
# in this case, we'll keep track of how much we changed
# towards Brazil & away from France -> the higher, the more causal effect!
intervened_prob_diff = (intervened_probs[0, source_country] - intervened_probs[0, base_country]).item().save()
causal_effect_per_layer.append(intervened_prob_diff)
causal_effects.append(causal_effect_per_layer)
# visualize our results!
import plotly.express as px
fig = px.imshow(
causal_effects,
x=list(range(model.config.num_attention_heads)),
y=list(range(model.config.num_hidden_layers)),
template='simple_white',
color_continuous_scale=[[0, '#FFFFFF'], [1, "#A59FD9"]]
)
fig.update_layout(
xaxis_title='attention head',
yaxis_title='layer',
yaxis=dict(autorange='min')
)
fig