The Geometry of Truth¶
Emergent Linear Structure in Large Language Model Representations of True/False Datasets
Modern LMs seem to be able to reason about true and false statements. Yet how do they represent factuality in their internal representations?
In this tutorial, we'll follow Marks and Tegmark (2023) The Geometry of Truth, which finds linear structure in representations of true and false statements!
Specifically, we will investigate:
- Information flow: where does an LM store information about the truth of a sentence?
- Visualizing activations: using PCA to visualize low dimensions of LM internals.
- Difference in means and why linear probing might not always be the right choice!
- Steering LMs both small and large!
If you're reading along, our tutorial will roughly recreate Figures 1 and 2, and then give an in-depth explanation of Figure 4 from the paper.
📗 Prefer to use Colab? Follow the tutorial here!
0️⃣ Setup¶
Run this code before we begin!
import plotly.io as pio
from IPython.display import clear_output
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
pio.renderers.default = "colab"
!pip install nnsight==0.5.0.dev
else:
pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"
clear_output()
Note
In this tutorial, we use the Llama-3.2 3B model. Before starting the tutorial, please go to the model's huggingface page and request permission to use the model. Then, log in to this notebook with your huggingface access token.
from huggingface_hub import notebook_login
notebook_login()
# (try to) set seeds for reproducibility
import random
import torch
random.seed(12)
torch.manual_seed(12)
torch.cuda.manual_seed(12)
# util functions for tutorial
class COLORS:
"""keep consistent plotting colors"""
LIGHT_BLUE = "#46B1E1"
BLUE = "#156082"
LIGHT_ORANGE = "#F2AA84"
ORANGE = "#E97132"
PURPLE = "#A02B93"
GREEN = "#4EA72E"
def hex_to_rgba(hex_color, alpha=1.):
"""convert hex to rgb with opacity parameter"""
hex_color = hex_color.lstrip('#')
rgb = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
return f'rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {alpha})'
def rindex(lst, value):
"""get the rightmost index of a value in a list."""
return len(lst) - 1 - lst[::-1].index(value)
1️⃣ Information flow¶
LMs seem to be able to classify true and false statements in-context. Yet where is this information stored? In this section, we'll see how activation patching can inform which token and layer activations play a causal role in representing the truth of a sentence.
# load model
import nnsight
from IPython.display import clear_output
model = nnsight.LanguageModel("meta-llama/Llama-3.2-3B", device_map="auto")
clear_output()
# let's set up a few-shot prompt to see if models can reason about factuality
PROMPT_TEMPLATE = """The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
{statement} This statement is:"""
source_statement = "The city of Toronto is in Canada." # true
source_prompt = PROMPT_TEMPLATE.format(statement=source_statement)
base_statement = "The city of Chicago is in Canada." # false
base_prompt = PROMPT_TEMPLATE.format(statement=base_statement)
# this is a false statement
print(base_prompt)
The city of Tokyo is in Japan. This statement is: TRUE The city of Hanoi is in Poland. This statement is: FALSE The city of Chicago is in Canada. This statement is:
Let's put our model to the test - can it identify true and false statements? We'll use a single example where we only vary the city: "the city of Toronto is in Canada" (true) vs. "the city of Chicago is in Canada" (false).
# does the model know that Chicago isn't in Canada?
with torch.no_grad():
with model.trace(base_prompt) as trace:
# save the model's output logits
logits = model.output.logits.save()
# what's the model's response?
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Some parameters are on the meta device because they were offloaded to the disk and cpu.
The city of Chicago is in Canada. This statement is: FALSE
# does the model know that Toronto is in Canada?
source_activations = []
with torch.no_grad():
with model.trace(source_prompt) as trace:
# let's save the intemediate activations - we'll use them in the next step!
for layer in model.model.layers:
source_activations.append(layer.output[0].save())
# save the model's output logits
logits = model.output.logits.save()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is: TRUE
Activation Patching¶
Okay, so our model seems to know a true statement from a false one. Where did it store this information? Let's use activation patching between our two examples to see which token and layer activations causally mediated the truth of the statement.
If you're not familiar with activation patching, we strongly encourage you to check out the activation patching tutorial on nnsight! We'll leave some comments to explain the process.
# run activation patching from source (true) -> base (false)
# and measure P(TRUE) - P(FALSE)
from tqdm import trange
true_token_id = model.tokenizer(" TRUE").input_ids[1]
false_token_id = model.tokenizer(" FALSE").input_ids[1]
source_prompt_ids = model.tokenizer(source_prompt).input_ids
newline_token_id = model.tokenizer('\n').input_ids[1]
last_example_index = rindex(source_prompt_ids, newline_token_id) + 1 # get start of final example
patching_results = [] # save interchange intervention accuracies
for layer_index in trange(model.config.num_hidden_layers): # loop through layers
patching_per_layer = []
for token_index in range(last_example_index, len(source_prompt_ids)): # loop through story tokens
with torch.no_grad():
with model.trace(base_prompt):
# patch source -> base
model.model.layers[layer_index].output[0][:, token_index, :] = source_activations[layer_index][:, token_index, :]
# get model output
patched_probs = model.output.logits[:, -1].softmax(dim=-1) # convert logits to probs with softmax
# get probability of generating true vs. false answer
patched_true_prob = patched_probs[0, true_token_id].item()
patched_false_prob = patched_probs[0, false_token_id].item()
# save difference btw true & false answers
patched_diff = patched_true_prob - patched_false_prob
patching_per_layer.append(patched_diff.save())
patching_results.append(patching_per_layer)
100%|██████████| 28/28 [24:45<00:00, 53.05s/it]
Let's plot our results! The darker the color, the more effect the residual activation at that token/layer position has.
# plot results
import plotly.express as px
# convert token indices to token strings
base_token_ids = model.tokenizer(base_prompt).input_ids
token_strings = [
f"{model.tokenizer.decode(base_token_ids[t])}" + " " * i
for i, t in enumerate(range(last_example_index, len(base_token_ids)))
]
fig = px.imshow(
patching_results,
y=list(range(model.config.num_hidden_layers)),
template='simple_white',
color_continuous_scale=[[0, '#FFFFFF'], [1, COLORS.BLUE]],
aspect='auto',
)
fig.update_layout(
xaxis=dict(ticktext=token_strings, tickvals=list(range(len(token_strings)))),
xaxis_title='tokens',
yaxis=dict(autorange='min'),
yaxis_title='layers'
)
fig