Distributed Alignment Search (DAS): Searching for Linearly Encoded Concepts in Model Representations#

Imagine we want to edit a model to think that Paris is in the country of Brazil, without changing whatever else the model knows about Paris (e.g., its language, continent, …). Which representations in the model encode this fact about Paris?

In this tutorial, we’ll go over Distributed Alignment Search, or DAS, which helps us automatically identify a set of linear subspaces in a model’s representations that encode a particular concept.

Note

We encourage using “light mode” to view this tutorial, since the color blocks are harder to read in “dark mode”. You can also follow along on this Colab notebook.

Before we begin!

These are good things to know before we begin the tutorial

  • Activation patching - check out the activation patching tutorial here!

Things we’ll talk about

In case you want to tell people what you learned today!

  • DAS - method for finding linear subspaces of model representations that store a particular concept.

  • RAVEL - evaluation framework for localizing concepts in model activations.

Let’s do this!

[1]:
import plotly.io as pio
from IPython.display import clear_output

try:
  import google.colab
  is_colab = True
except ImportError:
  is_colab = False

if is_colab:
  pio.renderers.default = "colab"
  !pip install nnsight==0.5.0.dev
else:
  pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

clear_output()

Note

In this tutorial, we use the Llama-3.2 1B model. Before starting the tutorial, please go to the model’s huggingface page and request permission to use the model. Then, log in to this notebook with your huggingface access token.

[ ]:
from huggingface_hub import notebook_login

notebook_login()
[2]:
# (try to) set seeds for reproducibility
import random
import torch

random.seed(12)
torch.manual_seed(12)
torch.cuda.manual_seed(12)

Making surgical edits - residual streams capture too much information#

When we ask a language model questions about the city of Paris, it seems to know the city’s country, its continent, and its language. Yet where are these properties of Paris stored?

One way to start investigating this is activation patching. When we patch the residual stream of the 8th layer’s activation for the Paris token, we change its country from France to Brazil.

two forward runs of a model, with an arrow between the residual stream activations of Rio and Paris. After the intervention is applied, the model outputs Brazil

[3]:
# load model
import nnsight
from IPython.display import clear_output
model = nnsight.LanguageModel("meta-llama/Llama-3.2-1B", device_map="auto")
clear_output()
[4]:
# base run - does our model know where Paris is?
import torch

base_prompt = "Paris is in the country of"

# get logits from the model's output
with torch.no_grad():
  with model.trace(base_prompt) as tracer:
    base_logits = model.output.logits[:, -1, :].save()

# apply softmax to convert logits to probability distribution over tokens
base_probs = torch.softmax(base_logits, dim=-1)

