Scan and Validate#
Have you encountered a situation where you are changing the tensor values in the intervention code and getting an error message that is not very helpful?
This is where “Scanning” and “Validating” can help. As the name suggests, these features help you scan the shapes of the tensors throughout the model and validate that the current tensor shapes are compatible with the model.
We can enable these helpful tools by setting the scan=True
and validate=True
flags in the trace
method.
Here is an example that demonstrates how Scan and Validate can help us debug the model:
[1]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
input = "The Eiffel Tower is in the city of"
number_of_tokens = len(model.tokenizer.encode(input))
# turn on scan and validate
with model.trace(input, scan=True, validate=True):
original_output = model.transformer.h[11].output[0].clone().save()
# we want to modify the hidden states for the last token
model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
modified_output = model.transformer.h[11].output[0].save()
print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[1], line 9
6 number_of_tokens = len(model.tokenizer.encode(input))
8 # turn on scan and validate
----> 9 with model.trace(input, scan=True, validate=True):
11 original_output = model.transformer.h[11].output[0].clone().save()
13 # we want to modify the hidden states for the last token
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/Tracer.py:103, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
98 self.invoker.__exit__(None, None, None)
100 self.model._envoy._reset()
--> 103 super().__exit__(exc_type, exc_val, exc_tb)
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:216, in GraphBasedContext.__exit__(self, exc_type, exc_val, exc_tb)
214 self.graph.alive = False
215 self.graph = None
--> 216 raise exc_val
218 self.backend(self)
Cell In[1], line 14
11 original_output = model.transformer.h[11].output[0].clone().save()
13 # we want to modify the hidden states for the last token
---> 14 model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
16 modified_output = model.transformer.h[11].output[0].save()
18 print("\nOriginal Output: ", original_output[0][-1])
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:89, in Proxy.__setitem__(self, key, value)
88 def __setitem__(self, key: Union[Proxy, Any], value: Union[Self, Any]) -> None:
---> 89 self.node.create(
90 target=operator.setitem,
91 args=[self.node, key, value],
92 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:270, in Node.create(self, target, proxy_value, args, kwargs, name)
267 return value
269 # Otherwise just create the Node on the Graph like normal.
--> 270 return self.graph.create(
271 target=target,
272 name=name,
273 proxy_value=proxy_value,
274 args=args,
275 kwargs=kwargs,
276 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:113, in Graph.create(self, *args, **kwargs)
106 def create(self, *args, **kwargs) -> Proxy:
107 """Creates a Node directly on this `Graph` and returns its `Proxy`.
108
109 Returns:
110 Proxy: `Proxy` for newly created `Node`.
111 """
--> 113 return self.proxy_class(Node(*args, graph=self, **kwargs))
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:96, in Node.__init__(self, target, graph, proxy_value, args, kwargs, name)
93 # If theres an alive Graph, add it.
94 if self.attached():
---> 96 self.graph.add(self)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:131, in Graph.add(self, node)
128 # If we're validating and the user did not provide a proxy_value, execute the given target with meta proxy values to compute new proxy_value.
129 if self.validate and node.proxy_value is inspect._empty:
--> 131 node.proxy_value = validate(node.target, *node.args, **node.kwargs)
133 # Get name of target.
134 name = (
135 node.target
136 if isinstance(node.target, str)
137 else node.target.__name__
138 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:20, in validate(target, *args, **kwargs)
14 with FakeTensorMode(
15 allow_non_fake_inputs=True,
16 shape_env=ShapeEnv(assume_static_by_default=True),
17 ) as fake_mode:
18 with FakeCopyMode(fake_mode):
---> 20 with GlobalTracingContext.exit_global_tracing_context():
22 args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
24 return target(
25 *args,
26 **kwargs,
27 )
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:330, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
326 GlobalTracingContext.PATCHER.__enter__()
328 if isinstance(exc_val, BaseException):
--> 330 raise exc_val
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:24, in validate(target, *args, **kwargs)
20 with GlobalTracingContext.exit_global_tracing_context():
22 args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
---> 24 return target(
25 *args,
26 **kwargs,
27 )
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:312, in GlobalTracingContext.GlobalTracingTorchHandler.__torch_function__(self, func, types, args, kwargs)
305 if "_VariableFunctionsClass" in func.__qualname__:
306 return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
307 func,
308 *args,
309 **kwargs
310 )
--> 312 return func(*args, **kwargs)
IndexError: index 10 is out of bounds for dimension 1 with size 10
Ah of course, we needed to index at number_of_tokens - 1
not number_of_tokens
.
How was nnsight
able to catch this error?
If scan
and validate
are enabled, our input is run though the model, but under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight
to record the shapes and data types of module inputs and outputs! The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight
also attempts to execute the request on the “fake”
values we recorded.
“Scanning” is what we call running “fake” inputs throught the model to collect information like shapes and types. “Validating” is what we call trying to execute the intervention proxies with “fake” inputs to see if they work. “Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate.
A word of caution
Some pytorch operations and related libraries don’t work well with fake tensors
If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.
Let’s try again with the correct indexing, and view the shape of the output before leaving the tracing context:
[2]:
with model.trace(input, scan=True, validate=True):
original_output = model.transformer.h[11].output[0].clone().save()
# we want to modify the hidden states for the last token
model.transformer.h[11].output[0][:, number_of_tokens-1, :] = 0
modified_output = model.transformer.h[11].output[0].save()
print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Original Output: tensor([ 6.6286e+00, 1.7258e+00, 4.7969e+00, -3.8255e+00, 1.0698e+00,
-1.4242e+00, 9.2749e+00, 6.0404e+00, -3.2988e+00, -2.7030e+00,
-3.9210e-01, -5.5507e-01, 6.4831e+00, 1.4496e+00, -4.2496e-01,
-9.4764e+00, -8.5587e-01, 4.8417e+00, 1.7383e+00, -1.9535e+01,
-2.1625e+00, -5.4659e+00, 7.9305e-02, -1.2014e+00, 7.6166e-01,
1.3293e+00, -8.1797e-01, -6.6870e+00, 2.9511e+00, -3.3648e+00,
-3.5362e+00, -3.2539e+00, 3.4988e+00, 2.2232e+00, 3.4482e+00,
-4.8883e+00, 6.5206e+01, 8.3218e-01, -3.5060e-01, -1.2041e+00,
-1.8520e+00, 2.3440e+00, 3.0114e+00, 6.7492e+00, 4.4499e+00,
6.5314e+00, -3.0311e+00, 4.3609e+00, 6.4801e-01, 1.6725e+00,
3.0538e+00, 2.5054e+00, -1.9737e+00, -1.5169e+00, -3.9845e+00,
3.4548e+00, -1.7004e+00, -2.8162e+00, -2.5651e-01, -1.8362e+00,
-7.5023e+00, 1.9528e+00, 6.4438e-01, 8.7818e-01, -1.0992e+02,
8.8575e+00, 2.1478e-01, -6.6564e+00, -3.2905e-01, -2.3264e-01,
-4.9211e+00, -6.6120e-02, -3.6601e+00, -1.5331e+00, -1.2056e+00,
5.6827e+00, 6.7709e+00, -5.1225e-01, -3.9829e+00, 2.9273e+00,
-5.6971e+00, 1.6272e+00, 3.2709e+00, -2.0965e+00, -1.4081e+00,
-2.4534e+00, -4.4642e+00, -1.0931e+02, -3.8111e+00, 2.1471e-01,
-6.9692e-01, -1.1457e+00, -1.1235e+01, 5.2517e+00, -4.2227e+00,
-3.2003e+00, -7.1090e+00, 1.8102e+00, -2.4567e+00, 7.7879e-01,
-6.2834e+00, 3.7080e+00, -1.6301e+00, 6.9053e-01, -2.9357e-02,
-5.7841e-01, -2.0679e+00, -2.0271e+00, -2.0579e+00, -5.7195e+00,
5.1443e-01, 1.7420e+00, 1.3746e+00, 3.5129e+00, 8.9945e-01,
-4.1595e-01, -1.5102e+00, -1.2280e+00, 3.4264e+00, -7.5586e+00,
2.8480e+00, 6.8030e+00, -1.3625e+00, 1.1234e+01, -1.5630e+00,
2.7383e+00, 3.4384e+00, 1.0834e+01, 5.3671e-01, 8.9106e-01,
-6.9593e+00, 1.0443e+00, -3.2028e-01, 1.1285e+01, 3.6665e+00,
3.1522e+00, 2.0780e+00, -1.5473e+00, -2.9861e+01, -3.5902e+00,
-4.2159e+00, 7.0041e-03, 8.4291e+00, -3.5786e+00, 3.3004e+00,
1.3246e+00, 2.4886e+00, -3.1515e+00, 2.5345e+00, -3.1293e+00,
5.2794e+00, -6.1508e+00, 9.1936e+00, -1.7968e+00, -3.2526e+00,
-1.0222e+00, 7.4691e+00, -3.1648e+00, -1.6389e+00, -9.7188e-01,
-2.1339e+00, -5.8595e+00, -4.7614e+00, -4.3966e+00, -1.0889e+00,
-1.5380e+00, 7.0410e+00, -1.5772e+01, -2.4345e-01, 3.2805e+00,
5.1078e+00, -4.2193e+00, -2.4413e+00, -1.2237e-01, -9.5395e+00,
-3.3000e+00, 4.7484e+00, 2.1002e+00, 2.5656e+00, -4.1450e+00,
1.1324e+01, -9.2751e-01, -8.7061e-03, -1.3499e+00, -1.0883e+00,
1.2036e+00, 6.4077e-01, -2.0958e+00, -9.4460e-01, -4.6134e+00,
-4.8703e+00, 3.2674e+00, -3.2317e+00, 5.0362e+00, -3.1834e+00,
-1.3516e+01, 7.0807e+00, 2.0336e+00, 2.0479e+00, -1.9521e+00,
-8.9104e+00, 3.0803e+00, -3.2048e+00, -9.7705e-01, 3.8135e+00,
2.4048e-01, 1.3258e+00, -7.1608e-01, 1.2787e+00, -9.8557e-02,
5.1077e+00, -4.0518e+00, -2.6806e-03, -7.2934e-01, -5.4432e+00,
3.5619e+00, 6.1031e+00, -7.2877e-01, -4.0819e+00, 2.9329e+00,
3.8585e+00, -2.9784e+00, 1.1124e+00, -8.2287e+00, 2.7348e+00,
6.0236e-01, -2.4054e+00, -3.5393e+00, -1.5170e+00, -7.5092e-01,
-4.3856e+00, 5.0673e+00, 1.6784e+01, 3.3701e+00, -6.4999e+00,
7.3039e+00, 7.5358e+00, -1.9126e+00, 2.1336e+00, -3.2421e+00,
-4.4454e+00, 2.0309e+00, 1.3034e+00, -5.0879e+00, 4.4193e+00,
1.1515e+01, 7.5885e-01, 4.5374e+00, -1.0041e+01, 2.4802e+00,
-1.9640e+00, 8.6382e+00, -2.9521e-01, -2.5199e+00, -3.1697e+00,
-4.1011e+00, 2.9947e+00, 2.5317e-01, 3.3526e+00, -8.4459e-01,
-1.6096e+00, 4.6977e+00, -2.5488e+00, -3.9472e+00, -2.5825e+00,
1.1431e+00, -9.7997e+01, -6.4164e+00, -1.7173e+00, 4.5707e+00,
2.2898e+00, -4.0544e+00, 8.1279e+00, 4.0761e+00, 3.0572e+00,
1.4155e+00, 3.1557e+00, 2.0821e+00, 1.4200e+00, 2.1833e-01,
4.3272e+00, -3.4302e+00, -1.4085e+00, 2.8076e-01, -8.1994e-01,
-6.6751e+00, 2.7346e+00, 8.0669e+00, -8.9313e-01, 5.0663e-01,
2.6838e+00, -1.5756e+00, 1.8478e+00, 1.4335e+00, -3.1085e+00,
-3.5366e+00, -2.3190e+00, -2.4223e+00, 4.6886e-01, 2.7125e+00,
-4.6715e-01, 1.9403e+00, 1.6051e+00, 4.9405e+00, 5.9342e+00,
6.0049e+00, 1.0645e+00, 3.4900e+00, 4.0834e+00, 7.7649e-01,
2.0025e+00, 2.8585e+00, 4.6692e-01, -4.4438e-01, -5.1792e+01,
-3.7743e+00, 1.0210e+00, -4.8536e-01, -3.3734e+00, 1.2150e+00,
2.1420e+00, 6.3237e-01, -2.1060e-01, -1.8375e-01, -5.4864e-01,
6.0235e+00, 8.7014e+00, -7.0699e+00, 3.2431e+00, 1.7484e+00,
-7.2326e+00, 3.2734e+00, 3.1882e+00, 6.0536e+00, 9.2431e-01,
-5.6773e+00, -4.9881e+00, 1.4860e+00, -3.0162e+00, -3.5636e+00,
-8.4085e-01, -1.0074e+00, -1.6067e+01, -1.3144e+01, 1.1818e-01,
6.4546e+00, -1.0912e+00, 1.1206e+01, -1.3724e+00, 2.7155e+00,
-7.4178e+00, -1.2956e+00, -1.6902e+00, 9.5223e+00, 2.4952e+00,
-3.7531e+00, 5.1116e+00, 3.8512e+00, 3.7123e+00, 7.6355e-01,
-3.8020e+00, 1.8983e+00, 4.3574e+01, -7.3467e+00, -2.8905e+00,
-4.3620e+00, 3.8599e+00, 1.2110e+00, 2.5844e+00, -4.4734e+00,
-2.9198e+00, 2.9884e+00, -1.7416e+00, -1.6009e+02, -1.6512e+01,
-9.0772e-02, 6.8859e+00, -3.4945e+00, 6.9834e+00, 6.2670e+00,
5.6965e-01, 3.0969e-02, 2.9006e+00, 8.9760e-01, -2.3006e+00,
8.3767e-01, -2.9463e+00, -4.3662e+00, -2.5842e+00, 9.1415e+00,
-5.5953e+00, -4.7413e+00, -3.7179e+00, 2.1987e+01, 1.8265e+00,
-8.9532e-01, 4.1599e+00, -3.9978e+00, -3.8641e+00, -2.0673e+00,
3.0306e+00, -2.1554e+00, -2.0086e+00, 4.8663e+00, 2.8928e-01,
-4.1998e+00, 8.5062e+00, 4.2103e+00, 1.7206e+01, -3.5376e+00,
2.5671e-01, 2.2228e-01, 6.6561e+00, -4.8118e+00, -3.6561e+00,
-6.8825e-01, -4.4886e+00, -9.2805e+00, -3.8699e-01, -1.3017e+00,
2.3756e+00, 4.2634e+00, -3.6821e+00, -6.7862e-01, -3.6553e+00,
-3.7938e+00, -3.6224e+00, 1.7574e+00, 3.0822e+00, -3.9101e-01,
7.3470e+01, -5.4396e-01, 9.8533e-01, -3.3989e+00, -1.3025e+00,
-2.0689e+00, 2.5526e+00, -2.0326e+00, 3.0746e+00, -2.0950e+00,
4.1649e+00, 6.4935e-01, -7.1394e+00, -1.0906e+00, -7.1690e+00,
3.2226e+00, -1.2760e+00, 1.4431e+02, 6.0973e+00, -1.3528e+00,
-4.9426e+00, -8.5969e-01, -8.5405e+00, 5.7662e+00, -2.4922e+00,
-7.4357e+00, -7.8989e+00, 4.2059e+00, -3.9071e+00, -2.3410e+00,
-1.0605e+00, -1.5846e+00, -6.7736e+00, 5.0953e-01, 1.5078e+01,
-3.7121e-02, 3.4718e+00, -2.8174e+00, -3.7044e+00, 1.0058e+01,
-2.1485e+00, 3.8376e+00, -8.5943e-01, 1.8444e+00, 3.8978e+00,
3.0776e+00, -1.0642e+00, 4.3514e+00, 2.2874e+00, 5.2097e+00,
-9.6646e+01, 1.6328e+02, -1.2562e-01, 1.4275e+00, 1.6121e+00,
-8.4209e+00, 1.1070e+00, -4.9542e+00, 1.7836e+00, -1.0035e+01,
-2.0094e+00, 2.7393e+00, 5.4366e-01, -7.8556e-01, -3.2903e-01,
5.5690e+00, 1.6298e+02, 1.2884e+00, 6.0525e+00, 2.3517e+00,
-5.1991e+00, -8.9318e+00, 3.2929e+00, -1.2494e+01, 4.4547e+00,
-4.6449e-01, 1.0015e+00, 3.5933e-01, 2.1539e+00, 8.7483e-01,
-1.9571e+00, 4.6167e+00, -4.1041e+00, 4.1621e+00, -3.2981e+00,
5.2509e-02, 6.9099e+00, 1.9714e-01, 8.3853e+00, 1.2267e+00,
-3.2992e+00, 4.8448e+00, 2.9776e+00, -2.0985e-01, 1.9363e+00,
-3.0710e+00, -4.9856e+00, 9.9332e-01, 1.8208e+00, 6.0879e+00,
3.6373e+00, -2.4278e-01, 2.3095e+00, -1.1847e+00, 5.4024e+00,
4.2026e+00, 2.0059e-01, -6.7174e+00, -6.0502e+00, -7.2487e+00,
4.4279e+00, -2.2406e+00, 5.9508e+00, 6.9945e+00, 1.2061e+00,
-3.2929e+00, -1.5007e+00, 2.5605e+00, -1.6327e+00, 6.6576e+00,
4.1315e+00, -2.5557e+00, 2.7006e+00, 6.6053e+00, -1.9167e+00,
-1.5669e+00, -2.4749e+00, 4.2183e-01, 2.4309e+00, 1.5208e-01,
-9.1254e+00, -3.7077e+00, 4.5885e+00, -1.9774e+00, 1.3449e+01,
3.4506e+00, -7.2922e+00, -5.2381e-01, 1.1032e+00, 2.7615e+00,
-6.3168e+00, -3.8688e-02, 6.5302e+00, -3.0305e-01, 1.4317e+00,
-3.4285e+00, -4.6842e+00, -3.0345e+00, 5.0907e+00, -2.3043e+00,
4.8652e+00, -3.4368e+00, -8.0064e+00, -6.0210e+00, 5.9621e+00,
1.6881e+00, -1.6544e+00, -6.9820e+00, 7.2363e-01, -7.4828e+00,
-1.5182e+00, -1.0165e+00, -4.8026e+00, -2.3627e+00, 2.1291e+00,
-6.0510e+00, -4.2262e+00, -2.6994e+00, 2.1767e-01, -3.2933e+00,
6.8151e+00, -2.8563e+00, -1.1988e+00, 1.6882e+00, 1.8934e+00,
9.4241e+00, 9.5913e-01, 2.1459e+00, 1.3060e+01, -4.9042e+00,
-5.1736e+00, 8.4661e+00, 7.1147e-01, -1.0368e+01, 3.1779e+00,
1.8113e+00, 7.0269e+00, 2.9996e+00, -2.8506e+00, -7.3495e-01,
-8.3824e+00, -2.6062e+00, 5.2471e+00, 6.6041e+00, -8.8814e-01,
1.3164e-01, 7.5559e-01, 6.3661e+00, 6.7891e-01, -1.6838e+00,
3.9709e-02, 4.6224e-01, -5.2465e+00, 5.9410e+00, -4.3000e+00,
-4.3974e+00, -6.8152e+00, -4.7801e+00, -4.4493e+00, -4.0901e-01,
1.5374e+01, 2.9918e+00, 1.8610e+00, -3.2574e+00, -3.8129e+00,
-5.0761e+00, 9.2315e+00, -1.0814e+00, 2.8385e-01, -3.2490e+00,
3.4800e+00, -5.8052e+00, 1.3061e+00, -2.8434e+00, 6.5931e-02,
3.2917e-02, -1.9720e+00, -3.3967e+00, -7.4260e+00, 6.1623e+00,
-4.1470e+00, -4.2324e+00, 9.5378e-01, -1.0738e+00, -1.3570e-01,
-2.2882e+00, 4.0768e+00, 2.7377e-01, -1.5838e+00, 1.1146e+01,
6.9866e+00, 1.3019e+01, -1.8275e+00, -6.4068e+00, 1.7844e+00,
-8.6919e-01, -2.9575e+00, 4.8200e-01, -6.8251e+00, -2.5988e+00,
-8.7482e+00, 8.0484e-01, 2.6259e+00, -4.5771e+00, -5.8153e-01,
4.7844e+00, -3.3178e+00, 9.6060e+00, 4.9834e+00, 5.5188e-01,
-3.0826e-01, -2.0525e+00, -4.3834e+00, 5.7229e+00, 3.6664e+00,
4.6413e-02, 1.0723e+00, -5.3191e+00, 1.4131e+00, -9.6075e-01,
-3.2504e+00, -2.3741e+00, 2.7716e+00, -1.4227e+00, -9.3925e-01,
-7.5700e+00, -3.7153e+00, 1.3690e+00, 1.1079e+01, -8.3346e+00,
-4.9114e+00, -2.0677e+00, -1.0595e+00, -1.8981e+00, 6.1321e+00,
1.9002e+00, -1.9221e-02, -1.1126e+00, -1.0669e+01, -3.0803e+00,
-3.5232e+00, 1.5203e+00, -4.8918e+00, 2.8023e+00, 7.7897e+00,
3.6398e+00, -6.3831e-01, -5.3203e+00, 1.7896e+00, -4.1591e+00,
-4.2745e-01, 3.0674e+00, 4.1897e+00, -2.5548e+00, 3.0816e-01,
-3.4219e-02, -3.2085e+00, 3.0132e+00, -5.1880e+00, 8.2857e-01,
-9.3511e-01, -2.3891e-01, 3.3661e+00, 2.3398e+00, -7.5823e+00,
-1.1914e+01, 5.7218e+00, -2.5562e+00, -2.9274e+00, 3.1782e+00,
-1.9518e+00, 4.3836e+00, 3.5140e+00, -2.8808e+00, -3.9965e+00,
1.6073e+00, -5.4745e+00, -1.4041e+00, -2.8542e+00, -2.0609e+00,
8.0068e-02, -2.6370e+00, -4.4448e+00, -6.0635e+00, -4.2056e+00,
7.6714e+00, 3.0683e+00, 2.0481e+00], device='mps:0',
grad_fn=<SelectBackward0>)
Modified Output: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='mps:0', grad_fn=<SelectBackward0>)
We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:
[3]:
with model.scan(input):
dim = model.transformer.h[11].output[0].shape[-1]
print(dim)
768