Getting#
Summary#
NNsight gets values from model internals by accessing module inputs or outputs and assigning them to a variable.
You can get model intermediate values within the tracing context by calling .output
or .input
on your module of interest and assigning it to a variable. If you want to access defined values after exiting the tracing context, you need to call .save()
.
# capturing interventions
with model.trace(input): # enter the tracing context
k_layer_output = model.layer[k].output[0].save() # access the output of layer[k] and save for later use
When to Use#
Getting values is one of the most fundamental operations within NNsight. You will use this whenever you would like to access a hidden state within a model’s forward pass, such as activations.
How to Use#
Hidden states are exposed in NNsight by accessing the desired module and calling its input or output attributes.
The input and output attributes can be accessed by calling .input
or .output
on a module (can also use .inputs
or .outputs
, which return more detailed information).
Note: .input/.output vs .inputs/.outputs
The .inputs
and .outputs
attributes return tuples of a tuple and a dictionary with the positional and keyword arguments ((args),{kwargs})
, respectively.
On the other hand, .input
and .output
attributes return the first argument, which is usually positional. If there are only keyword arguments, then .input
would return the first keyword. So, .input
is equivalent to .inputs[0][0]
.
Values defined within the tracing context are automatically deleted upon exiting the context, so you need to save variables that you plan to use later with .save()
. Calling .save()
on a variable ensures its value persists after exiting the tracing context.
[ ]:
with model.trace("The Eiffel Tower is in the city of") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
After exiting the tracing context, the .value
attribute of the hidden_states
object will be populated.
[ ]:
print(hidden_states)
tensor([[[ 0.0505, -0.1728, -0.1690, ..., -1.0096, 0.1280, -1.0687],
[ 8.7495, 2.9057, 5.3024, ..., -8.0418, 1.2964, -2.8677],
[ 0.2960, 4.6686, -3.6642, ..., 0.2391, -2.6064, 3.2263],
...,
[ 2.1537, 6.8917, 3.8651, ..., 0.0588, -1.9866, 5.9188],
[-0.4460, 7.4285, -9.3065, ..., 2.0528, -2.7946, 0.5556],
[ 6.6286, 1.7258, 4.7969, ..., 7.6714, 3.0682, 2.0481]]],
device='mps:0', grad_fn=<AddBackward0>)
Related#
Setting