Source Code Interventions#
Summary#
You can use NNsight to access the intermediate vales of a forward pass! We can print the .source
of a module to see its forward pass and associated names of each operation.
# .source to print the first transformer
print(model.transformer.h[0].attn.source)
You can also use source
to get and set intermediate values during the tracing context.
with model.trace(prompt):
# get intermediate value
attention = model.transformer.h[0].attn.source.attention_interface_0.output.save()
# set intermediate value
model.transformer.h[k].attn.source.attention_interface_0.output = attention
When to Use#
One of the most common use-cases for source
is accessing attention.
How to Use#
To view the intermediate values of a forward pass, you can call .source
on a module. Printing .source
of a module will help you see the names of the forward pass operations.
[5]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
print(model.transformer.h[0].attn.source)
* def forward(
0 self,
1 hidden_states: Optional[tuple[torch.FloatTensor]],
2 past_key_value: Optional[Cache] = None,
3 cache_position: Optional[torch.LongTensor] = None,
4 attention_mask: Optional[torch.FloatTensor] = None,
5 head_mask: Optional[torch.FloatTensor] = None,
6 encoder_hidden_states: Optional[torch.Tensor] = None,
7 encoder_attention_mask: Optional[torch.FloatTensor] = None,
8 output_attentions: Optional[bool] = False,
9 **kwargs,
10 ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
11 is_cross_attention = encoder_hidden_states is not None
12 if past_key_value is not None:
isinstance_0 -> 13 if isinstance(past_key_value, EncoderDecoderCache):
past_key_value_is_updated_get_0 -> 14 is_updated = past_key_value.is_updated.get(self.layer_idx)
15 if is_cross_attention:
16 # after the first generated id, we can subsequently re-use all key/value_layer from cache
17 curr_past_key_value = past_key_value.cross_attention_cache
18 else:
19 curr_past_key_value = past_key_value.self_attention_cache
20 else:
21 curr_past_key_value = past_key_value
22
23 if is_cross_attention:
hasattr_0 -> 24 if not hasattr(self, "q_attn"):
ValueError_0 -> 25 raise ValueError(
26 "If class is used as cross attention, the weights `q_attn` have to be defined. "
27 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
28 )
self_q_attn_0 -> 29 query_states = self.q_attn(hidden_states)
30 attention_mask = encoder_attention_mask
31
32 # Try to get key/value states from cache if possible
33 if past_key_value is not None and is_updated:
34 key_states = curr_past_key_value.layers[self.layer_idx].keys
35 value_states = curr_past_key_value.layers[self.layer_idx].values
36 else:
self_c_attn_0 -> 37 key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
split_0 -> + ...
38 shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
key_states_view_0 -> 39 key_states = key_states.view(shape_kv).transpose(1, 2)
transpose_0 -> + ...
value_states_view_0 -> 40 value_states = value_states.view(shape_kv).transpose(1, 2)
transpose_1 -> + ...
41 else:
self_c_attn_1 -> 42 query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
split_1 -> + ...
43 shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
key_states_view_1 -> 44 key_states = key_states.view(shape_kv).transpose(1, 2)
transpose_2 -> + ...
value_states_view_1 -> 45 value_states = value_states.view(shape_kv).transpose(1, 2)
transpose_3 -> + ...
46
47 shape_q = (*query_states.shape[:-1], -1, self.head_dim)
query_states_view_0 -> 48 query_states = query_states.view(shape_q).transpose(1, 2)
transpose_4 -> + ...
49
50 if (past_key_value is not None and not is_cross_attention) or (
51 past_key_value is not None and is_cross_attention and not is_updated
52 ):
53 # save all key/value_layer to cache to be re-used for fast auto-regressive generation
54 cache_position = cache_position if not is_cross_attention else None
curr_past_key_value_update_0 -> 55 key_states, value_states = curr_past_key_value.update(
56 key_states, value_states, self.layer_idx, {"cache_position": cache_position}
57 )
58 # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
59 if is_cross_attention:
60 past_key_value.is_updated[self.layer_idx] = True
61
62 is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
63
64 using_eager = self.config._attn_implementation == "eager"
65 attention_interface: Callable = eager_attention_forward
66 if self.config._attn_implementation != "eager":
67 attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
68
69 if using_eager and self.reorder_and_upcast_attn:
self__upcast_and_reordered_attn_0 -> 70 attn_output, attn_weights = self._upcast_and_reordered_attn(
71 query_states, key_states, value_states, attention_mask, head_mask
72 )
73 else:
attention_interface_0 -> 74 attn_output, attn_weights = attention_interface(
75 self,
76 query_states,
77 key_states,
78 value_states,
79 attention_mask,
80 head_mask=head_mask,
81 dropout=self.attn_dropout.p if self.training else 0.0,
82 is_causal=is_causal,
83 **kwargs,
84 )
85
attn_output_reshape_0 -> 86 attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
contiguous_0 -> + ...
self_c_proj_0 -> 87 attn_output = self.c_proj(attn_output)
self_resid_dropout_0 -> 88 attn_output = self.resid_dropout(attn_output)
89
90 return attn_output, attn_weights
91
Okay, so you can now see the operations within the attention. But, how would you actually access and intervene on these operations?
Within the tracing context, you can access and intervene on source
values as you do with getting and setting operations on modules.
Let’s try to access the output of attention_interface_0
and save it.
[6]:
prompt = "The Eiffel Tower is located in the city of"
with model.trace(prompt):
attention = model.transformer.h[0].attn.source.attention_interface_0.output.save()
print(attention)
(tensor([[[[-1.3066e-02, -1.4464e-02, 1.2694e-01, ..., -4.9182e-02,
1.0464e-01, 2.3067e-02],
[ 5.9014e-01, 1.0051e-01, -2.0716e-01, ..., -6.9383e-01,
-2.7763e-01, 2.0517e-01],
[-2.8404e-02, -1.1449e-01, -2.1676e-02, ..., 3.9217e-03,
7.8844e-02, -3.9936e-03],
...,
[-1.0771e-01, -2.1316e-01, -2.1841e-02, ..., -2.3210e-01,
2.1270e-02, -6.6547e-02],
[ 9.3223e-02, -1.0404e-01, -2.1104e-01, ..., 1.8502e-01,
2.2378e-01, -3.1989e-02],
[-2.4770e-02, -3.7828e-01, 1.1838e-01, ..., 1.1582e-02,
-2.4843e-01, -1.1559e-01]],
[[ 1.6556e-02, -1.4358e-02, 1.2063e-01, ..., -4.5132e-02,
9.7916e-02, 2.9936e-02],
[ 6.3335e-01, 1.1630e-01, 2.4260e-01, ..., 1.8957e-01,
8.8069e-02, -5.1060e-02],
[-3.0525e-02, -8.4757e-02, -3.9247e-02, ..., -9.0845e-04,
6.1980e-02, -5.6949e-03],
...,
[-1.4259e-01, -1.0459e-01, -4.5955e-02, ..., -1.8481e-01,
5.3141e-03, -2.2560e-02],
[-4.6628e-02, -6.1292e-02, -1.9690e-01, ..., 6.9869e-02,
2.1985e-01, -2.1646e-02],
[-7.1120e-03, -2.6377e-01, 1.2137e-01, ..., 6.5153e-02,
-1.7566e-01, -6.8845e-02]],
[[ 2.6889e-02, -7.0122e-03, 1.1766e-01, ..., -4.4820e-02,
1.0287e-01, 3.1002e-02],
[ 2.7532e-01, -7.9725e-02, 2.0713e-01, ..., 2.0177e-01,
1.4186e-01, -1.3268e-01],
[-2.7850e-02, -5.9402e-02, -4.4856e-02, ..., -4.5167e-03,
3.0954e-02, 2.5686e-03],
...,
[-1.6139e-01, -4.3244e-02, -4.1564e-02, ..., -1.3405e-01,
3.9203e-03, 2.7092e-02],
[ 3.5194e-03, -2.6913e-03, -1.4248e-01, ..., -5.7047e-02,
1.9406e-01, -7.4551e-04],
[-5.7676e-03, -2.9221e-01, 6.9209e-02, ..., 2.6864e-03,
-1.3484e-01, -3.9497e-02]],
...,
[[ 3.0232e-02, -2.6076e-02, 5.2120e-02, ..., -3.7995e-02,
-1.5027e-02, -4.8896e-02],
[ 7.8568e-01, -7.6058e-02, 3.2838e-02, ..., -1.9134e-01,
1.1803e-01, 1.1202e-01],
[ 1.5476e-01, -2.2131e-02, 1.2879e-02, ..., 2.9907e-02,
5.0672e-02, 3.1175e-03],
...,
[-1.5960e-01, -6.9789e-02, -1.0114e-03, ..., 4.2446e-03,
-2.6784e-02, 4.3333e-02],
[-5.8486e-02, 1.3786e-02, -8.2491e-02, ..., 3.0148e-02,
9.8252e-02, -1.6242e-02],
[ 1.1144e-02, -1.0930e-01, -5.4984e-02, ..., 4.0812e-02,
4.5274e-02, 4.5370e-02]],
[[ 3.2244e-02, -6.4355e-03, 5.8934e-02, ..., -4.6791e-02,
-1.0505e-02, -3.3912e-02],
[ 3.0595e-01, 4.3599e-02, 2.6110e-01, ..., 7.1029e-02,
1.9688e-01, -5.6379e-03],
[ 9.0927e-02, -3.1979e-02, -7.2353e-02, ..., -7.3233e-03,
2.8075e-02, 2.2513e-02],
...,
[-1.6328e-01, -4.2869e-02, 5.8174e-03, ..., 2.7562e-04,
-2.8850e-02, 4.9361e-02],
[-1.0240e-01, 2.9412e-02, -3.0656e-02, ..., -5.3794e-02,
2.4850e-02, 1.3824e-01],
[ 2.6551e-02, -1.2249e-01, -4.9407e-02, ..., 4.5420e-02,
1.0131e-02, 5.6723e-02]],
[[ 3.8645e-02, -1.6784e-02, 4.6613e-02, ..., -6.4094e-02,
-1.5405e-02, -2.9360e-02],
[ 6.3216e-01, -8.6515e-02, -4.4017e-01, ..., -1.6970e-03,
4.7890e-01, 1.5406e-01],
[ 1.4776e-01, 5.3006e-04, 3.9563e-02, ..., 1.2799e-02,
5.2635e-02, -1.7847e-02],
...,
[-1.6876e-01, -4.3880e-02, 5.3133e-04, ..., -3.3176e-02,
-3.3048e-02, 1.0213e-02],
[-3.7433e-02, 4.7539e-02, -5.1587e-02, ..., -5.1953e-02,
-6.5796e-03, 2.5365e-02],
[ 3.4012e-02, -1.0146e-01, -3.5265e-02, ..., 4.5592e-02,
6.8942e-02, 7.0315e-02]]]], grad_fn=<CloneBackward0>), None)
We can also set values to intervene on intermediate variables.
[11]:
with model.trace(prompt):
attention_0 = model.transformer.h[0].attn.source.attention_interface_0.output.save()
# change attention of layer 7 to that of layer 0
model.transformer.h[7].attn.source.attention_interface_0.output = attention_0
Inside of the tracing context, source
also works recursively:
[7]:
with model.trace(prompt):
print(model.transformer.h[0].attn.source.attention_interface_0.source)
* def sdpa_attention_forward(
0 module: torch.nn.Module,
1 query: torch.Tensor,
2 key: torch.Tensor,
3 value: torch.Tensor,
4 attention_mask: Optional[torch.Tensor],
5 dropout: float = 0.0,
6 scaling: Optional[float] = None,
7 is_causal: Optional[bool] = None,
8 **kwargs,
9 ) -> tuple[torch.Tensor, None]:
kwargs_get_0 -> 10 if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
kwargs_get_1 -> + ...
logger_warning_once_0 -> 11 logger.warning_once(
12 "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
13 " Please set your attention to `eager` if you want any of these features."
14 )
15 sdpa_kwargs = {}
hasattr_0 -> 16 if hasattr(module, "num_key_value_groups"):
use_gqa_in_sdpa_0 -> 17 if not use_gqa_in_sdpa(attention_mask, key):
repeat_kv_0 -> 18 key = repeat_kv(key, module.num_key_value_groups)
repeat_kv_1 -> 19 value = repeat_kv(value, module.num_key_value_groups)
20 else:
21 sdpa_kwargs = {"enable_gqa": True}
22
23 if attention_mask is not None and attention_mask.ndim == 4:
24 attention_mask = attention_mask[:, :, :, : key.shape[-2]]
25
26 # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
27 # Reference: https://github.com/pytorch/pytorch/issues/112577.
query_contiguous_0 -> 28 query = query.contiguous()
key_contiguous_0 -> 29 key = key.contiguous()
value_contiguous_0 -> 30 value = value.contiguous()
31
32 # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
33 # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
34 # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
35 if is_causal is None:
36 # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
37 # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
getattr_0 -> 38 is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
39
40 # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
41 # We convert it to a bool for the SDPA kernel that only accepts bools.
torch_jit_is_tracing_0 -> 42 if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
isinstance_0 -> + ...
is_causal_item_0 -> 43 is_causal = is_causal.item()
44
torch_nn_functional_scaled_dot_product_attention_0 -> 45 attn_output = torch.nn.functional.scaled_dot_product_attention(
46 query,
47 key,
48 value,
49 attn_mask=attention_mask,
50 dropout_p=dropout,
51 scale=scaling,
52 is_causal=is_causal,
53 **sdpa_kwargs,
54 )
attn_output_transpose_0 -> 55 attn_output = attn_output.transpose(1, 2).contiguous()
contiguous_0 -> + ...
56
57 return attn_output, None
58
Related#
Getting
Setting