Distributed Alignment Search (DAS): Searching for Linearly Encoded Concepts in Model Representations#
Imagine we want to edit a model to think that Paris is in the country of Brazil, without changing whatever else the model knows about Paris (e.g., its language, continent, …). Which representations in the model encode this fact about Paris?
In this tutorial, we’ll go over Distributed Alignment Search, or DAS, which helps us automatically identify a set of linear subspaces in a model’s representations that encode a particular concept.
Note
We encourage using “light mode” to view this tutorial, since the color blocks are harder to read in “dark mode”. You can also follow along on this Colab notebook.
Before we begin!
These are good things to know before we begin the tutorial
Activation patching - check out the activation patching tutorial here!
Things we’ll talk about
In case you want to tell people what you learned today!
DAS - method for finding linear subspaces of model representations that store a particular concept.
RAVEL - evaluation framework for localizing concepts in model activations.
Let’s do this!
[1]:
import plotly.io as pio
from IPython.display import clear_output
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
pio.renderers.default = "colab"
!pip install nnsight==0.5.0.dev
else:
pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"
clear_output()
Note
In this tutorial, we use the Llama-3.2 1B model. Before starting the tutorial, please go to the model’s huggingface page and request permission to use the model. Then, log in to this notebook with your huggingface access token.
[ ]:
from huggingface_hub import notebook_login
notebook_login()
[2]:
# (try to) set seeds for reproducibility
import random
import torch
random.seed(12)
torch.manual_seed(12)
torch.cuda.manual_seed(12)
Making surgical edits - residual streams capture too much information#
When we ask a language model questions about the city of Paris, it seems to know the city’s country, its continent, and its language. Yet where are these properties of Paris stored?
One way to start investigating this is activation patching. When we patch the residual stream of the 8th layer’s activation for the Paris token, we change its country from France to Brazil.
[3]:
# load model
import nnsight
from IPython.display import clear_output
model = nnsight.LanguageModel("meta-llama/Llama-3.2-1B", device_map="auto")
clear_output()
[4]:
# base run - does our model know where Paris is?
import torch
base_prompt = "Paris is in the country of"
# get logits from the model's output
with torch.no_grad():
with model.trace(base_prompt) as tracer:
base_logits = model.output.logits[:, -1, :].save()
# apply softmax to convert logits to probability distribution over tokens
base_probs = torch.softmax(base_logits, dim=-1)
top_completions = torch.topk(base_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Some parameters are on the meta device because they were offloaded to the cpu.
France (0.65)
the (0.05)
love (0.01)
[5]:
# source run - collect representations for a city from a different country
source_prompt = "Rio is in the country of"
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space
source_hidden_states = []
with torch.no_grad():
with model.trace(source_prompt) as tracer:
# get hidden states of all layers in the network.
# we index the output at 0 because it's a tuple where the first index is the hidden state.
for layer in model.model.layers:
source_hidden_states.append(layer.output[0].save())
[6]:
# patched run - by patching at layer 8 over Paris, we change its country from France to Brazil!
TOKEN_INDEX = 1
LAYER_INDEX = 8
with model.trace(base_prompt) as tracer:
# apply the same patch we did before
model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
patched_logits = model.output.logits[:, -1, :].save()
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Brazil (0.61)
the (0.05)
Portugal (0.01)
However, we also accidentally edit other facts about Paris, such as its continent and language!
[7]:
# by changing Paris's country, we also changed its continent!
TOKEN_INDEX = 1
LAYER_INDEX = 8
new_base_prompt = "Paris is in the continent of"
with model.trace(new_base_prompt) as tracer:
# apply the same patch we did before
model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
patched_logits = model.output.logits[:, -1, :].save()
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
South (0.55)
America (0.11)
North (0.10)
[8]:
# as well as its language!
new_base_prompt = "Paris is a city whose main language is"
with model.trace(new_base_prompt) as tracer:
# apply the same patch we did before
model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
patched_logits = model.output.logits[:, -1, :].save()
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Portuguese (0.57)
Spanish (0.12)
English (0.10)
Takeaway
We need to find a way to make our patching more precise. One way to do this is to patch a unit of computation that’s smaller than the whole residual stream component. There are many reasonable options, such as patching sets of neurons. In this tutorial, we’ll look at how we can patch linear subspaces of a model’s representation.
Want to know more?
This example came from RAVEL, a benchmark that measures whether interpretability methods can localize specific concepts (e.g., country vs. language) in a model’s internal activations. Check out the paper and dataset for a full analysis of current interpretability methods & areas for future work!
Choosing the right unit of computation - how do models represent concepts?#
What are we patching to begin with? Let’s take a look at the source activations we collected.
[9]:
source_activations = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
source_activations
[9]:
tensor([[ 0.0111, -0.0206, -0.2613, ..., -0.0281, -0.1300, 0.0346]],
device='cuda:0')
[10]:
source_activations.shape
[10]:
torch.Size([1, 2048])
Can we break down the residual stream activation into smaller, meaningful units of computation?
One idea is to look at single neurons - that is, single indices within the large 2048-dimensional vector.
Another idea, motivated by the Linear Representation Hypothesis, is that transformer-based neural networks tend to use linear subspaces as units of computation. Thinking about a model’s activation as one giant vector, perhaps concepts are each encoded in a separate linear dimension within the vector.
To patch a set of neurons, we could simply index into the ones we think encode important concepts in the model. However, enumerating all subsets of neurons is computationally infeasible.
[11]:
# change the list of indices to try a set of neurons to patch!
NEURON_INDICES = [0, 1, 2, 4]
base_prompt = "Paris is in the country of"
with model.trace(base_prompt) as tracer:
# Apply the patch from the source hidden states to the base hidden states
model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, NEURON_INDICES] = \
source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, NEURON_INDICES]
patched_logits = model.output.logits[:, -1, :]
patched_probs = torch.softmax(patched_logits, dim=-1).save()
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
France (0.64)
the (0.05)
love (0.01)
To patch a set of linear subspaces, we can follow a similar procedure, with a slight twist…
First, we rotate our base and source vectors. This creates two new vectors, whose neurons are linear combinations of the original vector. Next, we patch linear subspaces just as we would in the regular set-up. Lastly, we rotate back the patched vector, so that it’s in the same basis as the original run.
[12]:
# construct a rotation matrix (model_dim x model_dim)
MODEL_HIDDEN_DIM = 2048
rotator = torch.nn.Linear(MODEL_HIDDEN_DIM, MODEL_HIDDEN_DIM, bias=False)
torch.nn.init.orthogonal_(rotator.weight)
rotator = torch.nn.utils.parametrizations.orthogonal(rotator).to(model.device)
clear_output()
[13]:
# play around with how many linear dimensions we patch!
N_PATCHING_DIMS = 1
base_prompt = "Paris is in the country of"
def patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False):
grad_env = torch.enable_grad if with_grad else torch.no_grad
with grad_env():
with model.trace(base_prompt) as tracer:
# rotate the base representation
base = model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :].clone()
rotated_base = rotator(base)
# rotate the source representation
source = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
rotated_source = rotator(source)
# patch the first n dimensions in the rotated space
# (NOTE: same thing as `rotated_base[:, 0] = rotated_source[:, 0]` but we want the gradient to flow)
rotated_patch = torch.cat([
rotated_source[:, :N_PATCHING_DIMS],
rotated_base[:, N_PATCHING_DIMS:]
], dim=1)
# unrotate patched vector back to the original space
patch = torch.matmul(rotated_patch, rotator.weight.T)
# replace base with patch
model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = patch
patched_logits = model.output.logits[:, -1, :].save()
return patched_logits
patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
France (0.24)
the (0.06)
Belgium (0.02)
Want to know more?
You may have suspected this, but there’s nothing particularly special about a linear rotation! Maybe the model uses the magnitude of a vector, instead of its direction, to do meaningful computation. We can think about different intermediate transformations that might expose interesting units of computation. Here are some key properties that we need these transformations to have:
invertible - we need to be able to “undo” the transformation to return to the original representation space from the transformed space
separable - we don’t want concepts to interfere with each other during the transformation
To learn about more of their properties and their theoretical grounding, check out the causal abstraction theory paper!
Hm, changing our unit of computation from neurons to linear subspaces didn’t seem to help us out much… Patching the first few linear subspaces of our rotation matrix didn’t successfully edit the model’s representation of Paris’s country.
How do we automatically search for the linear subspaces we care about?
Takeaway
There are different potentially meaningful units of computations in a model’s representation. Thinking about the model representation as one giant multi-dimensional vector, we can try to patch linear subspaces of the model’s representation by first rotating it to a different space.
How do we know which linear subspaces to patch? This is where DAS comes in!
Enter DAS - automatically finding relevant linear subspaces#
By rotating the hidden representations of our model, we can patch different linear subspaces. But how can we find the right linear subspace to patch?
Turns out, we can directly optimize our rotation vector to do this! Let’s try to train our rotation matrix to maximize the likelihood of “Brazil” the country of Paris instead of “France”.
[ ]:
# let's train our rotation matrix so that the patch output is Brazil instead of France
from tqdm import trange
# optimize only the rotation parameters (the LLM stays frozen)
optimizer = torch.optim.Adam(rotator.parameters())
# use language modeling loss - increase likelihood of outputing Brazil
loss_fn = torch.nn.CrossEntropyLoss()
counterfactual_answer = torch.tensor([model.tokenizer(" Brazil")["input_ids"][1]]).to(model.device)
with trange(10) as progress_bar: # train for 10 epochs
for epoch in progress_bar:
optimizer.zero_grad()
# get patched logits using our rotation vector
patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=True)
# cross entropy loss - make last token be Brazil instead of France
loss = loss_fn(patched_logits, counterfactual_answer)
progress_bar.set_postfix({'loss': loss.item()})
loss.backward()
optimizer.step()
10%|█ | 1/10 [02:41<22:09, 147.70s/it, loss=0.273]
Looks like training our rotation matrix did the job! Now, patching from Rio to Paris changes Paris’s country from France to Brazil.
[ ]:
base_prompt = "Paris is in the country of"
patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Brazil (0.21)
France (0.08)
the (0.07)
But did it interfere with other facts about Paris, such as its continent or language? Doesn’t look like it!
[ ]:
new_base_prompt = "Paris is in the continent of"
patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Europe (0.43)
Africa (0.19)
North (0.13)
[ ]:
new_base_prompt = "Paris is a city whose main language is"
patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
French (0.40)
English (0.15)
Italian (0.14)
Want to know more?
If there are concepts that we know we want to keep the same, we can train DAS with a multi-task objective (i.e., “edit this property” + “keep this other property the same”). See the RAVEL paper for more details!
Takeaway
How can we patch certain concepts in a model’s representation, such as the country of Paris, without messing with other concepts stored in the model, such as Paris’s continent or language?
DAS to the rescue! By searching over sets of linear subspaces, DAS finds a linear subspace in the model that, when patched, edits the model’s concept. The resulting patch is more precise - by patching individual linear subspaces, we have a better chance at making sure that only the specific concept we’re looking for gets edited.
Multi-Task DAS#
[ ]:
# let's train our rotation matrix so that the patch output is Brazil instead of France
from tqdm import trange
# optimize only the rotation parameters (the LLM stays frozen)
optimizer = torch.optim.Adam(rotator.parameters())
# use language modeling loss - increase likelihood of outputing Brazil
loss_fn = torch.nn.CrossEntropyLoss()
counterfactual_answer = torch.tensor([model.tokenizer(" Brazil")["input_ids"][1]]).to(model.device)
# we can directly specify things we want to stay the same!
new_base_prompt = "Paris is in the continent of"
new_base_answer = torch.tensor([model.tokenizer(" Europe")["input_ids"][1]]).to(model.device)
with trange(10) as progress_bar: # train for 10 epochs
for epoch in progress_bar:
optimizer.zero_grad()
# get loss for counterfactual behavior (what we want to CHANGE)
patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=True)
counterfactual_loss = loss_fn(patched_logits, counterfactual_answer)
# get loss for base behavior (what we want to STAY THE SAME)
patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=True)
new_base_loss = loss_fn(patched_logits, new_base_answer)
# can add more examples of base behavior to keep the same if we want!
# ...
# add up all losses together
loss = counterfactual_loss + new_base_loss
progress_bar.set_postfix({'loss': loss.item()})
loss.backward()
optimizer.step()