top_completions = torch.topk(base_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
Some parameters are on the meta device because they were offloaded to the cpu.
 France (0.65)
 the (0.05)
 love (0.01)
[5]:
# source run - collect representations for a city from a different country
source_prompt = "Rio is in the country of"
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space

source_hidden_states = []
with torch.no_grad():
  with model.trace(source_prompt) as tracer:
    # get hidden states of all layers in the network.
    # we index the output at 0 because it's a tuple where the first index is the hidden state.
    for layer in model.model.layers:
      source_hidden_states.append(layer.output[0].save())
[6]:
# patched run - by patching at layer 8 over Paris, we change its country from France to Brazil!
TOKEN_INDEX = 1
LAYER_INDEX = 8

with model.trace(base_prompt) as tracer:
  # apply the same patch we did before
  model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

  patched_logits = model.output.logits[:, -1, :].save()

patched_probs = torch.softmax(patched_logits, dim=-1)

top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 Brazil (0.61)
 the (0.05)
 Portugal (0.01)

However, we also accidentally edit other facts about Paris, such as its continent and language!

[7]:
# by changing Paris's country, we also changed its continent!
TOKEN_INDEX = 1
LAYER_INDEX = 8

new_base_prompt = "Paris is in the continent of"

with model.trace(new_base_prompt) as tracer:
  # apply the same patch we did before
  model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
    source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

  patched_logits = model.output.logits[:, -1, :].save()

patched_probs = torch.softmax(patched_logits, dim=-1)

top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 South (0.55)
 America (0.11)
 North (0.10)
[8]:
# as well as its language!
new_base_prompt = "Paris is a city whose main language is"

with model.trace(new_base_prompt) as tracer:
  # apply the same patch we did before
  model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
    source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

  patched_logits = model.output.logits[:, -1, :].save()

patched_probs = torch.softmax(patched_logits, dim=-1)

top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 Portuguese (0.57)
 Spanish (0.12)
 English (0.10)

Takeaway

We need to find a way to make our patching more precise. One way to do this is to patch a unit of computation that’s smaller than the whole residual stream component. There are many reasonable options, such as patching sets of neurons. In this tutorial, we’ll look at how we can patch linear subspaces of a model’s representation.

Want to know more?

This example came from RAVEL, a benchmark that measures whether interpretability methods can localize specific concepts (e.g., country vs. language) in a model’s internal activations. Check out the paper and dataset for a full analysis of current interpretability methods & areas for future work!

Choosing the right unit of computation - how do models represent concepts?#

What are we patching to begin with? Let’s take a look at the source activations we collected.

[9]:
source_activations = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
source_activations
[9]:
tensor([[ 0.0111, -0.0206, -0.2613,  ..., -0.0281, -0.1300,  0.0346]],
       device='cuda:0')
[10]:
source_activations.shape
[10]:
torch.Size([1, 2048])

Can we break down the residual stream activation into smaller, meaningful units of computation?

One idea is to look at single neurons - that is, single indices within the large 2048-dimensional vector.

Another idea, motivated by the Linear Representation Hypothesis, is that transformer-based neural networks tend to use linear subspaces as units of computation. Thinking about a model’s activation as one giant vector, perhaps concepts are each encoded in a separate linear dimension within the vector.

Activation represented as a linear vector, with subspaces encoding concepts such as the country & language of Paris

To patch a set of neurons, we could simply index into the ones we think encode important concepts in the model. However, enumerating all subsets of neurons is computationally infeasible.

patching the first 3 neurons of the activations of Rio and Paris

[11]:
# change the list of indices to try a set of neurons to patch!
NEURON_INDICES = [0, 1, 2, 4]

base_prompt = "Paris is in the country of"

with model.trace(base_prompt) as tracer:
  # Apply the patch from the source hidden states to the base hidden states
  model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, NEURON_INDICES] = \
    source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, NEURON_INDICES]

  patched_logits = model.output.logits[:, -1, :]

  patched_probs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 France (0.64)
 the (0.05)
 love (0.01)

To patch a set of linear subspaces, we can follow a similar procedure, with a slight twist…

First, we rotate our base and source vectors. This creates two new vectors, whose neurons are linear combinations of the original vector. Next, we patch linear subspaces just as we would in the regular set-up. Lastly, we rotate back the patched vector, so that it’s in the same basis as the original run.

patch between a source and base vector, where the source & base vector are first rotated. the resulting patch is then un-rotated back to the original basis

[12]:
# construct a rotation matrix (model_dim x model_dim)
MODEL_HIDDEN_DIM = 2048

rotator = torch.nn.Linear(MODEL_HIDDEN_DIM, MODEL_HIDDEN_DIM, bias=False)
torch.nn.init.orthogonal_(rotator.weight)

rotator = torch.nn.utils.parametrizations.orthogonal(rotator).to(model.device)
clear_output()
[13]:
# play around with how many linear dimensions we patch!
N_PATCHING_DIMS = 1

base_prompt = "Paris is in the country of"

def patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False):
  grad_env = torch.enable_grad if with_grad else torch.no_grad
  with grad_env():
    with model.trace(base_prompt) as tracer:
      # rotate the base representation
      base = model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :].clone()
      rotated_base = rotator(base)

      # rotate the source representation
      source = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
      rotated_source = rotator(source)

      # patch the first n dimensions in the rotated space
      # (NOTE: same thing as `rotated_base[:, 0] = rotated_source[:, 0]` but we want the gradient to flow)
      rotated_patch = torch.cat([
        rotated_source[:, :N_PATCHING_DIMS],
        rotated_base[:, N_PATCHING_DIMS:]
      ], dim=1)

      # unrotate patched vector back to the original space
      patch = torch.matmul(rotated_patch, rotator.weight.T)

      # replace base with patch
      model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = patch

      patched_logits = model.output.logits[:, -1, :].save()
  return patched_logits

patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 France (0.24)
 the (0.06)
 Belgium (0.02)

Want to know more?

You may have suspected this, but there’s nothing particularly special about a linear rotation! Maybe the model uses the magnitude of a vector, instead of its direction, to do meaningful computation. We can think about different intermediate transformations that might expose interesting units of computation. Here are some key properties that we need these transformations to have:

  • invertible - we need to be able to “undo” the transformation to return to the original representation space from the transformed space

  • separable - we don’t want concepts to interfere with each other during the transformation

