Getting#

Summary#

NNsight gets values from model internals by accessing module inputs or outputs and assigning them to a variable.

You can get model intermediate values within the tracing context by calling .output or .input on your module of interest and assigning it to a variable. If you want to access defined values after exiting the tracing context, you need to call .save().

# capturing interventions
with model.trace(input): # enter the tracing context
    k_layer_output = model.layer[k].output[0].save() # access the output of layer[k] and save for later use

When to Use#

Getting values is one of the most fundamental operations within NNsight. You will use this whenever you would like to access a hidden state within a model’s forward pass, such as activations.

How to Use#

Hidden states are exposed in NNsight by accessing the desired module and calling its input or output attributes.

The input and output attributes can be accessed by calling .input or .output on a module (can also use .inputs or .outputs, which return more detailed information).

Note: .input/.output vs .inputs/.outputs

The .inputs and .outputs attributes return tuples of a tuple and a dictionary with the positional and keyword arguments ((args),{kwargs}), respectively.

On the other hand, .input and .output attributes return the first argument, which is usually positional. If there are only keyword arguments, then .input would return the first keyword. So, .input is equivalent to .inputs[0][0].

Values defined within the tracing context are automatically deleted upon exiting the context, so you need to save variables that you plan to use later with .save(). Calling .save() on a variable ensures its value persists after exiting the tracing context.

[ ]:
with model.trace("The Eiffel Tower is in the city of") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()

After exiting the tracing context, the .value attribute of the hidden_states object will be populated.

[ ]:
print(hidden_states)
tensor([[[ 0.0505, -0.1728, -0.1690,  ..., -1.0096,  0.1280, -1.0687],
         [ 8.7495,  2.9057,  5.3024,  ..., -8.0418,  1.2964, -2.8677],
         [ 0.2960,  4.6686, -3.6642,  ...,  0.2391, -2.6064,  3.2263],
         ...,
         [ 2.1537,  6.8917,  3.8651,  ...,  0.0588, -1.9866,  5.9188],
         [-0.4460,  7.4285, -9.3065,  ...,  2.0528, -2.7946,  0.5556],
         [ 6.6286,  1.7258,  4.7969,  ...,  7.6714,  3.0682,  2.0481]]],
       device='mps:0', grad_fn=<AddBackward0>)