Gradients#
Summary#
We can access and intervene on gradients in NNsight using the .backward()
context.
with model.trace(input):
# access module inputs/outputs during forward pass
hidden_states = model.layer[k].output[0].save()
logits = model.output[0].logits.save()
# enter .backward() context
with logits.sum().backward():
# You can access gradients within backward context
grad = logits.grad.save()
# backwards operations must be defined in order of backpropagation
# (so opposite to the model execution order)
with hidden_states.sum().backward():
# you can intervene on gradients within backward context
hidden_states.grad[:] = 0 # in-place setting operation
When to Use#
Gradients are used in input attribution methods like integrated gradients and also in causal tracing methods like attribution patching.
How to Use#
To access and alter gradients in nnsight
, you can do so inside a backwards context. The backwards context re-implements all the tracing functionality established by the core library with .grad
as the entrypoint (instead of .inputs
/.output
in forward pass tracing) to intervene on gradient values from intermediate computations fetched in the tracing context.
The backwards context also enforces an order of module accessing, but in reverse to the model’s execution order, intuitively following the backpropagration order.
To access and alter gradients in NNsight
, you can create a .backward()
context using a with
block. In the .backward()
context, you can get and set gradient values as you’d do in the forward pass with a module’s .output
and .input
.
Note: Syntax for gradients prior to v0.5
Before NNsight 0.5
, to access gradients you would call .grad
and then call .backward()
to initiate the backwards pass. This is deprecated in NNsight 0.5
and following versions.
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad = model.transformer.h[-1].output[0].grad.save()
model.output.logits.sum().backward()
print(hidden_states)
print(hidden_states_grad)
Accessing Gradients#
To access the gradient of some activation, we first capture it during the forward pass tracing. In this example we get the model’s logit outputs.
During the backwards pass, we access the gradients of our logits
variable and save them with .save()
as we would with forward pass variables.
[4]:
with model.trace(prompt):
logits = model.output.logits
# .backward() context
with logits.sum().backward():
# access .grad within backward context
grad = logits.grad.save()
print(grad)
tensor([[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]]], device='cuda:0')
You can still backpropagate without tracing when you call .backwards()
without a with
block. This can make your code cleaner if you only need to access without intervening on them
then calling .retain_grad()
on the value yields the same result as the example above.
[5]:
with model.trace(prompt):
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states.retain_grad()
logits = model.output.logits.save()
logits.sum().backward()
print(hidden_states.grad)
tensor([[[ 28.7977, -282.5986, 868.7366, ..., 120.1745, 52.2265,
168.6451],
[ 79.4183, -253.6228, 1322.1305, ..., 208.3983, -19.5544,
509.9862]]], device='cuda:0')
You can also backpropagate multiple times during the same tracing context. This requires retaining the graph on values of interest.
[15]:
with model.trace(prompt):
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states.retain_grad()
logits = model.output.logits
logits.sum().backward(retain_graph=True)
print(hidden_states.grad)
logits.sum().backward(retain_graph=True)
print(hidden_states.grad)
logits.sum().backward()
print(hidden_states.grad)
tensor([[[ 28.7977, -282.5986, 868.7366, ..., 120.1745, 52.2265,
168.6451],
[ 79.4183, -253.6228, 1322.1305, ..., 208.3983, -19.5544,
509.9862]]], device='cuda:0')
tensor([[[ 57.5953, -565.1971, 1737.4731, ..., 240.3490, 104.4530,
337.2901],
[ 158.8366, -507.2456, 2644.2610, ..., 416.7967, -39.1088,
1019.9724]]], device='cuda:0')
tensor([[[ 86.3930, -847.7957, 2606.2097, ..., 360.5235, 156.6795,
505.9352],
[ 238.2549, -760.8683, 3966.3916, ..., 625.1951, -58.6633,
1529.9586]]], device='cuda:0')
Intervening on Gradients#
The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in GPT-2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.
[14]:
with model.trace(prompt) as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
logits = model.output.logits
with logits.sum().backward():
logits_grad = logits.grad.save() # <-reverse order
hidden_states_grad_before = hidden_states.grad.clone().save()
hidden_states.grad[:] = 0
hidden_states_grad_after = hidden_states.grad.save()
print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[ 28.7977, -282.5986, 868.7366, ..., 120.1745, 52.2265,
168.6451],
[ 79.4183, -253.6228, 1322.1305, ..., 208.3983, -19.5544,
509.9862]]], device='cuda:0')
After tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='cuda:0')
Notes#
Gradients are disabled by default on remote models, call
requires_grad=True
during tracing to enable them.You need to call
.grad
on intermediate values accessed in the forward pass.
Related#
Attribution Patching
Mini-Paper: Do Language Models Use Their Depth Efficiently?