Gradients#
There are a couple of ways we can interact with the gradients during and after a backward pass.
In the following example, we save the hidden states of the last layer and do a backward pass on the sum of the logits.
Note two things:
requires_grad=True
by default.We can all
.backward()
on a value within the tracing context just like you normally would.
[1]:
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="cuda")
/share/u/caden/.conda/envs/interp/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[2]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
logits = model.output.logits
logits.sum().backward()
print(hidden_states)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
tensor([[[ 0.5216, -1.1755, -0.4617, ..., -1.1919, 0.0204, -2.0075],
[ 0.9841, 2.2175, 3.5851, ..., 0.5212, -2.2286, 5.7334]]],
device='cuda:0', grad_fn=<AddBackward0>)
If we wanted to see the gradients for the hidden_states, we can call .retain_grad()
on it and access the .grad
attribute after execution.
[3]:
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad = model.transformer.h[-1].output[0].grad.save()
model.output.logits.sum().backward()
print(hidden_states)
print(hidden_states_grad)
tensor([[[ 0.5216, -1.1755, -0.4617, ..., -1.1919, 0.0204, -2.0075],
[ 0.9841, 2.2175, 3.5851, ..., 0.5212, -2.2286, 5.7334]]],
device='cuda:0', grad_fn=<AddBackward0>)
tensor([[[ 28.7976, -282.5981, 868.7354, ..., 120.1743, 52.2264,
168.6449],
[ 79.4181, -253.6225, 1322.1293, ..., 208.3980, -19.5545,
509.9857]]], device='cuda:0')
Even better, nnsight
also provides proxy access into the backward process via the .grad
attribute on proxies. This works just like .input
and .output
where operations , including getting and setting, are traced and performed on the model at runtime. (assuming it’s a proxy of a Tensor as this calls .register_hook(…) on it!)
The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in gpt2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.
[5]:
from nnsight import LanguageModel
import torch
model = LanguageModel('gpt2', device_map='cuda')
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad_before = hidden_states.grad.clone().save()
hidden_states.grad[:] = 0
hidden_states_grad_after = hidden_states.grad.save()
logits = model.output.logits
logits.sum().backward()
print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
with model.trace("Hello World") as tracer:
hidden_states = model.transformer.h[-1].output[0].save()
hidden_states_grad_before = hidden_states.grad.clone().save()
hidden_states.grad = torch.zeros(hidden_states.grad.shape)
hidden_states_grad_after = hidden_states.grad.save()
logits = model.output.logits
logits.sum().backward()
print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Before tensor([[[ 28.7976, -282.5981, 868.7354, ..., 120.1743, 52.2264,
168.6449],
[ 79.4181, -253.6225, 1322.1293, ..., 208.3980, -19.5545,
509.9857]]], device='cuda:0')
After tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='cuda:0')
Before tensor([[[ 28.7976, -282.5981, 868.7354, ..., 120.1743, 52.2264,
168.6449],
[ 79.4181, -253.6225, 1322.1293, ..., 208.3980, -19.5545,
509.9857]]], device='cuda:0')
After tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='cuda:0')