Multiple Token Generation¶
nnsight supports multi-token generation using .generate(). You can intervene on specific generation steps, iterate over all steps, or apply interventions globally.
Setup¶
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
Basic Generation¶
Use .generate() instead of .trace() to run multi-token generation. Access the generated tokens via model.generator.output.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=5) as tracer:
output = model.generator.output.save()
print(model.tokenizer.decode(output[0]))
The Eiffel Tower is in the city of Paris, and the E
Iterating Over Generation Steps¶
Use for step in tracer.iter[:] to loop over every generation step. This is the preferred way to intervene on or collect values from individual steps.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=5) as tracer:
tokens = list().save()
for step in tracer.iter[:]:
token = model.lm_head.output[0, -1].argmax(dim=-1)
tokens.append(token)
for i, t in enumerate(tokens):
print(f"Step {i}: {model.tokenizer.decode(t)}")
Step 0: Paris Step 1: , Step 2: and Step 3: the Step 4: E
Bounded Iteration¶
Use a slice to iterate over specific steps only. tracer.iter[1:3] runs for steps 1 and 2.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=5) as tracer:
tokens = list().save()
for step in tracer.iter[1:3]:
tokens.append(model.lm_head.output[0, -1].argmax(dim=-1))
output = model.generator.output.save()
print(f"Full output: {model.tokenizer.decode(output[0])}")
print(f"Collected {len(tokens)} tokens from steps 1-2")
Full output: The Eiffel Tower is in the city of Paris, and the E Collected 2 tokens from steps 1-2
You can also use a single index (tracer.iter[0]) or a list of indices (tracer.iter[[0, 2, 4]]).
Conditional Interventions Per Step¶
The step variable gives you the current step index, letting you apply different interventions at different points in generation.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=5) as tracer:
for step in tracer.iter[:]:
if step == 0:
# Only intervene on the first generation step
model.transformer.h[0].output[0][:] = 0
output = model.generator.output.save()
print(model.tokenizer.decode(output[0]))
The Eiffel Tower is in the city of, but I'm not
Code after unbounded iteration never runs
When using tracer.iter[:] (unbounded), the iterator doesn't know when generation will end, so it waits forever for the next step. All code after the loop is skipped. Use bounded iteration (tracer.iter[:3]), or put post-loop code in a separate invoker.
Manual Stepping with tracer.next()¶
For explicit control over which step you're on, use tracer.next() to advance to the next generation step.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=3) as tracer:
# First step (default)
hs0 = model.transformer.h[-1].output[0].save()
# Advance to second step
tracer.next()
hs1 = model.transformer.h[-1].output[0].save()
# Advance to third step
tracer.next()
hs2 = model.transformer.h[-1].output[0].save()
print(f"Step 0 shape: {hs0.shape}")
print(f"Step 1 shape: {hs1.shape}")
print(f"Step 2 shape: {hs2.shape}")
Step 0 shape: torch.Size([1, 10, 768]) Step 1 shape: torch.Size([1, 1, 768]) Step 2 shape: torch.Size([1, 1, 768])
Why step 0 is larger
Step 0 processes the full prompt (10 tokens), so its hidden state shape is [1, 10, 768]. Subsequent steps process one new token at a time, giving shape [1, 1, 768].
Applying Interventions to All Steps with tracer.all()¶
tracer.all() applies the same intervention to every generation step, including recursively into any nested invokes.
with model.generate("The Eiffel Tower is in the city of", max_new_tokens=5) as tracer:
hidden_states = list().save()
for step in tracer.all():
# This runs on every step
model.transformer.h[0].output[0][:] = 0
hidden_states.append(model.transformer.h[-1].output[0])
output = model.generator.output.save()
print(f"Collected {len(hidden_states)} hidden states")
print(f"Output: {model.tokenizer.decode(output[0])}")
Collected 5 hidden states Output: The Eiffel Tower is in the city of,,,,,