Scan and Validate#
Have you encountered a situation where you are changing the tensor values in the intervention code and getting an error message that is not very helpful?
This is where “Scanning” and “Validating” can help. As the name suggests, these features help you scan the shapes of the tensors throughout the model and validate that the current tensor shapes are compatible with the model.
We can enable these helpful tools by setting the scan=True
and validate=True
flags in the trace
method.
Here is an example that demonstrates how Scan and Validate can help us debug the model:
[1]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
input = "The Eiffel Tower is in the city of"
number_of_tokens = len(model.tokenizer.encode(input))
# turn on scan and validate
with model.trace(input, scan=True, validate=True):
original_output = model.transformer.h[11].output[0].clone().save()
# we want to modify the hidden states for the last token
model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
modified_output = model.transformer.h[11].output[0].save()
print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[1], line 9
6 number_of_tokens = len(model.tokenizer.encode(input))
8 # turn on scan and validate
----> 9 with model.trace(input, scan=True, validate=True):
11 original_output = model.transformer.h[11].output[0].clone().save()
13 # we want to modify the hidden states for the last token
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
92 self.invoker.__exit__(None, None, None)
94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
21 from .globals import GlobalTracingContext
23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/base.py:72, in Context.__exit__(self, exc_type, exc_val, exc_tb)
69 graph = self.graph.stack.pop()
71 if isinstance(exc_val, BaseException):
---> 72 raise exc_val
74 self.add(graph.stack[-1], graph, *self.args, **self.kwargs)
76 if self.backend is not None:
Cell In[1], line 14
11 original_output = model.transformer.h[11].output[0].clone().save()
13 # we want to modify the hidden states for the last token
---> 14 model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
16 modified_output = model.transformer.h[11].output[0].save()
18 print("\nOriginal Output: ", original_output[0][-1])
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:126, in Proxy.__setitem__(self, key, value)
125 def __setitem__(self, key: Union[Self, Any], value: Union[Self, Any]) -> None:
--> 126 self.node.create(
127 operator.setitem,
128 self.node,
129 key,
130 value,
131 )
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/node.py:250, in Node.create(self, *args, **kwargs)
247 return value
249 # Otherwise just create the Node on the Graph like normal.
--> 250 return self.graph.create(
251 *args,
252 **kwargs,
253 )
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/graph.py:131, in Graph.create(self, target, redirect, *args, **kwargs)
128 # Redirection.
129 graph = self.stack[-1] if redirect and self.stack else self
--> 131 return self.proxy_class(self.node_class(target, *args, graph=graph, **kwargs))
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:118, in ValidatingInterventionNode.__init__(self, *args, **kwargs)
111 super().__init__(*args, **kwargs)
113 if (
114 self.attached
115 and self.fake_value is inspect._empty
116 and not Protocol.is_protocol(self.target)
117 ):
--> 118 self.fake_value = validate(self.target, *self.args, **self.kwargs)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:147, in validate(target, *args, **kwargs)
141 with FakeTensorMode(
142 allow_non_fake_inputs=True,
143 shape_env=ShapeEnv(assume_static_by_default=True),
144 ) as fake_mode:
145 with FakeCopyMode(fake_mode):
--> 147 with GlobalTracingContext.exit_global_tracing_context():
149 if backwards_check(target, *args):
150 return None
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/globals.py:100, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
96 GlobalTracingContext.PATCHER.__enter__()
98 if isinstance(exc_val, BaseException):
--> 100 raise exc_val
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:156, in validate(target, *args, **kwargs)
150 return None
152 args, kwargs = InterventionNode.prepare_inputs(
153 (args, kwargs), fake=True
154 )
--> 156 return target(
157 *args,
158 **kwargs,
159 )
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2350, in FakeCopyMode.__torch_function__(self, func, types, args, kwargs)
2348 else:
2349 with torch._C.DisableTorchFunctionSubclass():
-> 2350 return func(*args, **kwargs)
IndexError: index 10 is out of bounds for dimension 1 with size 10
Ah of course, we needed to index at number_of_tokens - 1
not number_of_tokens
.
How was nnsight
able to catch this error?
If scan
and validate
are enabled, our input is run though the model, but under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight
to record the shapes and data types of module inputs and outputs! The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight
also attempts to execute the request on the “fake”
values we recorded.
“Scanning” is what we call running “fake” inputs throught the model to collect information like shapes and types. “Validating” is what we call trying to execute the intervention proxies with “fake” inputs to see if they work. “Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate.
A word of caution
Some pytorch operations and related libraries don’t work well with fake tensors
If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.
Let’s try again with the correct indexing, and view the shape of the output before leaving the tracing context:
[2]:
with model.trace(input, scan=True, validate=True):
original_output = model.transformer.h[11].output[0].clone().save()
# we want to modify the hidden states for the last token
model.transformer.h[11].output[0][:, number_of_tokens-1, :] = 0
modified_output = model.transformer.h[11].output[0].save()
print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
Original Output: tensor([ 6.6286e+00, 1.7258e+00, 4.7969e+00, -3.8255e+00, 1.0698e+00,
-1.4242e+00, 9.2749e+00, 6.0404e+00, -3.2988e+00, -2.7030e+00,
-3.9210e-01, -5.5507e-01, 6.4831e+00, 1.4496e+00, -4.2496e-01,
-9.4764e+00, -8.5587e-01, 4.8417e+00, 1.7383e+00, -1.9535e+01,
-2.1625e+00, -5.4659e+00, 7.9303e-02, -1.2014e+00, 7.6165e-01,
1.3293e+00, -8.1798e-01, -6.6870e+00, 2.9512e+00, -3.3648e+00,
-3.5362e+00, -3.2539e+00, 3.4988e+00, 2.2232e+00, 3.4482e+00,
-4.8883e+00, 6.5206e+01, 8.3218e-01, -3.5059e-01, -1.2041e+00,
-1.8520e+00, 2.3440e+00, 3.0114e+00, 6.7492e+00, 4.4499e+00,
6.5314e+00, -3.0311e+00, 4.3609e+00, 6.4800e-01, 1.6725e+00,
3.0538e+00, 2.5054e+00, -1.9737e+00, -1.5169e+00, -3.9845e+00,
3.4548e+00, -1.7004e+00, -2.8162e+00, -2.5650e-01, -1.8362e+00,
-7.5023e+00, 1.9528e+00, 6.4437e-01, 8.7817e-01, -1.0992e+02,
8.8575e+00, 2.1479e-01, -6.6564e+00, -3.2904e-01, -2.3263e-01,
-4.9211e+00, -6.6106e-02, -3.6601e+00, -1.5331e+00, -1.2056e+00,
5.6827e+00, 6.7709e+00, -5.1224e-01, -3.9829e+00, 2.9273e+00,
-5.6971e+00, 1.6272e+00, 3.2709e+00, -2.0965e+00, -1.4081e+00,
-2.4534e+00, -4.4642e+00, -1.0931e+02, -3.8111e+00, 2.1470e-01,
-6.9690e-01, -1.1457e+00, -1.1235e+01, 5.2517e+00, -4.2227e+00,
-3.2003e+00, -7.1090e+00, 1.8102e+00, -2.4567e+00, 7.7881e-01,
-6.2834e+00, 3.7080e+00, -1.6301e+00, 6.9053e-01, -2.9357e-02,
-5.7842e-01, -2.0679e+00, -2.0271e+00, -2.0579e+00, -5.7195e+00,
5.1443e-01, 1.7420e+00, 1.3746e+00, 3.5129e+00, 8.9945e-01,
-4.1596e-01, -1.5102e+00, -1.2280e+00, 3.4264e+00, -7.5586e+00,
2.8480e+00, 6.8030e+00, -1.3625e+00, 1.1234e+01, -1.5630e+00,
2.7383e+00, 3.4384e+00, 1.0834e+01, 5.3671e-01, 8.9106e-01,
-6.9593e+00, 1.0443e+00, -3.2028e-01, 1.1285e+01, 3.6665e+00,
3.1522e+00, 2.0780e+00, -1.5473e+00, -2.9861e+01, -3.5902e+00,
-4.2159e+00, 7.0031e-03, 8.4291e+00, -3.5786e+00, 3.3005e+00,
1.3246e+00, 2.4886e+00, -3.1515e+00, 2.5345e+00, -3.1293e+00,
5.2794e+00, -6.1508e+00, 9.1936e+00, -1.7968e+00, -3.2526e+00,
-1.0222e+00, 7.4691e+00, -3.1648e+00, -1.6389e+00, -9.7188e-01,
-2.1339e+00, -5.8595e+00, -4.7614e+00, -4.3966e+00, -1.0889e+00,
-1.5380e+00, 7.0410e+00, -1.5772e+01, -2.4344e-01, 3.2804e+00,
5.1078e+00, -4.2193e+00, -2.4413e+00, -1.2236e-01, -9.5395e+00,
-3.3000e+00, 4.7484e+00, 2.1003e+00, 2.5656e+00, -4.1450e+00,
1.1324e+01, -9.2751e-01, -8.7090e-03, -1.3499e+00, -1.0883e+00,
1.2036e+00, 6.4078e-01, -2.0958e+00, -9.4459e-01, -4.6134e+00,
-4.8703e+00, 3.2674e+00, -3.2317e+00, 5.0362e+00, -3.1834e+00,
-1.3516e+01, 7.0807e+00, 2.0336e+00, 2.0479e+00, -1.9521e+00,
-8.9104e+00, 3.0803e+00, -3.2048e+00, -9.7706e-01, 3.8135e+00,
2.4048e-01, 1.3258e+00, -7.1607e-01, 1.2787e+00, -9.8548e-02,
5.1077e+00, -4.0518e+00, -2.6827e-03, -7.2934e-01, -5.4432e+00,
3.5619e+00, 6.1031e+00, -7.2877e-01, -4.0819e+00, 2.9329e+00,
3.8585e+00, -2.9784e+00, 1.1124e+00, -8.2287e+00, 2.7348e+00,
6.0236e-01, -2.4054e+00, -3.5393e+00, -1.5170e+00, -7.5092e-01,
-4.3856e+00, 5.0673e+00, 1.6784e+01, 3.3701e+00, -6.5000e+00,
7.3039e+00, 7.5358e+00, -1.9126e+00, 2.1336e+00, -3.2421e+00,
-4.4454e+00, 2.0309e+00, 1.3034e+00, -5.0879e+00, 4.4193e+00,
1.1515e+01, 7.5886e-01, 4.5374e+00, -1.0041e+01, 2.4802e+00,
-1.9640e+00, 8.6382e+00, -2.9520e-01, -2.5199e+00, -3.1697e+00,
-4.1011e+00, 2.9947e+00, 2.5317e-01, 3.3526e+00, -8.4460e-01,
-1.6096e+00, 4.6977e+00, -2.5488e+00, -3.9472e+00, -2.5825e+00,
1.1430e+00, -9.7997e+01, -6.4164e+00, -1.7173e+00, 4.5707e+00,
2.2899e+00, -4.0544e+00, 8.1279e+00, 4.0761e+00, 3.0572e+00,
1.4155e+00, 3.1557e+00, 2.0821e+00, 1.4200e+00, 2.1832e-01,
4.3272e+00, -3.4302e+00, -1.4085e+00, 2.8077e-01, -8.1994e-01,
-6.6751e+00, 2.7346e+00, 8.0669e+00, -8.9311e-01, 5.0664e-01,
2.6838e+00, -1.5756e+00, 1.8478e+00, 1.4335e+00, -3.1085e+00,
-3.5366e+00, -2.3190e+00, -2.4223e+00, 4.6886e-01, 2.7125e+00,
-4.6714e-01, 1.9403e+00, 1.6051e+00, 4.9405e+00, 5.9342e+00,
6.0049e+00, 1.0645e+00, 3.4900e+00, 4.0834e+00, 7.7648e-01,
2.0025e+00, 2.8585e+00, 4.6693e-01, -4.4440e-01, -5.1792e+01,
-3.7743e+00, 1.0210e+00, -4.8536e-01, -3.3734e+00, 1.2150e+00,
2.1420e+00, 6.3238e-01, -2.1061e-01, -1.8375e-01, -5.4864e-01,
6.0235e+00, 8.7014e+00, -7.0699e+00, 3.2431e+00, 1.7484e+00,
-7.2326e+00, 3.2734e+00, 3.1882e+00, 6.0536e+00, 9.2431e-01,
-5.6773e+00, -4.9881e+00, 1.4860e+00, -3.0162e+00, -3.5636e+00,
-8.4084e-01, -1.0074e+00, -1.6067e+01, -1.3144e+01, 1.1817e-01,
6.4546e+00, -1.0912e+00, 1.1206e+01, -1.3724e+00, 2.7155e+00,
-7.4177e+00, -1.2955e+00, -1.6902e+00, 9.5223e+00, 2.4952e+00,
-3.7531e+00, 5.1116e+00, 3.8512e+00, 3.7123e+00, 7.6354e-01,
-3.8020e+00, 1.8983e+00, 4.3574e+01, -7.3467e+00, -2.8905e+00,
-4.3620e+00, 3.8599e+00, 1.2109e+00, 2.5844e+00, -4.4734e+00,
-2.9198e+00, 2.9884e+00, -1.7416e+00, -1.6009e+02, -1.6512e+01,
-9.0769e-02, 6.8859e+00, -3.4945e+00, 6.9834e+00, 6.2670e+00,
5.6965e-01, 3.0963e-02, 2.9006e+00, 8.9759e-01, -2.3006e+00,
8.3767e-01, -2.9463e+00, -4.3662e+00, -2.5841e+00, 9.1415e+00,
-5.5953e+00, -4.7413e+00, -3.7179e+00, 2.1987e+01, 1.8265e+00,
-8.9532e-01, 4.1599e+00, -3.9978e+00, -3.8641e+00, -2.0673e+00,
3.0306e+00, -2.1554e+00, -2.0086e+00, 4.8663e+00, 2.8927e-01,
-4.1999e+00, 8.5062e+00, 4.2103e+00, 1.7206e+01, -3.5376e+00,
2.5672e-01, 2.2229e-01, 6.6561e+00, -4.8118e+00, -3.6561e+00,
-6.8825e-01, -4.4886e+00, -9.2805e+00, -3.8699e-01, -1.3017e+00,
2.3756e+00, 4.2634e+00, -3.6821e+00, -6.7862e-01, -3.6553e+00,
-3.7938e+00, -3.6224e+00, 1.7574e+00, 3.0822e+00, -3.9100e-01,
7.3470e+01, -5.4397e-01, 9.8533e-01, -3.3989e+00, -1.3025e+00,
-2.0689e+00, 2.5526e+00, -2.0326e+00, 3.0746e+00, -2.0950e+00,
4.1649e+00, 6.4935e-01, -7.1394e+00, -1.0906e+00, -7.1690e+00,
3.2226e+00, -1.2760e+00, 1.4431e+02, 6.0973e+00, -1.3528e+00,
-4.9426e+00, -8.5969e-01, -8.5405e+00, 5.7662e+00, -2.4922e+00,
-7.4357e+00, -7.8989e+00, 4.2059e+00, -3.9071e+00, -2.3410e+00,
-1.0605e+00, -1.5846e+00, -6.7736e+00, 5.0952e-01, 1.5078e+01,
-3.7115e-02, 3.4718e+00, -2.8174e+00, -3.7045e+00, 1.0058e+01,
-2.1485e+00, 3.8376e+00, -8.5942e-01, 1.8444e+00, 3.8978e+00,
3.0776e+00, -1.0642e+00, 4.3514e+00, 2.2874e+00, 5.2097e+00,
-9.6646e+01, 1.6328e+02, -1.2562e-01, 1.4275e+00, 1.6121e+00,
-8.4209e+00, 1.1070e+00, -4.9542e+00, 1.7836e+00, -1.0035e+01,
-2.0094e+00, 2.7393e+00, 5.4365e-01, -7.8556e-01, -3.2903e-01,
5.5690e+00, 1.6298e+02, 1.2884e+00, 6.0525e+00, 2.3517e+00,
-5.1991e+00, -8.9318e+00, 3.2929e+00, -1.2494e+01, 4.4547e+00,
-4.6448e-01, 1.0015e+00, 3.5933e-01, 2.1539e+00, 8.7484e-01,
-1.9571e+00, 4.6167e+00, -4.1041e+00, 4.1621e+00, -3.2981e+00,
5.2514e-02, 6.9099e+00, 1.9714e-01, 8.3853e+00, 1.2267e+00,
-3.2992e+00, 4.8448e+00, 2.9776e+00, -2.0985e-01, 1.9363e+00,
-3.0710e+00, -4.9856e+00, 9.9332e-01, 1.8208e+00, 6.0879e+00,
3.6372e+00, -2.4278e-01, 2.3095e+00, -1.1847e+00, 5.4024e+00,
4.2026e+00, 2.0059e-01, -6.7174e+00, -6.0502e+00, -7.2487e+00,
4.4279e+00, -2.2406e+00, 5.9508e+00, 6.9944e+00, 1.2061e+00,
-3.2928e+00, -1.5006e+00, 2.5605e+00, -1.6327e+00, 6.6576e+00,
4.1315e+00, -2.5557e+00, 2.7006e+00, 6.6053e+00, -1.9167e+00,
-1.5669e+00, -2.4749e+00, 4.2183e-01, 2.4309e+00, 1.5208e-01,
-9.1254e+00, -3.7077e+00, 4.5885e+00, -1.9774e+00, 1.3449e+01,
3.4506e+00, -7.2922e+00, -5.2382e-01, 1.1032e+00, 2.7615e+00,
-6.3168e+00, -3.8679e-02, 6.5301e+00, -3.0306e-01, 1.4317e+00,
-3.4285e+00, -4.6842e+00, -3.0345e+00, 5.0907e+00, -2.3043e+00,
4.8652e+00, -3.4368e+00, -8.0064e+00, -6.0210e+00, 5.9621e+00,
1.6881e+00, -1.6543e+00, -6.9820e+00, 7.2362e-01, -7.4828e+00,
-1.5182e+00, -1.0165e+00, -4.8026e+00, -2.3627e+00, 2.1291e+00,
-6.0510e+00, -4.2262e+00, -2.6994e+00, 2.1767e-01, -3.2933e+00,
6.8151e+00, -2.8563e+00, -1.1988e+00, 1.6883e+00, 1.8934e+00,
9.4241e+00, 9.5912e-01, 2.1459e+00, 1.3060e+01, -4.9042e+00,
-5.1736e+00, 8.4661e+00, 7.1147e-01, -1.0368e+01, 3.1779e+00,
1.8113e+00, 7.0269e+00, 2.9996e+00, -2.8506e+00, -7.3495e-01,
-8.3824e+00, -2.6062e+00, 5.2471e+00, 6.6041e+00, -8.8814e-01,
1.3165e-01, 7.5560e-01, 6.3661e+00, 6.7891e-01, -1.6838e+00,
3.9716e-02, 4.6225e-01, -5.2465e+00, 5.9410e+00, -4.3000e+00,
-4.3974e+00, -6.8152e+00, -4.7801e+00, -4.4493e+00, -4.0902e-01,
1.5374e+01, 2.9918e+00, 1.8610e+00, -3.2574e+00, -3.8129e+00,
-5.0761e+00, 9.2315e+00, -1.0814e+00, 2.8385e-01, -3.2490e+00,
3.4800e+00, -5.8052e+00, 1.3061e+00, -2.8434e+00, 6.5940e-02,
3.2905e-02, -1.9720e+00, -3.3967e+00, -7.4260e+00, 6.1623e+00,
-4.1470e+00, -4.2324e+00, 9.5378e-01, -1.0738e+00, -1.3569e-01,
-2.2882e+00, 4.0768e+00, 2.7376e-01, -1.5838e+00, 1.1146e+01,
6.9866e+00, 1.3019e+01, -1.8275e+00, -6.4068e+00, 1.7844e+00,
-8.6919e-01, -2.9575e+00, 4.8201e-01, -6.8251e+00, -2.5988e+00,
-8.7482e+00, 8.0484e-01, 2.6259e+00, -4.5771e+00, -5.8153e-01,
4.7844e+00, -3.3178e+00, 9.6059e+00, 4.9834e+00, 5.5188e-01,
-3.0826e-01, -2.0525e+00, -4.3834e+00, 5.7229e+00, 3.6664e+00,
4.6417e-02, 1.0723e+00, -5.3191e+00, 1.4131e+00, -9.6074e-01,
-3.2504e+00, -2.3741e+00, 2.7716e+00, -1.4227e+00, -9.3925e-01,
-7.5700e+00, -3.7153e+00, 1.3690e+00, 1.1079e+01, -8.3346e+00,
-4.9114e+00, -2.0677e+00, -1.0595e+00, -1.8981e+00, 6.1321e+00,
1.9002e+00, -1.9220e-02, -1.1126e+00, -1.0669e+01, -3.0803e+00,
-3.5232e+00, 1.5203e+00, -4.8918e+00, 2.8023e+00, 7.7897e+00,
3.6398e+00, -6.3831e-01, -5.3203e+00, 1.7896e+00, -4.1591e+00,
-4.2744e-01, 3.0674e+00, 4.1897e+00, -2.5547e+00, 3.0816e-01,
-3.4222e-02, -3.2085e+00, 3.0132e+00, -5.1880e+00, 8.2855e-01,
-9.3510e-01, -2.3891e-01, 3.3661e+00, 2.3398e+00, -7.5823e+00,
-1.1914e+01, 5.7218e+00, -2.5562e+00, -2.9274e+00, 3.1782e+00,
-1.9518e+00, 4.3836e+00, 3.5140e+00, -2.8808e+00, -3.9965e+00,
1.6073e+00, -5.4745e+00, -1.4041e+00, -2.8542e+00, -2.0609e+00,
8.0075e-02, -2.6370e+00, -4.4448e+00, -6.0635e+00, -4.2056e+00,
7.6714e+00, 3.0682e+00, 2.0481e+00], device='mps:0',
grad_fn=<SelectBackward0>)
Modified Output: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='mps:0', grad_fn=<SelectBackward0>)
We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:
[3]:
with model.scan(input):
dim = model.transformer.h[11].output[0].shape[-1]
print(dim)
768