To learn about more of their properties and their theoretical grounding, check out the causal abstraction theory paper!

Hm, changing our unit of computation from neurons to linear subspaces didn’t seem to help us out much… Patching the first few linear subspaces of our rotation matrix didn’t successfully edit the model’s representation of Paris’s country.

How do we automatically search for the linear subspaces we care about?

Takeaway

There are different potentially meaningful units of computations in a model’s representation. Thinking about the model representation as one giant multi-dimensional vector, we can try to patch linear subspaces of the model’s representation by first rotating it to a different space.

How do we know which linear subspaces to patch? This is where DAS comes in!

Enter DAS - automatically finding relevant linear subspaces#

By rotating the hidden representations of our model, we can patch different linear subspaces. But how can we find the right linear subspace to patch?

Turns out, we can directly optimize our rotation vector to do this! Let’s try to train our rotation matrix to maximize the likelihood of “Brazil” the country of Paris instead of “France”.

[ ]:
# let's train our rotation matrix so that the patch output is Brazil instead of France
from tqdm import trange

# optimize only the rotation parameters (the LLM stays frozen)
optimizer = torch.optim.Adam(rotator.parameters())

# use language modeling loss - increase likelihood of outputing Brazil
loss_fn = torch.nn.CrossEntropyLoss()

counterfactual_answer = torch.tensor([model.tokenizer(" Brazil")["input_ids"][1]]).to(model.device)

with trange(10) as progress_bar: # train for 10 epochs
  for epoch in progress_bar:
    optimizer.zero_grad()

    # get patched logits using our rotation vector
    patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=True)

    # cross entropy loss - make last token be Brazil instead of France
    loss = loss_fn(patched_logits, counterfactual_answer)
    progress_bar.set_postfix({'loss': loss.item()})
    loss.backward()
    optimizer.step()
 10%|█         | 1/10 [02:41<22:09, 147.70s/it, loss=0.273]

Looks like training our rotation matrix did the job! Now, patching from Rio to Paris changes Paris’s country from France to Brazil.

[ ]:
base_prompt = "Paris is in the country of"

patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 Brazil (0.21)
 France (0.08)
 the (0.07)

But did it interfere with other facts about Paris, such as its continent or language? Doesn’t look like it!

[ ]:
new_base_prompt = "Paris is in the continent of"

patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 Europe (0.43)
 Africa (0.19)
 North (0.13)
[ ]:
new_base_prompt = "Paris is a city whose main language is"

patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_probs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_probs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
  print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')
 French (0.40)
 English (0.15)
 Italian (0.14)

Want to know more?

If there are concepts that we know we want to keep the same, we can train DAS with a multi-task objective (i.e., “edit this property” + “keep this other property the same”). See the RAVEL paper for more details!

Takeaway

How can we patch certain concepts in a model’s representation, such as the country of Paris, without messing with other concepts stored in the model, such as Paris’s continent or language?

DAS to the rescue! By searching over sets of linear subspaces, DAS finds a linear subspace in the model that, when patched, edits the model’s concept. The resulting patch is more precise - by patching individual linear subspaces, we have a better chance at making sure that only the specific concept we’re looking for gets edited.

Multi-Task DAS#

[ ]:
# let's train our rotation matrix so that the patch output is Brazil instead of France
from tqdm import trange

# optimize only the rotation parameters (the LLM stays frozen)
optimizer = torch.optim.Adam(rotator.parameters())

# use language modeling loss - increase likelihood of outputing Brazil
loss_fn = torch.nn.CrossEntropyLoss()

counterfactual_answer = torch.tensor([model.tokenizer(" Brazil")["input_ids"][1]]).to(model.device)

# we can directly specify things we want to stay the same!
new_base_prompt = "Paris is in the continent of"
new_base_answer = torch.tensor([model.tokenizer(" Europe")["input_ids"][1]]).to(model.device)

with trange(10) as progress_bar: # train for 10 epochs
  for epoch in progress_bar:
    optimizer.zero_grad()

    # get loss for counterfactual behavior (what we want to CHANGE)
    patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=True)
    counterfactual_loss = loss_fn(patched_logits, counterfactual_answer)

    # get loss for base behavior (what we want to STAY THE SAME)
    patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=True)
    new_base_loss = loss_fn(patched_logits, new_base_answer)

    # can add more examples of base behavior to keep the same if we want!
    # ...

    # add up all losses together
    loss = counterfactual_loss + new_base_loss

    progress_bar.set_postfix({'loss': loss.item()})
    loss.backward()
    optimizer.step()