Modules#

Summary#

Modules from the traced model can also be called as functions during tracing. This operation does NOT need to follow the execution order of the model.

with model.trace(prompt):
  hs = model.transformer.h[l].output

  logits = model.lm_head.output.save()

  # call module
  out = model.transformer.ln_f(hs).save()

When to Use#

This functionality can be used to perform logit lens by decoding the hidden states of intermediate layers.

How to Use#

Beyond accessing and intervening on modules, you can also manually call a module out of order, similar to running your own function on its values.

Let’s try it out. Here, we will get the hidden states of the last layer. We can then chain apply model.transformer.ln_f and model.lm_head to “decode” the hidden states into the vocabulary space. Applying softmax and then argmax then transforms the vocabulary space hidden states into tokens that we can decode with the tokenizer.

[ ]:
with model.trace(prompt) as tracer:

    hidden_states = model.transformer.h[-1].output[0]

    # applying ln_f module and then lm_head module to hidden_states
    hidden_states = model.lm_head(model.transformer.ln_f(hidden_states)).save()

    tokens = torch.softmax(hidden_states, dim=2).argmax(dim=2).save()

The output looks like:

[ ]:
print(hidden_states)
print(tokens)
print(model.tokenizer.decode(tokens[0]))
tensor([[[ -36.2874,  -35.0114,  -38.0794,  ...,  -40.5164,  -41.3760,
           -34.9194],
         [ -68.8886,  -70.1562,  -71.8408,  ...,  -80.4195,  -78.2553,
           -71.1206],
         [ -82.2950,  -81.6519,  -83.9940,  ...,  -94.4878,  -94.5194,
           -85.6998],
         ...,
         [-111.7770, -110.1240, -111.3596,  ..., -114.8074, -113.9788,
          -110.5932],
         [ -76.1197,  -77.0599,  -83.5494,  ...,  -85.6277,  -83.4789,
           -79.1687],
         [-106.5131, -105.2869, -108.0233,  ..., -113.4612, -113.5024,
          -105.6071]]], grad_fn=<UnsafeViewBackward0>)
tensor([[ 198,   12,  417, 8765,  318,  257,  287,  262, 2612,  286, 6342]])

-el Tower is a in the heart of Paris