Source Code Interventions#

Summary#

You can use NNsight to access the intermediate vales of a forward pass! We can print the .source of a module to see its forward pass and associated names of each operation.

# .source to print the first transformer
print(model.transformer.h[0].attn.source)

You can also use source to get and set intermediate values during the tracing context.

with model.trace(prompt):
  # get intermediate value
  attention = model.transformer.h[0].attn.source.attention_interface_0.output.save()

  # set intermediate value
  model.transformer.h[k].attn.source.attention_interface_0.output = attention

When to Use#

One of the most common use-cases for source is accessing attention.

How to Use#

To view the intermediate values of a forward pass, you can call .source on a module. Printing .source of a module will help you see the names of the forward pass operations.

[5]:
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')

print(model.transformer.h[0].attn.source)
                                       * def forward(
                                       0     self,
                                       1     hidden_states: Optional[tuple[torch.FloatTensor]],
                                       2     past_key_value: Optional[Cache] = None,
                                       3     cache_position: Optional[torch.LongTensor] = None,
                                       4     attention_mask: Optional[torch.FloatTensor] = None,
                                       5     head_mask: Optional[torch.FloatTensor] = None,
                                       6     encoder_hidden_states: Optional[torch.Tensor] = None,
                                       7     encoder_attention_mask: Optional[torch.FloatTensor] = None,
                                       8     output_attentions: Optional[bool] = False,
                                       9     **kwargs,
                                      10 ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
                                      11     is_cross_attention = encoder_hidden_states is not None
                                      12     if past_key_value is not None:
 isinstance_0                      -> 13         if isinstance(past_key_value, EncoderDecoderCache):
 past_key_value_is_updated_get_0   -> 14             is_updated = past_key_value.is_updated.get(self.layer_idx)
                                      15             if is_cross_attention:
                                      16                 # after the first generated id, we can subsequently re-use all key/value_layer from cache
                                      17                 curr_past_key_value = past_key_value.cross_attention_cache
                                      18             else:
                                      19                 curr_past_key_value = past_key_value.self_attention_cache
                                      20         else:
                                      21             curr_past_key_value = past_key_value
                                      22
                                      23     if is_cross_attention:
 hasattr_0                         -> 24         if not hasattr(self, "q_attn"):
 ValueError_0                      -> 25             raise ValueError(
                                      26                 "If class is used as cross attention, the weights `q_attn` have to be defined. "
                                      27                 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                                      28             )
 self_q_attn_0                     -> 29         query_states = self.q_attn(hidden_states)
                                      30         attention_mask = encoder_attention_mask
                                      31
                                      32         # Try to get key/value states from cache if possible
                                      33         if past_key_value is not None and is_updated:
                                      34             key_states = curr_past_key_value.layers[self.layer_idx].keys
                                      35             value_states = curr_past_key_value.layers[self.layer_idx].values
                                      36         else:
 self_c_attn_0                     -> 37             key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
 split_0                           ->  +             ...
                                      38             shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
 key_states_view_0                 -> 39             key_states = key_states.view(shape_kv).transpose(1, 2)
 transpose_0                       ->  +             ...
 value_states_view_0               -> 40             value_states = value_states.view(shape_kv).transpose(1, 2)
 transpose_1                       ->  +             ...
                                      41     else:
 self_c_attn_1                     -> 42         query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
 split_1                           ->  +         ...
                                      43         shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
 key_states_view_1                 -> 44         key_states = key_states.view(shape_kv).transpose(1, 2)
 transpose_2                       ->  +         ...
 value_states_view_1               -> 45         value_states = value_states.view(shape_kv).transpose(1, 2)
 transpose_3                       ->  +         ...
                                      46
                                      47     shape_q = (*query_states.shape[:-1], -1, self.head_dim)
 query_states_view_0               -> 48     query_states = query_states.view(shape_q).transpose(1, 2)
 transpose_4                       ->  +     ...
                                      49
                                      50     if (past_key_value is not None and not is_cross_attention) or (
                                      51         past_key_value is not None and is_cross_attention and not is_updated
                                      52     ):
                                      53         # save all key/value_layer to cache to be re-used for fast auto-regressive generation
                                      54         cache_position = cache_position if not is_cross_attention else None
 curr_past_key_value_update_0      -> 55         key_states, value_states = curr_past_key_value.update(
                                      56             key_states, value_states, self.layer_idx, {"cache_position": cache_position}
                                      57         )
                                      58         # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
                                      59         if is_cross_attention:
                                      60             past_key_value.is_updated[self.layer_idx] = True
                                      61
                                      62     is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
                                      63
                                      64     using_eager = self.config._attn_implementation == "eager"
                                      65     attention_interface: Callable = eager_attention_forward
                                      66     if self.config._attn_implementation != "eager":
                                      67         attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
                                      68
                                      69     if using_eager and self.reorder_and_upcast_attn:
 self__upcast_and_reordered_attn_0 -> 70         attn_output, attn_weights = self._upcast_and_reordered_attn(
                                      71             query_states, key_states, value_states, attention_mask, head_mask
                                      72         )
                                      73     else:
 attention_interface_0             -> 74         attn_output, attn_weights = attention_interface(
                                      75             self,
                                      76             query_states,
                                      77             key_states,
                                      78             value_states,
                                      79             attention_mask,
                                      80             head_mask=head_mask,
                                      81             dropout=self.attn_dropout.p if self.training else 0.0,
                                      82             is_causal=is_causal,
                                      83             **kwargs,
                                      84         )
                                      85
 attn_output_reshape_0             -> 86     attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
 contiguous_0                      ->  +     ...
 self_c_proj_0                     -> 87     attn_output = self.c_proj(attn_output)
 self_resid_dropout_0              -> 88     attn_output = self.resid_dropout(attn_output)
                                      89
                                      90     return attn_output, attn_weights
                                      91

Okay, so you can now see the operations within the attention. But, how would you actually access and intervene on these operations?

Within the tracing context, you can access and intervene on source values as you do with getting and setting operations on modules.

Let’s try to access the output of attention_interface_0 and save it.

[6]:
prompt = "The Eiffel Tower is located in the city of"

with model.trace(prompt):

  attention = model.transformer.h[0].attn.source.attention_interface_0.output.save()

print(attention)
(tensor([[[[-1.3066e-02, -1.4464e-02,  1.2694e-01,  ..., -4.9182e-02,
            1.0464e-01,  2.3067e-02],
          [ 5.9014e-01,  1.0051e-01, -2.0716e-01,  ..., -6.9383e-01,
           -2.7763e-01,  2.0517e-01],
          [-2.8404e-02, -1.1449e-01, -2.1676e-02,  ...,  3.9217e-03,
            7.8844e-02, -3.9936e-03],
          ...,
          [-1.0771e-01, -2.1316e-01, -2.1841e-02,  ..., -2.3210e-01,
            2.1270e-02, -6.6547e-02],
          [ 9.3223e-02, -1.0404e-01, -2.1104e-01,  ...,  1.8502e-01,
            2.2378e-01, -3.1989e-02],
          [-2.4770e-02, -3.7828e-01,  1.1838e-01,  ...,  1.1582e-02,
           -2.4843e-01, -1.1559e-01]],

         [[ 1.6556e-02, -1.4358e-02,  1.2063e-01,  ..., -4.5132e-02,
            9.7916e-02,  2.9936e-02],
          [ 6.3335e-01,  1.1630e-01,  2.4260e-01,  ...,  1.8957e-01,
            8.8069e-02, -5.1060e-02],
          [-3.0525e-02, -8.4757e-02, -3.9247e-02,  ..., -9.0845e-04,
            6.1980e-02, -5.6949e-03],
          ...,
          [-1.4259e-01, -1.0459e-01, -4.5955e-02,  ..., -1.8481e-01,
            5.3141e-03, -2.2560e-02],
          [-4.6628e-02, -6.1292e-02, -1.9690e-01,  ...,  6.9869e-02,
            2.1985e-01, -2.1646e-02],
          [-7.1120e-03, -2.6377e-01,  1.2137e-01,  ...,  6.5153e-02,
           -1.7566e-01, -6.8845e-02]],

         [[ 2.6889e-02, -7.0122e-03,  1.1766e-01,  ..., -4.4820e-02,
            1.0287e-01,  3.1002e-02],
          [ 2.7532e-01, -7.9725e-02,  2.0713e-01,  ...,  2.0177e-01,
            1.4186e-01, -1.3268e-01],
          [-2.7850e-02, -5.9402e-02, -4.4856e-02,  ..., -4.5167e-03,
            3.0954e-02,  2.5686e-03],
          ...,
          [-1.6139e-01, -4.3244e-02, -4.1564e-02,  ..., -1.3405e-01,
            3.9203e-03,  2.7092e-02],
          [ 3.5194e-03, -2.6913e-03, -1.4248e-01,  ..., -5.7047e-02,
            1.9406e-01, -7.4551e-04],
          [-5.7676e-03, -2.9221e-01,  6.9209e-02,  ...,  2.6864e-03,
           -1.3484e-01, -3.9497e-02]],

         ...,

         [[ 3.0232e-02, -2.6076e-02,  5.2120e-02,  ..., -3.7995e-02,
           -1.5027e-02, -4.8896e-02],
          [ 7.8568e-01, -7.6058e-02,  3.2838e-02,  ..., -1.9134e-01,
            1.1803e-01,  1.1202e-01],
          [ 1.5476e-01, -2.2131e-02,  1.2879e-02,  ...,  2.9907e-02,
            5.0672e-02,  3.1175e-03],
          ...,
          [-1.5960e-01, -6.9789e-02, -1.0114e-03,  ...,  4.2446e-03,
           -2.6784e-02,  4.3333e-02],
          [-5.8486e-02,  1.3786e-02, -8.2491e-02,  ...,  3.0148e-02,
            9.8252e-02, -1.6242e-02],
          [ 1.1144e-02, -1.0930e-01, -5.4984e-02,  ...,  4.0812e-02,
            4.5274e-02,  4.5370e-02]],

         [[ 3.2244e-02, -6.4355e-03,  5.8934e-02,  ..., -4.6791e-02,
           -1.0505e-02, -3.3912e-02],
          [ 3.0595e-01,  4.3599e-02,  2.6110e-01,  ...,  7.1029e-02,
            1.9688e-01, -5.6379e-03],
          [ 9.0927e-02, -3.1979e-02, -7.2353e-02,  ..., -7.3233e-03,
            2.8075e-02,  2.2513e-02],
          ...,
          [-1.6328e-01, -4.2869e-02,  5.8174e-03,  ...,  2.7562e-04,
           -2.8850e-02,  4.9361e-02],
          [-1.0240e-01,  2.9412e-02, -3.0656e-02,  ..., -5.3794e-02,
            2.4850e-02,  1.3824e-01],
          [ 2.6551e-02, -1.2249e-01, -4.9407e-02,  ...,  4.5420e-02,
            1.0131e-02,  5.6723e-02]],

         [[ 3.8645e-02, -1.6784e-02,  4.6613e-02,  ..., -6.4094e-02,
           -1.5405e-02, -2.9360e-02],
          [ 6.3216e-01, -8.6515e-02, -4.4017e-01,  ..., -1.6970e-03,
            4.7890e-01,  1.5406e-01],
          [ 1.4776e-01,  5.3006e-04,  3.9563e-02,  ...,  1.2799e-02,
            5.2635e-02, -1.7847e-02],
          ...,
          [-1.6876e-01, -4.3880e-02,  5.3133e-04,  ..., -3.3176e-02,
           -3.3048e-02,  1.0213e-02],
          [-3.7433e-02,  4.7539e-02, -5.1587e-02,  ..., -5.1953e-02,
           -6.5796e-03,  2.5365e-02],
          [ 3.4012e-02, -1.0146e-01, -3.5265e-02,  ...,  4.5592e-02,
            6.8942e-02,  7.0315e-02]]]], grad_fn=<CloneBackward0>), None)

We can also set values to intervene on intermediate variables.

[11]:
with model.trace(prompt):
  attention_0 = model.transformer.h[0].attn.source.attention_interface_0.output.save()

  # change attention of layer 7 to that of layer 0
  model.transformer.h[7].attn.source.attention_interface_0.output = attention_0

Inside of the tracing context, source also works recursively:

[7]:
with model.trace(prompt):

  print(model.transformer.h[0].attn.source.attention_interface_0.source)
                                                        * def sdpa_attention_forward(
                                                        0     module: torch.nn.Module,
                                                        1     query: torch.Tensor,
                                                        2     key: torch.Tensor,
                                                        3     value: torch.Tensor,
                                                        4     attention_mask: Optional[torch.Tensor],
                                                        5     dropout: float = 0.0,
                                                        6     scaling: Optional[float] = None,
                                                        7     is_causal: Optional[bool] = None,
                                                        8     **kwargs,
                                                        9 ) -> tuple[torch.Tensor, None]:
 kwargs_get_0                                       -> 10     if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
 kwargs_get_1                                       ->  +     ...
 logger_warning_once_0                              -> 11         logger.warning_once(
                                                       12             "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
                                                       13             " Please set your attention to `eager` if you want any of these features."
                                                       14         )
                                                       15     sdpa_kwargs = {}
 hasattr_0                                          -> 16     if hasattr(module, "num_key_value_groups"):
 use_gqa_in_sdpa_0                                  -> 17         if not use_gqa_in_sdpa(attention_mask, key):
 repeat_kv_0                                        -> 18             key = repeat_kv(key, module.num_key_value_groups)
 repeat_kv_1                                        -> 19             value = repeat_kv(value, module.num_key_value_groups)
                                                       20         else:
                                                       21             sdpa_kwargs = {"enable_gqa": True}
                                                       22
                                                       23     if attention_mask is not None and attention_mask.ndim == 4:
                                                       24         attention_mask = attention_mask[:, :, :, : key.shape[-2]]
                                                       25
                                                       26     # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
                                                       27     # Reference: https://github.com/pytorch/pytorch/issues/112577.
 query_contiguous_0                                 -> 28     query = query.contiguous()
 key_contiguous_0                                   -> 29     key = key.contiguous()
 value_contiguous_0                                 -> 30     value = value.contiguous()
                                                       31
                                                       32     # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
                                                       33     # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
                                                       34     # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
                                                       35     if is_causal is None:
                                                       36         # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
                                                       37         # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
 getattr_0                                          -> 38         is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
                                                       39
                                                       40     # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
                                                       41     # We convert it to a bool for the SDPA kernel that only accepts bools.
 torch_jit_is_tracing_0                             -> 42     if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
 isinstance_0                                       ->  +     ...
 is_causal_item_0                                   -> 43         is_causal = is_causal.item()
                                                       44
 torch_nn_functional_scaled_dot_product_attention_0 -> 45     attn_output = torch.nn.functional.scaled_dot_product_attention(
                                                       46         query,
                                                       47         key,
                                                       48         value,
                                                       49         attn_mask=attention_mask,
                                                       50         dropout_p=dropout,
                                                       51         scale=scaling,
                                                       52         is_causal=is_causal,
                                                       53         **sdpa_kwargs,
                                                       54     )
 attn_output_transpose_0                            -> 55     attn_output = attn_output.transpose(1, 2).contiguous()
 contiguous_0                                       ->  +     ...
                                                       56
                                                       57     return attn_output, None
                                                       58