Gradients#

There are a couple of ways we can interact with the gradients during and after a backward pass.

In the following example, we save the hidden states of the last layer and do a backward pass on the sum of the logits.

Note two things:

  1. requires_grad=True by default.

  2. We can all .backward() on a value within the tracing context just like you normally would.

[1]:
from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto")
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
with model.trace("Hello World") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()

    logits = model.output.logits

    logits.sum().backward()

print(hidden_states)
tensor([[[ 0.5216, -1.1755, -0.4617,  ..., -1.1919,  0.0204, -2.0075],
         [ 0.9841,  2.2175,  3.5851,  ...,  0.5212, -2.2286,  5.7334]]],
       device='mps:0', grad_fn=<AddBackward0>)

If we wanted to see the gradients for the hidden_states, we can call .retain_grad() on it and access the .grad attribute after execution.

[3]:
with model.trace("Hello World") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()
    hidden_states_grad = model.transformer.h[-1].output[0].grad.save()

    model.output.logits.sum().backward()

print(hidden_states)
print(hidden_states_grad)
tensor([[[ 0.5216, -1.1755, -0.4617,  ..., -1.1919,  0.0204, -2.0075],
         [ 0.9841,  2.2175,  3.5851,  ...,  0.5212, -2.2286,  5.7334]]],
       device='mps:0', grad_fn=<AddBackward0>)
tensor([[[  28.7976, -282.5975,  868.7336,  ...,  120.1741,   52.2263,
           168.6446],
         [  79.4183, -253.6228, 1322.1298,  ...,  208.3982,  -19.5544,
           509.9858]]], device='mps:0')

Even better, nnsight also provides proxy access into the backward process via the .grad attribute on proxies. This works just like .input and .output where operations , including getting and setting, are traced and performed on the model at runtime. (assuming it’s a proxy of a Tensor, as this calls .register_hook(...) on it!)

The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in GPT-2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.

[4]:
with model.trace("Hello World") as tracer:
    hidden_states = model.transformer.h[-1].output[0].save()

    hidden_states_grad_before = hidden_states.grad.clone().save()
    hidden_states.grad[:] = 0
    hidden_states_grad_after = hidden_states.grad.save()

    logits = model.output.logits

    logits.sum().backward()

print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[  28.7976, -282.5975,  868.7336,  ...,  120.1741,   52.2263,
           168.6446],
         [  79.4183, -253.6228, 1322.1298,  ...,  208.3982,  -19.5544,
           509.9858]]], device='mps:0')
After tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='mps:0')
[5]:
with model.trace("Hello World") as tracer:
    hidden_states = model.transformer.h[-1].output[0].save()

    hidden_states_grad_before = hidden_states.grad.clone().save()
    hidden_states.grad = torch.zeros(hidden_states.grad.shape).to(hidden_states.grad.device)
    hidden_states_grad_after = hidden_states.grad.save()

    logits = model.output.logits

    logits.sum().backward()

print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[  28.7976, -282.5975,  868.7336,  ...,  120.1741,   52.2263,
           168.6446],
         [  79.4183, -253.6228, 1322.1298,  ...,  208.3982,  -19.5544,
           509.9858]]], device='mps:0')
After tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='mps:0')