Gradients#
There are a couple of ways we can interact with the gradients during and after a backward pass.
In the following example, we save the hidden states of the last layer and do a backward pass on the sum of the logits.
Note two things:
requires_grad=True
by default.We can all
.backward()
on a value within the tracing context just like you normally would.
[1]:
from nnsight import LanguageModel
import torch
model = LanguageModel("openai-community/gpt2", device_map="auto")
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[2]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
logits = model.output.logits
logits.sum().backward()
print(hidden_states)
tensor([[[ 0.5216, -1.1755, -0.4617, ..., -1.1919, 0.0204, -2.0075],
[ 0.9841, 2.2175, 3.5851, ..., 0.5212, -2.2286, 5.7334]]],
device='mps:0', grad_fn=<AddBackward0>)
If we wanted to see the gradients for the hidden_states, we can call .retain_grad()
on it and access the .grad
attribute after execution.
[3]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad = model.transformer.h[-1].output[0].grad.save()
model.output.logits.sum().backward()
print(hidden_states)
print(hidden_states_grad)
tensor([[[ 0.5216, -1.1755, -0.4617, ..., -1.1919, 0.0204, -2.0075],
[ 0.9841, 2.2175, 3.5851, ..., 0.5212, -2.2286, 5.7334]]],
device='mps:0', grad_fn=<AddBackward0>)
tensor([[[ 28.7976, -282.5975, 868.7336, ..., 120.1741, 52.2263,
168.6446],
[ 79.4183, -253.6228, 1322.1298, ..., 208.3982, -19.5544,
509.9858]]], device='mps:0')
Even better, nnsight
also provides proxy access into the backward process via the .grad
attribute on proxies. This works just like .input
and .output
where operations , including getting and setting, are traced and performed on the model at runtime. (assuming it’s a proxy of a Tensor, as this calls .register_hook(...)
on it!)
The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in GPT-2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.
[4]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad_before = hidden_states.grad.clone().save()
hidden_states.grad[:] = 0
hidden_states_grad_after = hidden_states.grad.save()
logits = model.output.logits
logits.sum().backward()
print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[ 28.7976, -282.5975, 868.7336, ..., 120.1741, 52.2263,
168.6446],
[ 79.4183, -253.6228, 1322.1298, ..., 208.3982, -19.5544,
509.9858]]], device='mps:0')
After tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='mps:0')
[5]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad_before = hidden_states.grad.clone().save()
hidden_states.grad = torch.zeros(hidden_states.grad.shape).to(hidden_states.grad.device)
hidden_states_grad_after = hidden_states.grad.save()
logits = model.output.logits
logits.sum().backward()
print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[ 28.7976, -282.5975, 868.7336, ..., 120.1741, 52.2263,
168.6446],
[ 79.4183, -253.6228, 1322.1298, ..., 208.3982, -19.5544,
509.9858]]], device='mps:0')
After tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='mps:0')