Multiple Token Generation#
When generating more than one token, use <module>.next()
to denote following interventions should be applied to the subsequent generations for that module.
Here we generate three tokens and save the hidden states of the last layer for each one:
[8]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
.generate()
#
NNsight’s LanguageModel
class supports multiple token generation with .generate()
. You can control the number of new tokens generated by setting max_new_tokens = N
within your call to .generate()
.
[9]:
prompt = 'The Eiffel Tower is in the city of'
n_new_tokens = 3
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
out = model.generator.output.save()
decoded_prompt = model.tokenizer.decode(out[0][0:-n_new_tokens].cpu())
decoded_answer = model.tokenizer.decode(out[0][-n_new_tokens:].cpu())
print("Prompt: ", decoded_prompt)
print("Generated Answer: ", decoded_answer)
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Prompt: The Eiffel Tower is in the city of
Generated Answer: Paris, and
.next()
#
When generating more than one token, use <module>.next()
to denote following interventions should be applied to the subsequent generations for that module.
Here we generate three tokens and save the hidden states of the last layer for each one:
[10]:
n_new_tokens = 3
with model.generate('The Eiffel Tower is in the city of', max_new_tokens=n_new_tokens) as tracer:
hidden_states1 = model.transformer.h[-1].output[0].save()
hidden_states2 = model.transformer.h[-1].next().output[0].save()
hidden_states3 = model.transformer.h[-1].next().output[0].save()
out = model.generator.output.save()
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Note how calling save before tracer.next()
returns the hidden state across the initial prompt while calling save after returns the hidden state of each subsequent generated token.
[11]:
print(hidden_states1.shape) # hidden states across prompt & first generated token
print(hidden_states2.shape) # only hidden states across next token
print(hidden_states3.shape) # only hidden states across next token
print(out) # model output tokens, including prompt
torch.Size([1, 10, 768])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
tensor([[ 464, 412, 733, 417, 8765, 318, 287, 262, 1748, 286, 6342, 11,
290]], device='mps:0')
Great, we’ve now successfully stored hidden states across multiple different rounds of token generation! However, if you’re generating many tokens while applying interventions, using .next()
requires you to set a loop within the tracing context, which can be clunky:
[12]:
# Old approach:
prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 50
hidden_states = []
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
for i in range(n_new_tokens):
# Apply intervention - set first layer output to zero
layers[0].output[0][:] = 0
# Append desired hidden state post-intervention
hidden_states.append(layers[-1].output.save())
# Move to next generated token
layers[0].next()
print("Hidden state length: ",len(hidden_states))
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Hidden state length: 50
.all()
streamlines interventions on many generated tokens#
With nnsight 0.4
you can use .all()
to recursively apply interventions to a model. Calling .all()
on a module within a model will recursively apply its .input
and .output
across all iterations. Previously, we’d need to loop across each new generated token, saving the intervention for every generated token and calling .next()
to move forward, as demonstrated in the previous section.
Let’s try using .all()
to streamline the multiple token generation process. We simply call .all()
on the module where we are applying the intervention (in this case GPT-2’s layers), apply our intervention, and append our hidden states (stored in an nnsight.list()
object).
[13]:
import nnsight
# using .all():
prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 50
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
hidden_states = nnsight.list().save() # Initialize & .save() nnsight list
# Call .all() to apply intervention to each new token
layers.all()
# Apply intervention - set first layer output to zero
layers[0].output[0][:] = 0
# Append desired hidden state post-intervention
hidden_states.append(layers[-1].output) # no need to call .save
# Don't need to loop or call .next()!
print("Hidden state length: ",len(hidden_states))
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Hidden state length: 50
Easy! Note that because .all()
is recursive, it will only work to append outputs called on children of the module that .all()
was called on. See example below for more information. TL;DR: apply .all()
on the highest-level accessed module if interventions and outputs have different hierarchies within model structure.
Recursive properties of .all()
.all()
recursively acts on model components. In the below code example, only the first token generation is saved, because .all()
applied to layers
, while the saved variable hidden_states
is produced from model.lm_head
, which is not a child of layers
.
prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 3
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
hidden_states = nnsight.list().save() # Initialize & .save() nnsight list
# Call .all() on layers
layers.all()
# Apply same intervention - set first layer output to zero
layers[0].output[0][:] = 0
# Append desired hidden state post-intervention
hidden_states.append(model.lm_head.output) # no need to call .save, it's already initialized
print("Hidden state length: ",len(hidden_states)) # length is 1, meaning it only saved the first token generation
If you want to apply an intervention during multiple token generation while saving the state of a model component that isn’t a child of that module, you can instead apply .all()
to the full model:
prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 3
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
hidden_states = nnsight.list().save() # Initialize & .save() nnsight list
# Call .all() on model
model.all()
# Apply same intervention - set first layer output to zero
layers[0].output[0][:] = 0
# Append desired hidden state post-intervention
hidden_states.append(model.lm_head.output) # no need to call .save
print("Hidden state length: ",len(hidden_states)) # length is 3, as expected!