Gradients#

Summary#

We can access and intervene on gradients in NNsight using the .backward() context.

with model.trace(input):
    # access module inputs/outputs during forward pass
    hidden_states = model.layer[k].output[0].save()
    logits = model.output[0].logits.save()

    # enter .backward() context
    with logits.sum().backward():
        # You can access gradients within backward context
        grad = logits.grad.save()

    # backwards operations must be defined in order of backpropagation
    # (so opposite to the model execution order)
    with hidden_states.sum().backward():
        # you can intervene on gradients within backward context
        hidden_states.grad[:] = 0 # in-place setting operation

When to Use#

Gradients are used in input attribution methods like integrated gradients and also in causal tracing methods like attribution patching.

How to Use#

To access and alter gradients in nnsight, you can do so inside a backwards context. The backwards context re-implements all the tracing functionality established by the core library with .grad as the entrypoint (instead of .inputs/.output in forward pass tracing) to intervene on gradient values from intermediate computations fetched in the tracing context.

The backwards context also enforces an order of module accessing, but in reverse to the model’s execution order, intuitively following the backpropagration order.

To access and alter gradients in NNsight, you can create a .backward() context using a with block. In the .backward() context, you can get and set gradient values as you’d do in the forward pass with a module’s .output and .input.

Note: Syntax for gradients prior to v0.5

Before NNsight 0.5, to access gradients you would call .grad and then call .backward() to initiate the backwards pass. This is deprecated in NNsight 0.5 and following versions.

with model.trace("Hello World") as tracer:

    hidden_states = model.transformer.h[-1].output[0].save()
    hidden_states_grad = model.transformer.h[-1].output[0].grad.save()

    model.output.logits.sum().backward()

print(hidden_states)
print(hidden_states_grad)

Accessing Gradients#

To access the gradient of some activation, we first capture it during the forward pass tracing. In this example we get the model’s logit outputs.

During the backwards pass, we access the gradients of our logits variable and save them with .save() as we would with forward pass variables.

[4]:
with model.trace(prompt):

    logits = model.output.logits

    # .backward() context
    with logits.sum().backward():
        # access .grad within backward context
        grad = logits.grad.save()

print(grad)
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], device='cuda:0')

You can still backpropagate without tracing when you call .backwards() without a with block. This can make your code cleaner if you only need to access without intervening on them

then calling .retain_grad() on the value yields the same result as the example above.

[5]:
with model.trace(prompt):

    hidden_states = model.transformer.h[-1].output[0].save()
    hidden_states.retain_grad()

    logits = model.output.logits.save()

    logits.sum().backward()

print(hidden_states.grad)
tensor([[[  28.7977, -282.5986,  868.7366,  ...,  120.1745,   52.2265,
           168.6451],
         [  79.4183, -253.6228, 1322.1305,  ...,  208.3983,  -19.5544,
           509.9862]]], device='cuda:0')

You can also backpropagate multiple times during the same tracing context. This requires retaining the graph on values of interest.

[15]:
with model.trace(prompt):

    hidden_states = model.transformer.h[-1].output[0].save()
    hidden_states.retain_grad()

    logits = model.output.logits

    logits.sum().backward(retain_graph=True)

    print(hidden_states.grad)

    logits.sum().backward(retain_graph=True)

    print(hidden_states.grad)

    logits.sum().backward()

print(hidden_states.grad)
tensor([[[  28.7977, -282.5986,  868.7366,  ...,  120.1745,   52.2265,
           168.6451],
         [  79.4183, -253.6228, 1322.1305,  ...,  208.3983,  -19.5544,
           509.9862]]], device='cuda:0')
tensor([[[  57.5953, -565.1971, 1737.4731,  ...,  240.3490,  104.4530,
           337.2901],
         [ 158.8366, -507.2456, 2644.2610,  ...,  416.7967,  -39.1088,
          1019.9724]]], device='cuda:0')
tensor([[[  86.3930, -847.7957, 2606.2097,  ...,  360.5235,  156.6795,
           505.9352],
         [ 238.2549, -760.8683, 3966.3916,  ...,  625.1951,  -58.6633,
          1529.9586]]], device='cuda:0')

Intervening on Gradients#

The following examples demonstrate ablating (setting to zero) the gradients for a hidden state in GPT-2. The first example is an in-place operation and the second swaps the gradient out for a new tensor of zeroes.

[14]:
with model.trace(prompt) as tracer:
    hidden_states = model.transformer.h[-1].output[0].save()


    logits = model.output.logits

    with logits.sum().backward():
        logits_grad = logits.grad.save() #  <-reverse order

        hidden_states_grad_before = hidden_states.grad.clone().save()
        hidden_states.grad[:] = 0
        hidden_states_grad_after = hidden_states.grad.save()

print("Before", hidden_states_grad_before)
print("After", hidden_states_grad_after)
Before tensor([[[  28.7977, -282.5986,  868.7366,  ...,  120.1745,   52.2265,
           168.6451],
         [  79.4183, -253.6228, 1322.1305,  ...,  208.3983,  -19.5544,
           509.9862]]], device='cuda:0')
After tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')

Notes#

  • Gradients are disabled by default on remote models, call requires_grad=True during tracing to enable them.

  • You need to call .grad on intermediate values accessed in the forward pass.