The Geometry of Truth#
Emergent Linear Structure in Large Language Model Representations of True/False Datasets
Modern LMs seem to be able to reason about true and false statements. Yet how do they represent factuality in their internal representations?
In this tutorial, we’ll follow Marks and Tegmark (2023) The Geometry of Truth, which finds linear structure in representations of true and false statements!
Specifically, we will investigate:
Information flow: where does an LM store information about the truth of a sentence?
Visualizing activations: using PCA to visualize low dimensions of LM internals.
Difference in means and why linear probing might not always be the right choice!
Steering LMs both small and large!
If you’re reading along, our tutorial will roughly recreate Figures 1 and 2, and then give an in-depth explanation of Figure 4 from the paper.
📗 Prefer to use Colab? Follow the tutorial here!
0️⃣ Setup#
Run this code before we begin!
[1]:
import plotly.io as pio
from IPython.display import clear_output
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
pio.renderers.default = "colab"
!pip install nnsight==0.5.0.dev
else:
pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"
clear_output()
Note
In this tutorial, we use the Llama-3.2 3B model. Before starting the tutorial, please go to the model’s huggingface page and request permission to use the model. Then, log in to this notebook with your huggingface access token.
[ ]:
from huggingface_hub import notebook_login
notebook_login()
[3]:
# (try to) set seeds for reproducibility
import random
import torch
random.seed(12)
torch.manual_seed(12)
torch.cuda.manual_seed(12)
[4]:
# util functions for tutorial
class COLORS:
"""keep consistent plotting colors"""
LIGHT_BLUE = "#46B1E1"
BLUE = "#156082"
LIGHT_ORANGE = "#F2AA84"
ORANGE = "#E97132"
PURPLE = "#A02B93"
GREEN = "#4EA72E"
def hex_to_rgba(hex_color, alpha=1.):
"""convert hex to rgb with opacity parameter"""
hex_color = hex_color.lstrip('#')
rgb = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
return f'rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {alpha})'
def rindex(lst, value):
"""get the rightmost index of a value in a list."""
return len(lst) - 1 - lst[::-1].index(value)
1️⃣ Information flow#
LMs seem to be able to classify true and false statements in-context. Yet where is this information stored? In this section, we’ll see how activation patching can inform which token and layer activations play a causal role in representing the truth of a sentence.
[5]:
# load model
import nnsight
from IPython.display import clear_output
model = nnsight.LanguageModel("meta-llama/Llama-3.2-3B", device_map="auto")
clear_output()
[6]:
# let's set up a few-shot prompt to see if models can reason about factuality
PROMPT_TEMPLATE = """The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
{statement} This statement is:"""
source_statement = "The city of Toronto is in Canada." # true
source_prompt = PROMPT_TEMPLATE.format(statement=source_statement)
base_statement = "The city of Chicago is in Canada." # false
base_prompt = PROMPT_TEMPLATE.format(statement=base_statement)
# this is a false statement
print(base_prompt)
The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
The city of Chicago is in Canada. This statement is:
Let’s put our model to the test - can it identify true and false statements? We’ll use a single example where we only vary the city: “the city of Toronto is in Canada” (true) vs. “the city of Chicago is in Canada” (false).
[7]:
# does the model know that Chicago isn't in Canada?
with torch.no_grad():
with model.trace(base_prompt) as trace:
# save the model's output logits
logits = model.output.logits.save()
# what's the model's response?
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
Some parameters are on the meta device because they were offloaded to the disk and cpu.
The city of Chicago is in Canada. This statement is:
FALSE
[8]:
# does the model know that Toronto is in Canada?
source_activations = []
with torch.no_grad():
with model.trace(source_prompt) as trace:
# let's save the intemediate activations - we'll use them in the next step!
for layer in model.model.layers:
source_activations.append(layer.output[0].save())
# save the model's output logits
logits = model.output.logits.save()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is:
TRUE
Activation Patching#
Okay, so our model seems to know a true statement from a false one. Where did it store this information? Let’s use activation patching between our two examples to see which token and layer activations causally mediated the truth of the statement.
If you’re not familiar with activation patching, we strongly encourage you to check out the activation patching tutorial on ``nnsight``! We’ll leave some comments to explain the process.
[10]:
# run activation patching from source (true) -> base (false)
# and measure P(TRUE) - P(FALSE)
from tqdm import trange
true_token_id = model.tokenizer(" TRUE").input_ids[1]
false_token_id = model.tokenizer(" FALSE").input_ids[1]
source_prompt_ids = model.tokenizer(source_prompt).input_ids
newline_token_id = model.tokenizer('\n').input_ids[1]
last_example_index = rindex(source_prompt_ids, newline_token_id) + 1 # get start of final example
patching_results = [] # save interchange intervention accuracies
for layer_index in trange(model.config.num_hidden_layers): # loop through layers
patching_per_layer = []
for token_index in range(last_example_index, len(source_prompt_ids)): # loop through story tokens
with torch.no_grad():
with model.trace(base_prompt):
# patch source -> base
model.model.layers[layer_index].output[0][:, token_index, :] = source_activations[layer_index][:, token_index, :]
# get model output
patched_probs = model.output.logits[:, -1].softmax(dim=-1) # convert logits to probs with softmax
# get probability of generating true vs. false answer
patched_true_prob = patched_probs[0, true_token_id].item()
patched_false_prob = patched_probs[0, false_token_id].item()
# save difference btw true & false answers
patched_diff = patched_true_prob - patched_false_prob
patching_per_layer.append(patched_diff.save())
patching_results.append(patching_per_layer)
100%|██████████| 28/28 [24:45<00:00, 53.05s/it]
Let’s plot our results! The darker the color, the more effect the residual activation at that token/layer position has.
[12]:
# plot results
import plotly.express as px
# convert token indices to token strings
base_token_ids = model.tokenizer(base_prompt).input_ids
token_strings = [
f"{model.tokenizer.decode(base_token_ids[t])}" + " " * i
for i, t in enumerate(range(last_example_index, len(base_token_ids)))
]
fig = px.imshow(
patching_results,
y=list(range(model.config.num_hidden_layers)),
template='simple_white',
color_continuous_scale=[[0, '#FFFFFF'], [1, COLORS.BLUE]],
aspect='auto',
)
fig.update_layout(
xaxis=dict(ticktext=token_strings, tickvals=list(range(len(token_strings)))),
xaxis_title='tokens',
yaxis=dict(autorange='min'),
yaxis_title='layers'
)
fig
We read this plot starting from the bottom, starting at the “Toronto -> “Chicago” token. We can trace how the information flows from the input to the output!
Let’s jot down some takeaways.
The input: in the early layers, the “Toronto -> Chicago” token accounts for the difference between outputting TRUE vs. FALSE. This makes sense, because it’s the only thing we changed in the input!
The output: in the last layers, the information is stored at the last token. This also makes sense, because LMs are autoregressive, meaning they can only refer to the final token when generating the next one.
Something interesting in the middle… looks like the city token and the final token aren’t the only ones carrying meaningful information about the factuality of the sentence! The final sentence token and the end-of-sentence punctuation (“.”) mediate truth as well!
In the original paper, the authors investigated the representation under the end-of-sentence punctuation (“.”). However, we’ll differ slightly and analyze the representation over the final token in the sentence (e.g., “Canda”). From now on, our analysis will only focus on the representation of the final token in the sentence at the 10th layer, because it seems to mediate factuality!
2️⃣ Visualizing activations#
Now that we’ve identified where an LM represents the truth of a statement, let’s investigate how it represents this information. To start, we’ll visualize the representations of true and false statements, projecting them to a lower dimension using PCA.
[13]:
# let's load set of 50 true and 50 negative statements
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/saprmarks/geometry-of-truth/refs/heads/main/datasets/cities.csv')
df = df.iloc[:100]
df.head()
[13]:
statement | label | city | country | correct_country | |
---|---|---|---|---|---|
0 | The city of Krasnodar is in Russia. | 1 | Krasnodar | Russia | Russia |
1 | The city of Krasnodar is in South Africa. | 0 | Krasnodar | South Africa | Russia |
2 | The city of Lodz is in Poland. | 1 | Lodz | Poland | Poland |
3 | The city of Lodz is in the Dominican Republic. | 0 | Lodz | the Dominican Republic | Poland |
4 | The city of Maracay is in Venezuela. | 1 | Maracay | Venezuela | Venezuela |
[ ]:
# let's collect the model's activations over each statement at the same location
from tqdm import trange
import torch
# we'll focus on the 10th layer and the last token before the "."
LAYER = 10
punctuation_token_id = model.tokenizer('.').input_ids[1]
true_activations = []
false_activations = []
for i in trange(df.shape[0]): # loop through dataset
row = df.iloc[i]
prompt = PROMPT_TEMPLATE.format(statement=row.statement)
prompt_token_ids = model.tokenizer(prompt).input_ids
# get index of final token in the sentence (before the ".")
final_token_index = rindex(prompt_token_ids, punctuation_token_id) - 1
with torch.no_grad():
with model.trace(prompt) as trace:
# get the model's activation at our chosen token & layer position
activation = model.model.layers[LAYER].output[0][:, final_token_index, :].save() # (1, hidden_dim)
# add to our false/true activations!
if row.label == 0:
false_activations.append(activation)
else:
true_activations.append(activation)
true_activations = torch.cat(true_activations) # (50, hidden_dim)
false_activations = torch.cat(false_activations) # (50, hidden_dim)
100%|██████████| 100/100 [07:12<00:00, 4.33s/it]
It’s hard to visualize an 8,192 dimensional vector! So let’s try to only visualize two dimensions. We can do this with PCA, which will pick out the n
dimensions that best explain the variance across our collected activations.
[15]:
# go from 8192 -> 2 dimensions with PCA!
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
all_activations = torch.cat([true_activations, false_activations]).detach().cpu().numpy() # (100, hidden_dim)
low_dim_activations = pca.fit_transform(all_activations) # (100, 2)
low_dim_activations.shape
[15]:
(100, 2)
Let’s visualize our results! We’ll color the model representation of each sentence by whether that sentence is true or false.
[16]:
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
x=low_dim_activations[:50, 0],
y=low_dim_activations[:50, 1],
mode='markers',
marker=dict(symbol='square', color=hex_to_rgba(COLORS.LIGHT_BLUE, 0.5), size=10),
name='true',
))
fig.add_trace(go.Scatter(
x=low_dim_activations[50:, 0],
y=low_dim_activations[50:, 1],
mode='markers',
marker=dict(symbol='circle', color=hex_to_rgba(COLORS.LIGHT_ORANGE, 0.5), size=10),
name='false',
))
fig.update_layout(
template='simple_white',
width=400,
height=400,
xaxis_title='PCA 1',
yaxis_title='PCA 2',
)
fig.show()
Behold! The model representations form two clusters that correspond to the factuality of their input sentences. This means that the representations of truth in our dataset are linearly separable - a linear classifier over the model’s internals should be able to tell us with high accuracy whether the representations came from a true or false statement.
HOWEVER, just because the representations are linearly separable, it doesn’t mean the LM represents the concept linearly! To see why, we’ll go on a short but important tangent, exploring the difference between prediction vs. control.
3️⃣ Choosing the right probe - linear classifiers vs. difference in means#
As we just saw, even in a low-dimension space, it’s easy to tell activations over true statements and activations over false statements apart! But what’s driving the difference between true and false activations?
Different linear probes can give us directions in the latent space of the LLM that separate true and false activations. But we should be careful, because probes are likely to pick up on correlations instead of causal directions!
To understand why this might happen, let’s consider an example story.
Imagine that Prof. X wants to improve the outcomes of his gifted students. Prof. X keeps track of how many assignments a student turned in and their average grade on those assignments, and wants to use this information to predict whether the student passed the final.
But predicting is not enough. Prof. X wants to use the information from his predictor to improve his students’ performance on the final next year. Can his predictor give Prof. X useful information about the intervention he should design to improve student performance?
Note before we run any code
In this section, we won’t need a GPU! Feel free to change your runtime to CPU and re-run the code in the Setup section to save on compute.
[17]:
# create fake dataset of students
import numpy as np
np.random.seed(12)
N = 100
low_group_effort = np.random.normal(loc=5, scale=1.8, size=N)
high_group_effort = np.random.normal(loc=10, scale=1.8, size=N)
low_group_grade = 0.5 * low_group_effort + np.random.normal(loc=5, scale=0.3, size=N)
high_group_grade = 0.5 * high_group_effort + np.random.normal(loc=7, scale=0.3, size=N)
X = np.stack((
np.concatenate((low_group_effort, high_group_effort)), # x-axis
np.concatenate((low_group_grade, high_group_grade)) # y-axis
), axis=1)
X = X - X.mean(axis=0) # center
y = ['low'] * N + ['high'] * N
[18]:
import plotly.graph_objects as go
fig = go.Figure()
mean_low = X[:N].mean(axis=0)
mean_high = X[N:].mean(axis=0)
fig.add_traces([
go.Scatter(
x=X[:N, 0],
y=X[:N, 1],
mode='markers',
marker=dict(symbol='circle', color=hex_to_rgba(COLORS.LIGHT_BLUE, 0.5), size=10),
name='failed final',
),
go.Scatter(
x=X[N:, 0],
y=X[N:, 1],
mode='markers',
marker=dict(symbol='square', color=hex_to_rgba(COLORS.LIGHT_ORANGE, 0.5), size=10),
name='passed final',
),
go.Scatter(
x=[mean_low[0]],
y=[mean_low[1]],
mode='markers',
marker=dict(symbol='circle', color=hex_to_rgba(COLORS.BLUE, 0.7), size=12),
showlegend=False,
),
go.Scatter(
x=[mean_high[0]],
y=[mean_high[1]],
mode='markers',
marker=dict(symbol='square', color=hex_to_rgba(COLORS.ORANGE, 0.7), size=12),
showlegend=False
)
])
fig.update_layout(
template='simple_white',
width=500,
height=400,
xaxis_title='# of assignments turned in',
yaxis_title='average grade',
xaxis=dict(scaleanchor="y", scaleratio=1), # Anchor x-axis to y-axis with 1:1 ratio
yaxis=dict(scaleanchor="x", scaleratio=1) # Anchor y-axis to x-axis with 1:1 ratio
)
fig.show()
Looking at the data, we can see that the number of assignments turned in correlates with the average grade. Both average grade and number of assignments turned in also seem predictive of the students’ performance on the final. Which feature will our probes pick out?
Logistic regression probe (LR)#
Logistic regression will find the boundary that maximizes the margin between the students who passed and failed the final. However, it’s likely to pick up on spurious correlations that don’t drive the difference between the two groups!
[19]:
from sklearn.linear_model import LogisticRegression
lr_probe = LogisticRegression(fit_intercept=False, random_state=12)
lr_probe.fit(X, y)
lr_probe.score(X, y)
[19]:
1.0
[20]:
theta = lr_probe.coef_[0]
theta
[20]:
array([ 0.56890804, -3.07920493])
Let’s visualize the boundary drawn by our logistic regression probe. It neatly classifies between students who passed and students who failed the final!
[21]:
import plotly.graph_objects as go
# flip x & y -? flip around 45-degree line
# negate x -> flip around y-axis
theta_orthogonal = np.array([-theta[1], theta[0]])
# add trendline
t = np.linspace(-2, 2, 100)
probe_x = t * theta_orthogonal[0]
probe_y = t * theta_orthogonal[1]
fig.add_trace(
go.Scatter(
x=probe_x,
y=probe_y,
mode='lines',
line=dict(color=COLORS.PURPLE, width=2, dash='dash'),
name='LR probe'
)
)
fig
But what is the direction of our probe? What intervention does it suggest to shift students who failed the final towards passing the final next year?
Because grade and # of assignments is correlated, the LR probe does away with # of assignments and uses the average grade to classify students. But what sort of intervention does this suggest?
[22]:
mean_low_transformed = mean_low - theta
fig.add_annotation(
ax=mean_low[0], # x-coordinate of the arrow's head
ay=mean_low[1] + 0.2, # y-coordinate of the arrow's head
x=mean_low_transformed[0], # x-coordinate of the arrow's tail
y=mean_low_transformed[1] - 0.2, # y-coordinate of the arrow's tail
xref='x', yref='y', # Reference coordinates to the plot's x and y axes
axref='x', ayref='y',
showarrow=True,
arrowhead=4, # Style of the arrowhead (e.g., 1, 2, 3, 4, 5)
arrowsize=2,
arrowwidth=1,
arrowcolor=COLORS.PURPLE,
)
fig
Presumably, bumping everyone’s grades’ up won’t have the intended effect of helping people pass the final! Indeed, if we draw an arrow in the direction of the linear probe, we see that it doesn’t do a great job at bringing the two distributions of students close together.
In this case, the direction of the LR probe isn’t meaningful - it tells the students apart, but doesn’t help us steer a student who’s on track to fail the class towards passing it.
Difference in means probe (aka. mean mass / MM)#
A simple but effective alternative to LR probes is difference in means (also known as mean mass) probing.
To classify between groups of students, we take the difference between the mean student who passed the final and the mean student who didn’t pass. What we get is a direction vector that roughly tells the two classes apart, but more importantly tells us how to bridge between the two distribution.
[23]:
mean_mass_probe = mean_high - mean_low
mean_mass_probe_orthogonal = np.array([-mean_mass_probe[1], mean_mass_probe[0]])
# add trendline
t = np.linspace(-1, 1, 100)
probe_x = t * mean_mass_probe_orthogonal[0]
probe_y = t * mean_mass_probe_orthogonal[1]
fig.add_trace(
go.Scatter(
x=probe_x,
y=probe_y,
mode='lines',
line=dict(color=COLORS.GREEN, width=2, dash='dash'),
name='MM probe'
)
)
fig
[24]:
fig.add_annotation(
ax=mean_low[0], # x-coordinate of the arrow's head
ay=mean_low[1] + 0.2, # y-coordinate of the arrow's head
x=mean_high[0], # x-coordinate of the arrow's tail
y=mean_high[1] - 0.2, # y-coordinate of the arrow's tail
xref='x', yref='y', # Reference coordinates to the plot's x and y axes
axref='x', ayref='y',
showarrow=True,
arrowhead=4, # Style of the arrowhead (e.g., 1, 2, 3, 4, 5)
arrowsize=2,
arrowwidth=1,
arrowcolor=COLORS.GREEN,
)
fig
The mean mass probe suggests a better intervention - not only should students’ grades go up, but they should also probably turn in more assignments in order to prepare for the final.
Even here, causality isn’t guaranteed. For example, there might be other underlying causes behind why students couldn’t turn in assignments. We are limited by the information in our dataset and the complexity of our intervention.
However, unlike with Prof. X who has to wait a whole year to try out ways to improve his students’ final grades, we get to test our interventions on our LLMs right away! Let’s see what it looks like to construct a “truth” steering vector on our LLM and use it to change a model’s opinion about the truth of a sentence.
4️⃣ Difference in means steering#
Now that we’ve explored the effect of different probe methods, let’s return to our investigation of how LMs represent truth.
[25]:
true_activations_mean = true_activations.mean(axis=0)
false_activations_mean = false_activations.mean(axis=0)
difference_in_means = true_activations_mean - false_activations_mean
difference_in_means.shape
[25]:
torch.Size([3072])
Steer from false to true
[26]:
with torch.no_grad():
with model.trace(base_prompt) as trace:
logits = model.output.logits.save()
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Chicago is in Canada. This statement is:
FALSE
[27]:
base_prompt_token_ids = model.tokenizer(base_prompt).input_ids
punctuation_token_id = model.tokenizer.encode('.')[1]
punctuation_index = rindex(base_prompt_token_ids, punctuation_token_id) - 1
with torch.no_grad():
with model.trace(base_prompt) as trace:
# steer model!
model.model.layers[LAYER].output[0][:, punctuation_index, :] += difference_in_means
logits = model.output.logits.save()
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Chicago is in Canada. This statement is:
TRUE
Steer from true to false
[28]:
with torch.no_grad():
with model.trace(source_prompt) as trace:
logits = model.output.logits.save()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is:
TRUE
[29]:
source_prompt_token_ids = model.tokenizer(source_prompt).input_ids
punctuation_token_id = model.tokenizer.encode('.')[1]
punctuation_index = rindex(source_prompt_token_ids, punctuation_token_id) - 1
with torch.no_grad():
with model.trace(source_prompt) as trace:
# reverse the steering direction!
model.model.layers[LAYER].output[0][:, punctuation_index, :] -= difference_in_means
logits = model.output.logits.save()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is:
FALSE
➡️ Let’s scale things up! Steering Llama-70B on NDIF#
If you’re using Colab, we recommend disconnecting and re-connecting to a CPU instance. In this section, we’ll use the models provided by NDIF!
[30]:
from IPython.display import clear_output
if is_colab:
!pip install -U nnsight
clear_output()
[ ]:
from nnsight import CONFIG
CONFIG.API.APIKEY = input("Enter your API key: ")
clear_output()
[32]:
from huggingface_hub import notebook_login
notebook_login()
[33]:
# instantiate the model using the LanguageModel class
from nnsight import LanguageModel
# don't worry, this won't load locally!
model = LanguageModel("meta-llama/Meta-Llama-3.1-70B", device_map="auto")
clear_output()
[34]:
model
[34]:
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8192)
(layers): ModuleList(
(0-79): 80 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=1024, bias=False)
(v_proj): Linear(in_features=8192, out_features=1024, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
(up_proj): Linear(in_features=8192, out_features=28672, bias=False)
(down_proj): Linear(in_features=28672, out_features=8192, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((8192,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8192, out_features=128256, bias=False)
(generator): Generator(
(streamer): Streamer()
)
)
[35]:
PROMPT_TEMPLATE = """The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
{statement} This statement is:"""
# patching from true statement (source) -> false statement (base)
source_statement = "The city of Toronto is in Canada." # true
source_prompt = PROMPT_TEMPLATE.format(statement=source_statement)
base_statement = "The city of Chicago is in Canada." # false
base_prompt = PROMPT_TEMPLATE.format(statement=base_statement)
print(base_prompt)
The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
The city of Chicago is in Canada. This statement is:
[36]:
with model.trace(source_prompt, remote=True):
source_logits = model.output.logits.save()
with model.trace(base_prompt, remote=True):
base_logits = model.output.logits.save()
clear_output()
print("SOURCE (true statement)")
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(source_logits.argmax(dim=-1)[0, -1]))
print('-' * 50)
print("BASE (false statement)")
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(base_logits.argmax(dim=-1)[0, -1]))
SOURCE (true statement)
The city of Toronto is in Canada. This statement is:
TRUE
--------------------------------------------------
BASE (false statement)
The city of Chicago is in Canada. This statement is:
FALSE
[37]:
import pandas as pd
num_examples = 50
df = pd.read_csv('https://raw.githubusercontent.com/saprmarks/geometry-of-truth/refs/heads/main/datasets/cities.csv')
df = df.iloc[:num_examples]
df.head()
[37]:
statement | label | city | country | correct_country | |
---|---|---|---|---|---|
0 | The city of Krasnodar is in Russia. | 1 | Krasnodar | Russia | Russia |
1 | The city of Krasnodar is in South Africa. | 0 | Krasnodar | South Africa | Russia |
2 | The city of Lodz is in Poland. | 1 | Lodz | Poland | Poland |
3 | The city of Lodz is in the Dominican Republic. | 0 | Lodz | the Dominican Republic | Poland |
4 | The city of Maracay is in Venezuela. | 1 | Maracay | Venezuela | Venezuela |
[ ]:
import torch
PROMPT_TEMPLATE = """The city of Tokyo is in Japan. This statement is: TRUE
The city of Hanoi is in Poland. This statement is: FALSE
{statement} This statement is:"""
LAYER = 20
punctuation_token_id = model.tokenizer('.').input_ids[1] # extract activation over "."
true_activations = []
false_activations = []
with model.session(remote=True) as session:
for i in range(df.shape[0]):
row = df.iloc[i]
prompt = PROMPT_TEMPLATE.format(statement=row.statement)
prompt_token_ids = model.tokenizer(prompt).input_ids
punctuation_index = rindex(prompt_token_ids, punctuation_token_id) - 1
with model.trace(prompt):
activation = model.model.layers[LAYER].output[0][:, punctuation_index, :].save()
if row.label == 0:
false_activations.append(activation)
else:
true_activations.append(activation)
clear_output()
true_activations = torch.cat(true_activations)
false_activations = torch.cat(false_activations)
true_activations.shape, false_activations.shape
(torch.Size([25, 8192]), torch.Size([25, 8192]))
[39]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
all_activations = torch.cat([true_activations.to(torch.float64), false_activations.to(torch.float64)]).cpu().numpy()
low_dim_activations = pca.fit_transform(all_activations)
low_dim_activations.shape
[39]:
(50, 2)
[40]:
import plotly.graph_objects as go
fig = go.Figure()
num_per_class = num_examples // 2
fig.add_trace(go.Scatter(
x=low_dim_activations[:num_per_class, 0],
y=low_dim_activations[:num_per_class, 1],
mode='markers',
marker=dict(symbol='square', color=hex_to_rgba(COLORS.LIGHT_BLUE, 0.5), size=10),
name='true',
))
fig.add_trace(go.Scatter(
x=low_dim_activations[num_per_class:, 0],
y=low_dim_activations[num_per_class:, 1],
mode='markers',
marker=dict(symbol='circle', color=hex_to_rgba(COLORS.LIGHT_ORANGE, 0.5), size=10),
name='false',
))
fig.update_layout(
template='simple_white',
width=400,
height=400,
xaxis_title='PCA 1',
yaxis_title='PCA 2',
)
fig.show()
[41]:
true_activations_mean = true_activations.mean(axis=0)
false_activations_mean = false_activations.mean(axis=0)
difference_in_means = true_activations_mean - false_activations_mean
difference_in_means.shape
[41]:
torch.Size([8192])
Steer from false to true
[42]:
with model.trace(base_prompt, remote=True) as trace:
logits = model.output.logits.save()
clear_output()
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Chicago is in Canada. This statement is:
FALSE
[43]:
STEER_FACTOR = 2
base_prompt_token_ids = model.tokenizer(base_prompt).input_ids
punctuation_token_id = model.tokenizer.encode('.')[1]
punctuation_index = rindex(base_prompt_token_ids, punctuation_token_id) - 1
with model.trace(base_prompt, remote=True) as trace:
# steer model!
model.model.layers[LAYER].output[0][:, punctuation_index, :] += difference_in_means * STEER_FACTOR
logits = model.output.logits.save()
clear_output()
print(base_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Chicago is in Canada. This statement is:
TRUE
Steer from true to false
[44]:
with model.trace(source_prompt, remote=True) as trace:
logits = model.output.logits.save()
clear_output()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is:
TRUE
[45]:
source_prompt_token_ids = model.tokenizer(source_prompt).input_ids
punctuation_token_id = model.tokenizer.encode('.')[1]
punctuation_index = rindex(source_prompt_token_ids, punctuation_token_id) - 1
with model.trace(source_prompt, remote=True) as trace:
# reverse the steering direction!
model.model.layers[LAYER].output[0][:, punctuation_index, :] -= difference_in_means * STEER_FACTOR
logits = model.output.logits.save()
clear_output()
print(source_prompt.split('\n')[-1])
print(model.tokenizer.decode(logits.argmax(dim=-1)[0, -1]))
The city of Toronto is in Canada. This statement is:
FALSE