Demystifying Verbatim Memorization in Large Language Models#
Recent LLMs exhibit verbatim memorization - reciting text directly from their training data. Verbatim memorization may lead to unreliable evaluations of LLM capabilities, might leak sensitive information in the training data, and already has resulted in lawsuits over copyright.
How do LLMs memorize long passages? Can we edit out passages that LLMs have memorized without harming their ability to generate fluent and informative text?
In this tutorial, we’ll follow Huang et al. (2024), which looks for tokens that trigger memorization in LLMs.
Along the way, we’ll cover:
Which tokens trigger memorization? Applying causal interventions, we can detect tokens that have a causal effect on the LLM’s memorized completion.
Is memorization a single mechanism? Memorized information is stored in different states, and might not be separable from the LLM’s language modeling capabilities.
Is causal analysis predictive of model behavior? We can test the predictions made by causal intervention experiments on whether or not the LLM will respond with a memorized answer.
Our code will roughly recreate Figure 4 in the original paper. We also strongly encourage you to read the paper’s blog post!
📗 Prefer to use Colab? Follow the tutorial here!
0️⃣ Setup#
Device
To speed things up, switch to GPU by going to Runtime -> Change Runtime Type.
[ ]:
# import nnsight
from IPython.display import clear_output
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
!pip install nnsight==0.5.0.dev
clear_output()
[ ]:
# load model
import nnsight
model = nnsight.LanguageModel("EleutherAI/pythia-2.8b", device_map="auto")
clear_output()
1️⃣ What triggers memorization?#
If you prompt an LLM to complete the sentence
Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly ____
It will know to complete it with
Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
(if you’re a Harry Potter fan, maybe you already completed it in your mind the same way!)
This is because the LLM has memorized the opening to Harry Potter. But where does the model store its memorized information? Can we hope to extract this memorized passage out of the model’s neural representations?
To find out where a model stores memorized text, we can look at which tokens in the prompt have a causal effect on the model’s behavior. Does switching out the token for “Dursley” with a different token change the model’s completion of the sentence?
We can systematically intervene on the model’s representation of each token in the first sentence of Harry Potter to ascertain whether that representation contributes to the model’s generation of the memorized passage, hence triggering memorization.
In the original paper, the authors patched from n = 10
random examples from the pile into the memorized text.
To simplify things, we’ll patch from a single “random” example with a bunch of punctuations (i.e., a list of “!”). This serves as a way to remove information from the memorized text.
This means our signal for memorization is as follows: if the “removal” patch reduces the probability of outputing the memorized text, then that token plays a causal role in triggering the memorization.
In other words, for memorized text, we should see a few tokens making a big difference. For text that the model didn’t memorize, we should not see any one token making a difference.
Note: if you’re unfamiliar with causal mediation analysis, we strongly encourage you to check out the tutorial on activation patching!
[ ]:
# try out other memorized lines!
harry = "Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you"
seuss = "Congratulations!\nToday is your day.\nYou\'re off to Great Places!\nYou\'re off and away!"
yoda = "Do or do not. There is no try"
example = harry
example_tokens = model.tokenizer(example).input_ids
[ ]:
# this is a bit hacky, but it wil serve as our "random" example
random = " ".join(["!"] * len(example_tokens))
random
'! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !'
[ ]:
# collect activations from our random source
import torch
random_activations = []
with torch.no_grad():
with model.trace(random) as trace:
for layer in model.gpt_neox.layers:
random_activations.append(layer.output[0].save())
random_activations[0].shape
torch.Size([1, 28, 2560])
[ ]:
# patch from the random activations into the base prompt
from tqdm import trange
patching_results = []
# iterate through layers (every other layer to save time)
for layer_idx in trange(0, model.config.num_hidden_layers, 2):
patching_results_per_layer = []
# iterate through all tokens
for token_idx in range(len(example_tokens)):
with torch.no_grad():
with model.trace(example_tokens) as tracer:
# apply the patch from "random" hidden states to current base run
model.gpt_neox.layers[layer_idx].output[0][:, token_idx, :] = \
random_activations[layer_idx][:, token_idx, :]
# get logits
patched_logits = model.output.logits[0] # (num_tokens, vocab_size)
patched_probs = torch.softmax(patched_logits, dim=-1)
# line up token probabilities with continuations
patched_probs = patched_probs[:-1, :] # (num_tokens - 1, vocab_size)
continuation_tokens = example_tokens[1:] # (num_tokens - 1)
patching_results_per_layer.append(patched_probs[range(len(continuation_tokens)), continuation_tokens].cpu().save())
patching_results.append(torch.stack(patching_results_per_layer))
patching_results = torch.stack(patching_results)
100%|██████████| 16/16 [00:47<00:00, 2.98s/it]
[ ]:
# visualize results for different token positions
import plotly.express as px
def plot_patching_results(
patching_results,
x_labels,
index=-1
):
patching_results = patching_results[:, :index, index]
fig = px.imshow(
patching_results,
color_continuous_midpoint=0.5,
color_continuous_scale="BuPu_r",
labels={"x": "Position", "y": "Layer","color":"Counterfactual logit"},
x=x_labels[:index]
)
fig.update_layout(
yaxis=dict(autorange="min")
)
return fig
Interpreting our patching results#
Although we collected causal effects across the entire sentence, we’ll only look at two tokens:
What causes the model to output “thank” after “normal,”?
What causes the model to output you after “thank”?
For “thank”, we see that a few specific tokens have a big impact on the model’s continuation.
Specifically, the names of the Dursley’s and Harry’s famous residence on Privet Drive have a causal effect on the model’s output. This suggests that ending the sentence on “thank” was memorized by the model, and that the Dursley’s names triggered the memorization.
[ ]:
clean_decoded_tokens = [model.tokenizer.decode(token) for token in example_tokens]
token_labels = [f"({index}) {token}" for index, token in enumerate(clean_decoded_tokens)]
fig = plot_patching_results(patching_results, token_labels, index=-2)
fig.show()
How about for following “thank” with “you”?
Unlike the first plot, we don’t see any significant effect from Harry Potter-specific tokens. Instead, almost all of the causal effect can be attributed to the token “thank”, which leads the model to continue with “you”.
This means that the model doesn’t memorize the token “you” in this sentence! Instead, it knows to fill it in thanks to its language modeling capabilities (knowing that “thank” is usually followed by “you”).
[ ]:
clean_decoded_tokens = [model.tokenizer.decode(token) for token in example_tokens]
token_labels = [f"({index}) {token}" for index, token in enumerate(clean_decoded_tokens)]
fig = plot_patching_results(patching_results, token_labels, index=-1)
fig.show()
This may not be surprising, but it’s an important takeaway. Our LLM did not memorize each token verbatim. Although it memorized that “thank” came after “normal,”, the LLM simply already knew that “thank” and “you” tend to go together.
If we were to try to naively reduce the LLM’s ability to output the memorized sentence, we might accidentally harm the LLM’s general language modeling ability. Memorization is distributed across different states, and can be entangled with the way our LLM learned to model language!
2️⃣ How does causal mediation analysis inform behavior?#
Our causal mediation analysis showed us when our LLM might be reciting memorized text, and when it might be trying to complete the text based on other indicators (e.g., “thank” -> “you”). It also told us which tokens played a causal role in triggering the model’s memorization. How do these predictions inform generalization? Let’s play around and see what happens when we remove the trigger tokens!
The “thank” plot shows that the “Dursley” token has the largest causal effect in triggering the model’s memorization. This suggests we can edit other parts of the sentence, like Harry Potter’s street address, without affecting the model’s memorization.
[ ]:
# changing harry's address doesn't seem to affect memorization!
different_harry = "Mr. and Mrs. Dursley of number 221B, Baker Street, were proud to say that they were perfectly"
with model.generate(different_harry, max_new_tokens=8, pad_token_id=model.tokenizer.eos_token_id):
out = model.generator.output.save()
print(model.tokenizer.decode(out[0]))
Mr. and Mrs. Dursley of number 221B, Baker Street, were proud to say that they were perfectly normal, thank you very much.
[ ]:
# talking about different harry potter characters still seems to trigger memorization
different_harry = "Mr. and Mrs. Hagrid of number four, Privet Drive, were proud to say that they were perfectly"
with model.generate(different_harry, max_new_tokens=8, pad_token_id=model.tokenizer.eos_token_id):
out = model.generator.output.save()
print(model.tokenizer.decode(out[0]))
Mr. and Mrs. Hagrid of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
We can also remove the memorization trigger entirely if we change the “Dursley” token to a different one - even if just slightly misspelled!
[ ]:
# but a simple mispelling throws the model off completely!
different_harry = "Mr. and Mrs. Dursney of number four, Privet Drive, were proud to say that they were perfectly"
with model.generate(different_harry, max_new_tokens=8, pad_token_id=model.tokenizer.eos_token_id):
out = model.generator.output.save()
print(model.tokenizer.decode(out[0]))
Mr. and Mrs. Dursney of number four, Privet Drive, were proud to say that they were perfectly happy.
"We have a
➡️ Let’s scale things up! Memorized MMLU examples#
Now that we’ve investigated memorization on a relatively small language model (3B parameters), let’s scale things up and analyze a much bigger model in the cloud (70B)!
It’s been hypothesized that recent LLMs were trained with MMLU in the dataset. Have these LLMs memorized their answers to some of the questions in MMLU? Let’s find out!
Running on NDIF
Follow these instructions to get set up with an API key for NDIF. We’ll also downgrade to nnsight v0.4
to use NDIF. Restart your session before running this code.
Device
We’ll be running everything on the NDIF server! So we recommend switching to CPU to save on compute.
[ ]:
from IPython.display import clear_output
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
!pip install -U nnsight
clear_output()
[ ]:
from nnsight import CONFIG
CONFIG.API.APIKEY = input("Enter your API key: ")
clear_output()
[ ]:
from huggingface_hub import notebook_login
notebook_login()
[ ]:
import nnsight
model = nnsight.LanguageModel("meta-llama/Llama-3.3-70B-Instruct", device_map="auto")
clear_output()
Testing our model on MMLU#
Let’s look at a math question from MMLU. The answer to this question is B = 4, for reasons that the creator of this tutorial doesn’t quite understand. Nevertheless, the LLM seems to know what the right answer is!
[ ]:
mmlu_question = """Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
A. 0
B. 4
C. 2
D. 6"""
answer = " B"
answer_token_id = model.tokenizer(answer).input_ids[1]
print(mmlu_question)
print("Answer:", answer.strip())
Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
A. 0
B. 4
C. 2
D. 6
Answer: B
[ ]:
import torch
prompt_template = [
{'role': 'user', 'content': mmlu_question},
{'role': 'assistant', 'content': 'Answer:'},
]
prompt = model.tokenizer.apply_chat_template(prompt_template, continue_final_message=True, tokenize=False)
with model.trace(prompt, remote=True):
logits = model.output.logits
model_answer_id = logits.argmax(dim=-1)[0, -1].save()
clear_output()
print("Model answer:", model.tokenizer.decode(model_answer_id).strip())
2025-07-10 18:25:26,094 23b7adde-c820-4695-9016-3fe9876d0974 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:23b7adde-c820-4695-9016-3fe9876d0974 - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:25:26,341 23b7adde-c820-4695-9016-3fe9876d0974 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:23b7adde-c820-4695-9016-3fe9876d0974 - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:25:26,636 23b7adde-c820-4695-9016-3fe9876d0974 - RUNNING: Your job has started running.
INFO:nnsight_remote:23b7adde-c820-4695-9016-3fe9876d0974 - RUNNING: Your job has started running.
2025-07-10 18:25:27,144 23b7adde-c820-4695-9016-3fe9876d0974 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:23b7adde-c820-4695-9016-3fe9876d0974 - COMPLETED: Your job has been completed.
Model answer: B
Our LLM seems to do more than just remember B - moving the answer around doesn’t fool our LLM!
[ ]:
import torch
mmlu_question_reordered = """Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
A. 4
B. 0
C. 2
D. 6"""
prompt_template = [
{'role': 'user', 'content': mmlu_question_reordered},
{'role': 'assistant', 'content': 'The answer is letter'},
]
prompt = model.tokenizer.apply_chat_template(prompt_template, continue_final_message=True, tokenize=False)
with model.trace(prompt, remote=True):
logits = model.output.logits
model_answer_id = logits.argmax(dim=-1)[0, -1].save()
print("Model answer:", model.tokenizer.decode(model_answer_id).strip())
2025-07-10 18:26:27,256 90acf29a-256f-4796-ace2-e092d2e58585 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:90acf29a-256f-4796-ace2-e092d2e58585 - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:26:27,441 90acf29a-256f-4796-ace2-e092d2e58585 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:90acf29a-256f-4796-ace2-e092d2e58585 - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:26:27,717 90acf29a-256f-4796-ace2-e092d2e58585 - RUNNING: Your job has started running.
INFO:nnsight_remote:90acf29a-256f-4796-ace2-e092d2e58585 - RUNNING: Your job has started running.
2025-07-10 18:26:28,173 90acf29a-256f-4796-ace2-e092d2e58585 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:90acf29a-256f-4796-ace2-e092d2e58585 - COMPLETED: Your job has been completed.
Model answer: A
Causal mediation analysis#
Let’s replicate our causal mediation analysis on our LLM’s response to the MMLU question.
[ ]:
# get only tokens inside question to make patching shorter
question_premise = "Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q"
question_premise_ids = model.tokenizer(question_premise).input_ids[1:] # ignore bos token
prompt_ids = model.tokenizer(prompt).input_ids
for i in range(len(prompt_ids)):
if prompt_ids[i:i+len(question_premise_ids)] == question_premise_ids:
patch_start = i
patch_end = i + len(question_premise_ids)
[ ]:
random = " ".join(["!"] * (patch_end - patch_start + 1))
random_activations = []
with model.trace(random, remote=True) as trace:
for layer in model.model.layers:
random_activations.append(layer.output[0].save())
random_activations[0].shape
2025-07-10 18:26:54,207 1840cc1a-be93-43b4-9359-444c7a53ba5a - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:1840cc1a-be93-43b4-9359-444c7a53ba5a - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:26:54,520 1840cc1a-be93-43b4-9359-444c7a53ba5a - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:1840cc1a-be93-43b4-9359-444c7a53ba5a - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:26:54,748 1840cc1a-be93-43b4-9359-444c7a53ba5a - RUNNING: Your job has started running.
INFO:nnsight_remote:1840cc1a-be93-43b4-9359-444c7a53ba5a - RUNNING: Your job has started running.
2025-07-10 18:26:56,366 1840cc1a-be93-43b4-9359-444c7a53ba5a - COMPLETED: Your job has been completed.
INFO:nnsight_remote:1840cc1a-be93-43b4-9359-444c7a53ba5a - COMPLETED: Your job has been completed.
torch.Size([1, 25, 8192])
[ ]:
patching_results = []
token_indices = list(range(patch_start, patch_end)) + [-1] # intervene on final token
# iterate through layers (every other layer to save time)
with model.session(remote=True) as session:
for layer_idx in range(0, model.config.num_hidden_layers, 2):
patching_results_per_layer = []
# iterate through all tokens
for i, token_idx in enumerate(token_indices):
with model.trace(prompt) as tracer:
# apply the patch from "random" hidden states to current base run
model.model.layers[layer_idx].output[0][:, token_idx, :] = \
random_activations[layer_idx][:, i, :]
# get logits
patched_logits = model.output.logits[0]
patched_probs = torch.softmax(patched_logits, dim=-1)
# save only logit over answer token id
patching_results_per_layer.append(patched_probs[-1, answer_token_id].item().save())
patching_results.append(patching_results_per_layer)
2025-07-10 18:41:43,125 856f2794-eb3b-4b30-b3d2-dc59df682635 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:856f2794-eb3b-4b30-b3d2-dc59df682635 - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:41:44,488 856f2794-eb3b-4b30-b3d2-dc59df682635 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:856f2794-eb3b-4b30-b3d2-dc59df682635 - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:41:47,580 856f2794-eb3b-4b30-b3d2-dc59df682635 - RUNNING: Your job has started running.
INFO:nnsight_remote:856f2794-eb3b-4b30-b3d2-dc59df682635 - RUNNING: Your job has started running.
2025-07-10 18:46:46,564 856f2794-eb3b-4b30-b3d2-dc59df682635 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:856f2794-eb3b-4b30-b3d2-dc59df682635 - COMPLETED: Your job has been completed.
[ ]:
import plotly.express as px
tokens = [model.tokenizer.decode(prompt_ids[i]) for i in token_indices]
tokens[-1] = "Answer:"
deduplicate_tokens = [t + " " * i for i, t in enumerate(tokens)]
fig = px.imshow(
[[p.value for p in per_layer_results] for per_layer_results in patching_results],
color_continuous_scale="BuPu_r",
labels={"x": "Position", "y": "Layer","color":"Counterfactual logit"},
x=deduplicate_tokens
)
fig.update_layout(
yaxis=dict(autorange="min")
)
The plot isn’t as clean as the previous analysis, but we can still draw some takeaways:
Looks like all tokens in the prompt seem to have an effect on the model, even going into the later layers! This suggests that the model is closely attending to the prompt, and even slight changes will change its answer.
Somewhat surprisingly, the number tokens such as 3 do not have as strong of a causal effect. In a sense, the model is paying more attention to the format of the question than its actual content!
Exposing memorization of MMLU#
Let’s apply our findings from the causal mediation analysis to rigorously test our LLM’s understanding of field extensions (or whatever it is this MMLU question is asking about).
As we saw, the second number in the list doesn’t have a large causal effect on the model’s final answer. What if we changed it around? For reasons unknown to the author of the tutorial but known to ChatGPT, changing the second number to 3 should change our answer from 4 to 2. Is our LLM aware of this?
[ ]:
# new answer should be A, bc sqrt(3) is repeated!
mmlu_question_different_numbers = """Find the degree for the given field extension Q(sqrt(3), sqrt(3), sqrt(18)) over Q.
A. 0
B. 4
C. 2
D. 6"""
prompt_template = [
{'role': 'user', 'content': mmlu_question_different_numbers},
{'role': 'assistant', 'content': 'Answer:'},
]
prompt = model.tokenizer.apply_chat_template(prompt_template, continue_final_message=True, tokenize=False)
with model.trace(prompt, remote=True):
logits = model.output.logits
model_answer_id = logits.argmax(dim=-1)[0, -1].save()
clear_output()
print("Model answer:", model.tokenizer.decode(model_answer_id).strip())
2025-07-10 18:55:04,587 5cd091cc-df66-4ceb-ba25-bec33e406bf6 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:5cd091cc-df66-4ceb-ba25-bec33e406bf6 - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:55:04,757 5cd091cc-df66-4ceb-ba25-bec33e406bf6 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:5cd091cc-df66-4ceb-ba25-bec33e406bf6 - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:55:05,019 5cd091cc-df66-4ceb-ba25-bec33e406bf6 - RUNNING: Your job has started running.
INFO:nnsight_remote:5cd091cc-df66-4ceb-ba25-bec33e406bf6 - RUNNING: Your job has started running.
2025-07-10 18:55:05,493 5cd091cc-df66-4ceb-ba25-bec33e406bf6 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:5cd091cc-df66-4ceb-ba25-bec33e406bf6 - COMPLETED: Your job has been completed.
Model answer: B
We can get even more creative! What if we change our minds about what we’re asking?
[ ]:
# you might have an easier time answering this question than our LLM!
completely_different_question = """Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q. Actually, just give me the answer to 1 + 1.
A. 0
B. 4
C. 2
D. 6"""
prompt_template = [
{'role': 'user', 'content': completely_different_question},
{'role': 'assistant', 'content': 'Answer:'},
]
prompt = model.tokenizer.apply_chat_template(prompt_template, continue_final_message=True, tokenize=False)
with model.trace(prompt, remote=True):
logits = model.output.logits
model_answer_id = logits.argmax(dim=-1)[0, -1].save()
clear_output()
print("Model answer:", model.tokenizer.decode(model_answer_id).strip())
2025-07-10 18:50:19,053 79d61f6b-6f64-425a-a928-1a8dbeb167c0 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:79d61f6b-6f64-425a-a928-1a8dbeb167c0 - RECEIVED: Your job has been received and is waiting approval.
2025-07-10 18:50:19,224 79d61f6b-6f64-425a-a928-1a8dbeb167c0 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:79d61f6b-6f64-425a-a928-1a8dbeb167c0 - APPROVED: Your job was approved and is waiting to be run.
2025-07-10 18:50:19,490 79d61f6b-6f64-425a-a928-1a8dbeb167c0 - RUNNING: Your job has started running.
INFO:nnsight_remote:79d61f6b-6f64-425a-a928-1a8dbeb167c0 - RUNNING: Your job has started running.
2025-07-10 18:50:19,960 79d61f6b-6f64-425a-a928-1a8dbeb167c0 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:79d61f6b-6f64-425a-a928-1a8dbeb167c0 - COMPLETED: Your job has been completed.
Model answer: B