Early Stopping#
If we are only interested in a model’s intermediate computations, we can halt a forward pass run at any module level, reducing runtime and conserving compute resources. One examples where this could be particularly useful would if we are working with SAEs - we can train an SAE on one layer and then stop the execution.
[1]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
with model.trace("The Eiffel Tower is in the city of"):
l1_out = model.transformer.h[0].output.save()
model.transformer.h[0].output.stop()
# get the output of the first layer and stop tracing
print("L1 - Output: ", l1_out)
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
L1 - Output: (tensor([[[ 0.1559, -0.7946, 0.3943, ..., -1.7854, -0.4689, 0.1554],
[-0.0624, -0.3135, 0.7256, ..., -1.5666, -0.4111, -0.5625],
[-0.3308, -0.3006, -0.9636, ..., -1.3112, -0.5972, 0.3754],
...,
[-1.3608, 0.6902, 0.2851, ..., 0.1985, 0.1965, -0.5839],
[-2.4975, 0.6122, -2.2314, ..., 0.7751, 0.3189, -0.2851],
[ 1.6383, 0.9466, 0.6050, ..., -0.2360, 0.1267, 0.2722]]],
device='mps:0', grad_fn=<AddBackward0>), (tensor([[[[-0.9420, 1.9023, 0.8722, ..., -1.2703, -0.4792, 1.2469],
[-1.9590, 2.7141, 2.8423, ..., -1.1633, -1.6173, 2.1507],
[-2.6123, 2.0937, 0.9679, ..., -0.9763, -1.2243, 2.0279],
...,
[-2.4282, 2.4462, 2.1550, ..., -0.5916, -1.6641, 2.1119],
[-3.5624, 3.6804, 2.5053, ..., -0.3572, -2.5960, 0.9592],
[-2.6021, 2.8035, 1.7291, ..., -0.8557, -2.1589, 2.6881]],
[[ 0.1103, 0.6967, -1.1409, ..., -0.1243, 1.8249, -0.1592],
[ 0.3364, -2.3421, -3.0033, ..., -0.9075, 3.9665, 0.2082],
[-1.2822, -2.8345, 0.1537, ..., 0.6516, 2.4424, 0.7518],
...,
[-0.1554, -1.0321, -2.5109, ..., -0.9747, 4.8222, -1.8171],
[-1.3993, -2.2428, -0.0644, ..., -0.9444, 3.5096, 0.4326],
[ 0.5759, -0.8102, -1.8774, ..., -1.4308, 3.0181, -2.2393]],
[[-0.0985, -0.0323, 0.7536, ..., -1.1902, -1.6401, 0.6545],
[ 1.1513, -0.7019, 0.2992, ..., -1.8075, -0.1072, 2.0486],
[-0.1089, -1.0244, 0.4639, ..., -1.8416, -0.2348, 1.0322],
...,
[ 0.3554, 0.3485, 0.0083, ..., -3.3077, 0.8817, 1.4423],
[ 0.3027, 0.2488, -0.2483, ..., -2.8617, 0.7589, 0.7380],
[ 0.5146, -0.1207, 0.6076, ..., -2.7679, 1.1289, 1.7932]],
...,
[[ 0.6009, -0.0877, -0.2693, ..., 0.1756, 0.7995, 0.5978],
[ 0.0456, 0.2891, -0.1535, ..., 1.0184, 1.0627, 0.3627],
[-0.0681, -0.5138, 0.5735, ..., 0.7821, 0.8516, 0.4657],
...,
[-0.0383, -0.2532, 0.0525, ..., 1.2245, 0.5464, 0.4056],
[ 0.2111, 0.9947, 0.0403, ..., 1.1817, -0.7079, 0.5290],
[ 0.1136, -0.0611, 0.1199, ..., 1.2025, 0.4589, 0.6644]],
[[ 1.4709, 1.5225, -0.4336, ..., -0.1837, 1.0947, -1.6615],
[ 0.7999, 0.0324, -1.5696, ..., -0.7550, 1.4671, -1.1099],
[ 0.6255, 0.4108, 0.0984, ..., -1.2564, 1.9016, -0.0603],
...,
[ 1.1553, 0.5795, -0.6220, ..., -0.7993, 0.4428, -0.6729],
[ 1.5863, -0.0730, 0.1822, ..., -0.5310, 0.4560, -0.4558],
[ 0.9170, -0.7168, -0.4214, ..., -0.8926, 0.4736, -0.0411]],
[[ 0.6260, 0.2122, 0.2527, ..., -0.6377, 0.2275, 1.5142],
[-0.3332, 1.5151, -0.3315, ..., 1.2160, 0.2653, 2.6735],
[ 0.1930, 0.0467, -0.3682, ..., -0.1827, 0.1576, 0.5612],
...,
[-0.1787, 1.2580, -0.2565, ..., -0.6601, 1.2289, 0.2853],
[ 0.9067, 0.6444, 0.2020, ..., 0.1291, 0.2002, 0.8276],
[-0.5779, 0.4654, -0.8867, ..., 1.4954, 1.3435, -0.6073]]]],
device='mps:0', grad_fn=<PermuteBackward0>), tensor([[[[-1.3066e-02, -1.4464e-02, 1.2694e-01, ..., -4.9182e-02,
1.0464e-01, 2.3067e-02],
[ 2.0469e-01, -1.3684e-02, 8.0588e-02, ..., -1.9410e-02,
5.5186e-02, 7.3562e-02],
[ 7.2232e-03, 1.8508e-01, 1.0139e-01, ..., -7.6448e-02,
2.9932e-01, -8.5229e-03],
...,
[-1.6686e-01, 1.9638e-02, 1.2153e-01, ..., 6.1965e-02,
9.3590e-02, -1.0460e-01],
[ 1.5657e-01, 1.5053e-01, 5.7654e-02, ..., -4.2498e-01,
-5.2136e-02, 3.0045e-02],
[-1.0558e-02, -8.6992e-02, -7.6297e-02, ..., -6.3531e-02,
-5.0926e-02, 1.9987e-01]],
[[ 5.9014e-01, 1.0051e-01, -2.0716e-01, ..., -6.9383e-01,
-2.7763e-01, 2.0517e-01],
[ 6.3339e-01, 1.1631e-01, 2.4300e-01, ..., 1.9035e-01,
8.8391e-02, -5.1286e-02],
[ 2.7510e-01, -7.9842e-02, 2.0712e-01, ..., 2.0180e-01,
1.4190e-01, -1.3274e-01],
...,
[ 8.3555e-01, -9.4205e-02, 7.4023e-02, ..., -1.7617e-01,
1.3164e-01, 1.1117e-01],
[ 3.2692e-01, 4.5032e-02, 2.5904e-01, ..., 7.9349e-02,
2.0154e-01, -5.9559e-03],
[ 6.5303e-01, -8.9489e-02, -4.5211e-01, ..., 7.0380e-04,
4.9327e-01, 1.5887e-01]],
[[-2.8404e-02, -1.1449e-01, -2.1676e-02, ..., 3.9217e-03,
7.8844e-02, -3.9935e-03],
[-4.9779e-02, 1.8518e-01, -1.9874e-01, ..., -4.4753e-02,
-9.1101e-02, -2.1138e-02],
[ 4.4283e-02, 5.9255e-02, 5.3522e-02, ..., -1.4617e-02,
-3.3558e-01, 1.7041e-01],
...,
[-5.3011e-01, -6.1408e-04, -6.0240e-01, ..., -1.9026e-01,
-7.1861e-02, 3.2019e-01],
[ 4.0913e-01, -1.1379e-01, -1.6436e-01, ..., 8.2217e-02,
-1.0437e-01, -7.6691e-02],
[ 3.1223e-01, 3.6828e-01, 6.0183e-01, ..., -2.8972e-02,
1.1367e-01, -2.5661e-01]],
...,
[[-1.0771e-01, -2.1316e-01, -2.1841e-02, ..., -2.3210e-01,
2.1270e-02, -6.6547e-02],
[-3.1855e-01, 4.4327e-01, -1.6764e-01, ..., 5.3822e-02,
-7.5202e-02, 1.9941e-01],
[-1.4129e-01, -6.2872e-02, 2.3989e-01, ..., 2.2710e-01,
1.2402e-01, 4.0053e-01],
...,
[-1.0040e-01, -4.9095e-01, 2.2476e-02, ..., 5.5608e-02,
-1.4735e-01, 2.3780e-01],
[-3.0879e-01, 7.1592e-01, 1.2739e-01, ..., 2.9476e-02,
-1.5573e-01, -1.7634e-02],
[-1.6235e-01, -2.5231e-01, -6.0719e-02, ..., -3.7746e-01,
-6.9727e-03, -2.2533e-01]],
[[ 9.3223e-02, -1.0404e-01, -2.1104e-01, ..., 1.8502e-01,
2.2378e-01, -3.1989e-02],
[-4.5714e-01, 6.4180e-02, -1.5538e-01, ..., -2.6815e-01,
2.0829e-01, 8.7156e-03],
[ 2.4635e-03, 1.8372e-01, 7.3724e-03, ..., -4.8131e-01,
1.2558e-01, 6.1276e-02],
...,
[-2.2317e-01, 1.2418e-01, -3.6774e-02, ..., 2.8985e-01,
5.9641e-02, -8.6952e-03],
[-1.8944e-01, -1.7414e-02, -2.9084e-02, ..., -4.5319e-02,
-5.7796e-02, 4.7680e-01],
[ 9.2804e-02, 9.9442e-02, -6.0471e-02, ..., -7.9065e-02,
-1.6836e-01, 7.1764e-02]],
[[-2.4770e-02, -3.7828e-01, 1.1838e-01, ..., 1.1582e-02,
-2.4843e-01, -1.1559e-01],
[ 5.8631e-02, 1.6256e-01, 1.3249e-01, ..., 2.6460e-01,
9.5267e-02, 1.0518e-01],
[ 5.0756e-02, -1.4601e-01, -2.3191e-01, ..., -2.2047e-01,
3.0730e-01, 2.6307e-01],
...,
[ 2.8192e-02, 1.9202e-01, -8.7550e-02, ..., 6.9838e-02,
-4.0262e-02, -5.9196e-03],
[ 2.6708e-01, 1.3450e-01, -8.2224e-02, ..., -6.0217e-04,
-1.4364e-01, 1.5347e-01],
[ 1.5456e-01, -1.1916e-01, 2.8118e-01, ..., 1.1415e-01,
2.5977e-01, 1.8767e-01]]]], device='mps:0',
grad_fn=<PermuteBackward0>)), None)
Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.
As a result, if the forward pass is terminated early any interventions linked to modules beyond that point will be skipped, even if they were defined earlier in the context.
In the example below, the output of layer 2 cannot be accessed since the model’s execution was stopped at layer 1.
[2]:
with model.trace("The Eiffel Tower is in the city of"):
l2_out = model.transformer.h[1].output.save()
model.transformer.h[0].output.stop()
print("L2 - Output: ", l2_out)
L2 - Output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 5
2 l2_out = model.transformer.h[1].output.save()
3 model.transformer.h[0].output.stop()
----> 5 print("L2 - Output: ", l2_out)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:70, in Proxy.__str__(self)
66 def __str__(self) -> str:
68 if not self.node.attached:
---> 70 return str(self.value)
72 return f"{type(self).__name__} ({self.node.target.__name__})"
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:64, in Proxy.value(self)
56 @property
57 def value(self) -> Any:
58 """Property to return the value of this proxy's node.
59
60 Returns:
61 Any: The stored value of the proxy, populated during execution of the model.
62 """
---> 64 return self.node.value
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/node.py:143, in Node.value(self)
133 """Property to return the value of this node.
134
135 Returns:
(...)
139 ValueError: If the underlying ._value is inspect._empty (therefore never set or was destroyed).
140 """
142 if not self.done:
--> 143 raise ValueError("Accessing value before it's been set.")
145 return self._value
ValueError: Accessing value before it's been set.