Working With Gradients#

There are a couple of ways we can interact with the gradients during and after a backward pass.

In the following example, we save the hidden states of the last layer and do a backward pass on the sum of the logits.

Note two things:

  1. requires_grad=True by default.

  2. We can all .backward() on a value within the tracing context just like you normally would.

[1]:
from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2", device_map="cuda")
/share/u/caden/.conda/envs/interp/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
with model.trace("Hello World") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()

    logits = model.output.logits

    logits.sum().backward()

print(hidden_states)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
tensor([[[ 0.5216, -1.1755, -0.4617,  ..., -1.1919,  0.0204, -2.0075],
         [ 0.9841,  2.2175,  3.5851,  ...,  0.5212, -2.2286,  5.7334]]],
       device='cuda:0', grad_fn=<AddBackward0>)

If we wanted to see the gradients for the hidden_states, we can call .retain_grad() on it and access the .grad attribute after execution.

[3]:
with model.trace("Hello World") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()
    hidden_states_grad = model.transformer.h[-1].output[0].grad.save()

    model.output.logits.sum().backward()

print(hidden_states)
print(hidden_states_grad)
tensor([[[ 0.5216, -1.1755, -0.4617,  ..., -1.1919,  0.0204, -2.0075],
         [ 0.9841,  2.2175,  3.5851,  ...,  0.5212, -2.2286,  5.7334]]],
       device='cuda:0', grad_fn=<AddBackward0>)
tensor([[[  28.7976, -282.5981,  868.7354,  ...,  120.1743,   52.2264,
           168.6449],
         [  79.4181, -253.6225, 1322.1293,  ...,  208.3980,  -19.5545,
           509.9857]]], device='cuda:0')

Even better, nnsight also provides proxy access into the backward process via the .grad attribute on proxies. This works just like .input and .output where operations , including getting and setting, are traced and performed on the model at runtime. (assuming it’s a proxy of a Tensor as this calls .register_hook(…) on it!)

The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in gpt2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.

[5]:
from nnsight import LanguageModel
import torch

model = LanguageModel('gpt2', device_map='cuda')

with model.trace("Hello World") as tracer:
    hidden_states = model.transformer.h[-1].output[0].save()

    hidden_states_grad_before = hidden_states.grad.clone().save()
    hidden_states.grad[:] = 0
    hidden_states_grad_after = hidden_states.grad.save()

    logits = model.output.logits

    logits.sum().backward()

print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)

with model.trace("Hello World") as tracer:
    hidden_states = model.transformer.h[-1].output[0].save()

    hidden_states_grad_before = hidden_states.grad.clone().save()
    hidden_states.grad = torch.zeros(hidden_states.grad.shape)
    hidden_states_grad_after = hidden_states.grad.save()

    logits = model.output.logits

    logits.sum().backward()

print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Before tensor([[[  28.7976, -282.5981,  868.7354,  ...,  120.1743,   52.2264,
           168.6449],
         [  79.4181, -253.6225, 1322.1293,  ...,  208.3980,  -19.5545,
           509.9857]]], device='cuda:0')
After tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')
Before tensor([[[  28.7976, -282.5981,  868.7354,  ...,  120.1743,   52.2264,
           168.6449],
         [  79.4181, -253.6225, 1322.1293,  ...,  208.3980,  -19.5545,
           509.9857]]], device='cuda:0')
After tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')