Modules#
Summary#
Modules from the traced model can also be called as functions during tracing. This operation does NOT need to follow the execution order of the model.
with model.trace(prompt):
hs = model.transformer.h[l].output
logits = model.lm_head.output.save()
# call module
out = model.transformer.ln_f(hs).save()
When to Use#
This functionality can be used to perform logit lens by decoding the hidden states of intermediate layers.
How to Use#
Beyond accessing and intervening on modules, you can also manually call a module out of order, similar to running your own function on its values.
Let’s try it out. Here, we will get the hidden states of the last layer. We can then chain apply model.transformer.ln_f
and model.lm_head
to “decode” the hidden states into the vocabulary space. Applying softmax
and then argmax
then transforms the vocabulary space hidden states into tokens that we can decode with the tokenizer.
[ ]:
with model.trace(prompt) as tracer:
hidden_states = model.transformer.h[-1].output[0]
# applying ln_f module and then lm_head module to hidden_states
hidden_states = model.lm_head(model.transformer.ln_f(hidden_states)).save()
tokens = torch.softmax(hidden_states, dim=2).argmax(dim=2).save()
The output looks like:
[ ]:
print(hidden_states)
print(tokens)
print(model.tokenizer.decode(tokens[0]))
tensor([[[ -36.2874, -35.0114, -38.0794, ..., -40.5164, -41.3760,
-34.9194],
[ -68.8886, -70.1562, -71.8408, ..., -80.4195, -78.2553,
-71.1206],
[ -82.2950, -81.6519, -83.9940, ..., -94.4878, -94.5194,
-85.6998],
...,
[-111.7770, -110.1240, -111.3596, ..., -114.8074, -113.9788,
-110.5932],
[ -76.1197, -77.0599, -83.5494, ..., -85.6277, -83.4789,
-79.1687],
[-106.5131, -105.2869, -108.0233, ..., -113.4612, -113.5024,
-105.6071]]], grad_fn=<UnsafeViewBackward0>)
tensor([[ 198, 12, 417, 8765, 318, 257, 287, 262, 2612, 286, 6342]])
-el Tower is a in the heart of Paris