Scan and Validate#

Have you encountered a situation where you are changing the tensor values in the intervention code and getting an error message that is not very helpful?

This is where “Scanning” and “Validating” can help. As the name suggests, these features help you scan the shapes of the tensors throughout the model and validate that the current tensor shapes are compatible with the model.

We can enable these helpful tools by setting the scan=True and validate=True flags in the trace method.

Here is an example that demonstrates how Scan and Validate can help us debug the model:

[1]:
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')

input = "The Eiffel Tower is in the city of"
number_of_tokens = len(model.tokenizer.encode(input))

# turn on scan and validate
with model.trace(input, scan=True, validate=True):

    original_output = model.transformer.h[11].output[0].clone().save()

    # we want to modify the hidden states for the last token
    model.transformer.h[11].output[0][:, number_of_tokens, :] = 0

    modified_output = model.transformer.h[11].output[0].save()

print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[1], line 9
      6 number_of_tokens = len(model.tokenizer.encode(input))
      8 # turn on scan and validate
----> 9 with model.trace(input, scan=True, validate=True):
     11     original_output = model.transformer.h[11].output[0].clone().save()
     13     # we want to modify the hidden states for the last token

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
     92     self.invoker.__exit__(None, None, None)
     94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     21 from .globals import GlobalTracingContext
     23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/base.py:72, in Context.__exit__(self, exc_type, exc_val, exc_tb)
     69 graph = self.graph.stack.pop()
     71 if isinstance(exc_val, BaseException):
---> 72     raise exc_val
     74 self.add(graph.stack[-1], graph, *self.args, **self.kwargs)
     76 if self.backend is not None:

Cell In[1], line 14
     11     original_output = model.transformer.h[11].output[0].clone().save()
     13     # we want to modify the hidden states for the last token
---> 14     model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
     16     modified_output = model.transformer.h[11].output[0].save()
     18 print("\nOriginal Output: ", original_output[0][-1])

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:126, in Proxy.__setitem__(self, key, value)
    125 def __setitem__(self, key: Union[Self, Any], value: Union[Self, Any]) -> None:
--> 126     self.node.create(
    127         operator.setitem,
    128         self.node,
    129         key,
    130         value,
    131     )

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/node.py:250, in Node.create(self, *args, **kwargs)
    247     return value
    249 # Otherwise just create the Node on the Graph like normal.
--> 250 return self.graph.create(
    251     *args,
    252     **kwargs,
    253 )

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/graph.py:131, in Graph.create(self, target, redirect, *args, **kwargs)
    128 # Redirection.
    129 graph = self.stack[-1] if redirect and self.stack else self
--> 131 return self.proxy_class(self.node_class(target, *args, graph=graph, **kwargs))

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:118, in ValidatingInterventionNode.__init__(self, *args, **kwargs)
    111 super().__init__(*args, **kwargs)
    113 if (
    114     self.attached
    115     and self.fake_value is inspect._empty
    116     and not Protocol.is_protocol(self.target)
    117 ):
--> 118     self.fake_value = validate(self.target, *self.args, **self.kwargs)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:147, in validate(target, *args, **kwargs)
    141 with FakeTensorMode(
    142     allow_non_fake_inputs=True,
    143     shape_env=ShapeEnv(assume_static_by_default=True),
    144 ) as fake_mode:
    145     with FakeCopyMode(fake_mode):
--> 147         with GlobalTracingContext.exit_global_tracing_context():
    149             if backwards_check(target, *args):
    150                 return None

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/globals.py:100, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
     96 GlobalTracingContext.PATCHER.__enter__()
     98 if isinstance(exc_val, BaseException):
--> 100     raise exc_val

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/graph/node.py:156, in validate(target, *args, **kwargs)
    150     return None
    152 args, kwargs = InterventionNode.prepare_inputs(
    153     (args, kwargs), fake=True
    154 )
--> 156 return target(
    157     *args,
    158     **kwargs,
    159 )

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2350, in FakeCopyMode.__torch_function__(self, func, types, args, kwargs)
   2348 else:
   2349     with torch._C.DisableTorchFunctionSubclass():
-> 2350         return func(*args, **kwargs)

IndexError: index 10 is out of bounds for dimension 1 with size 10

Ah of course, we needed to index at number_of_tokens - 1 not number_of_tokens.

How was nnsight able to catch this error?

If scan and validate are enabled, our input is run though the model, but under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight to record the shapes and data types of module inputs and outputs! The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight also attempts to execute the request on the “fake” values we recorded.

“Scanning” is what we call running “fake” inputs throught the model to collect information like shapes and types. “Validating” is what we call trying to execute the intervention proxies with “fake” inputs to see if they work. “Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate.

A word of caution


Some pytorch operations and related libraries don’t work well with fake tensors

If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.


Let’s try again with the correct indexing, and view the shape of the output before leaving the tracing context:

[2]:
with model.trace(input, scan=True, validate=True):

    original_output = model.transformer.h[11].output[0].clone().save()

    # we want to modify the hidden states for the last token
    model.transformer.h[11].output[0][:, number_of_tokens-1, :] = 0

    modified_output = model.transformer.h[11].output[0].save()

print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])

Original Output:  tensor([ 6.6286e+00,  1.7258e+00,  4.7969e+00, -3.8255e+00,  1.0698e+00,
        -1.4242e+00,  9.2749e+00,  6.0404e+00, -3.2988e+00, -2.7030e+00,
        -3.9210e-01, -5.5507e-01,  6.4831e+00,  1.4496e+00, -4.2496e-01,
        -9.4764e+00, -8.5587e-01,  4.8417e+00,  1.7383e+00, -1.9535e+01,
        -2.1625e+00, -5.4659e+00,  7.9303e-02, -1.2014e+00,  7.6165e-01,
         1.3293e+00, -8.1798e-01, -6.6870e+00,  2.9512e+00, -3.3648e+00,
        -3.5362e+00, -3.2539e+00,  3.4988e+00,  2.2232e+00,  3.4482e+00,
        -4.8883e+00,  6.5206e+01,  8.3218e-01, -3.5059e-01, -1.2041e+00,
        -1.8520e+00,  2.3440e+00,  3.0114e+00,  6.7492e+00,  4.4499e+00,
         6.5314e+00, -3.0311e+00,  4.3609e+00,  6.4800e-01,  1.6725e+00,
         3.0538e+00,  2.5054e+00, -1.9737e+00, -1.5169e+00, -3.9845e+00,
         3.4548e+00, -1.7004e+00, -2.8162e+00, -2.5650e-01, -1.8362e+00,
        -7.5023e+00,  1.9528e+00,  6.4437e-01,  8.7817e-01, -1.0992e+02,
         8.8575e+00,  2.1479e-01, -6.6564e+00, -3.2904e-01, -2.3263e-01,
        -4.9211e+00, -6.6106e-02, -3.6601e+00, -1.5331e+00, -1.2056e+00,
         5.6827e+00,  6.7709e+00, -5.1224e-01, -3.9829e+00,  2.9273e+00,
        -5.6971e+00,  1.6272e+00,  3.2709e+00, -2.0965e+00, -1.4081e+00,
        -2.4534e+00, -4.4642e+00, -1.0931e+02, -3.8111e+00,  2.1470e-01,
        -6.9690e-01, -1.1457e+00, -1.1235e+01,  5.2517e+00, -4.2227e+00,
        -3.2003e+00, -7.1090e+00,  1.8102e+00, -2.4567e+00,  7.7881e-01,
        -6.2834e+00,  3.7080e+00, -1.6301e+00,  6.9053e-01, -2.9357e-02,
        -5.7842e-01, -2.0679e+00, -2.0271e+00, -2.0579e+00, -5.7195e+00,
         5.1443e-01,  1.7420e+00,  1.3746e+00,  3.5129e+00,  8.9945e-01,
        -4.1596e-01, -1.5102e+00, -1.2280e+00,  3.4264e+00, -7.5586e+00,
         2.8480e+00,  6.8030e+00, -1.3625e+00,  1.1234e+01, -1.5630e+00,
         2.7383e+00,  3.4384e+00,  1.0834e+01,  5.3671e-01,  8.9106e-01,
        -6.9593e+00,  1.0443e+00, -3.2028e-01,  1.1285e+01,  3.6665e+00,
         3.1522e+00,  2.0780e+00, -1.5473e+00, -2.9861e+01, -3.5902e+00,
        -4.2159e+00,  7.0031e-03,  8.4291e+00, -3.5786e+00,  3.3005e+00,
         1.3246e+00,  2.4886e+00, -3.1515e+00,  2.5345e+00, -3.1293e+00,
         5.2794e+00, -6.1508e+00,  9.1936e+00, -1.7968e+00, -3.2526e+00,
        -1.0222e+00,  7.4691e+00, -3.1648e+00, -1.6389e+00, -9.7188e-01,
        -2.1339e+00, -5.8595e+00, -4.7614e+00, -4.3966e+00, -1.0889e+00,
        -1.5380e+00,  7.0410e+00, -1.5772e+01, -2.4344e-01,  3.2804e+00,
         5.1078e+00, -4.2193e+00, -2.4413e+00, -1.2236e-01, -9.5395e+00,
        -3.3000e+00,  4.7484e+00,  2.1003e+00,  2.5656e+00, -4.1450e+00,
         1.1324e+01, -9.2751e-01, -8.7090e-03, -1.3499e+00, -1.0883e+00,
         1.2036e+00,  6.4078e-01, -2.0958e+00, -9.4459e-01, -4.6134e+00,
        -4.8703e+00,  3.2674e+00, -3.2317e+00,  5.0362e+00, -3.1834e+00,
        -1.3516e+01,  7.0807e+00,  2.0336e+00,  2.0479e+00, -1.9521e+00,
        -8.9104e+00,  3.0803e+00, -3.2048e+00, -9.7706e-01,  3.8135e+00,
         2.4048e-01,  1.3258e+00, -7.1607e-01,  1.2787e+00, -9.8548e-02,
         5.1077e+00, -4.0518e+00, -2.6827e-03, -7.2934e-01, -5.4432e+00,
         3.5619e+00,  6.1031e+00, -7.2877e-01, -4.0819e+00,  2.9329e+00,
         3.8585e+00, -2.9784e+00,  1.1124e+00, -8.2287e+00,  2.7348e+00,
         6.0236e-01, -2.4054e+00, -3.5393e+00, -1.5170e+00, -7.5092e-01,
        -4.3856e+00,  5.0673e+00,  1.6784e+01,  3.3701e+00, -6.5000e+00,
         7.3039e+00,  7.5358e+00, -1.9126e+00,  2.1336e+00, -3.2421e+00,
        -4.4454e+00,  2.0309e+00,  1.3034e+00, -5.0879e+00,  4.4193e+00,
         1.1515e+01,  7.5886e-01,  4.5374e+00, -1.0041e+01,  2.4802e+00,
        -1.9640e+00,  8.6382e+00, -2.9520e-01, -2.5199e+00, -3.1697e+00,
        -4.1011e+00,  2.9947e+00,  2.5317e-01,  3.3526e+00, -8.4460e-01,
        -1.6096e+00,  4.6977e+00, -2.5488e+00, -3.9472e+00, -2.5825e+00,
         1.1430e+00, -9.7997e+01, -6.4164e+00, -1.7173e+00,  4.5707e+00,
         2.2899e+00, -4.0544e+00,  8.1279e+00,  4.0761e+00,  3.0572e+00,
         1.4155e+00,  3.1557e+00,  2.0821e+00,  1.4200e+00,  2.1832e-01,
         4.3272e+00, -3.4302e+00, -1.4085e+00,  2.8077e-01, -8.1994e-01,
        -6.6751e+00,  2.7346e+00,  8.0669e+00, -8.9311e-01,  5.0664e-01,
         2.6838e+00, -1.5756e+00,  1.8478e+00,  1.4335e+00, -3.1085e+00,
        -3.5366e+00, -2.3190e+00, -2.4223e+00,  4.6886e-01,  2.7125e+00,
        -4.6714e-01,  1.9403e+00,  1.6051e+00,  4.9405e+00,  5.9342e+00,
         6.0049e+00,  1.0645e+00,  3.4900e+00,  4.0834e+00,  7.7648e-01,
         2.0025e+00,  2.8585e+00,  4.6693e-01, -4.4440e-01, -5.1792e+01,
        -3.7743e+00,  1.0210e+00, -4.8536e-01, -3.3734e+00,  1.2150e+00,
         2.1420e+00,  6.3238e-01, -2.1061e-01, -1.8375e-01, -5.4864e-01,
         6.0235e+00,  8.7014e+00, -7.0699e+00,  3.2431e+00,  1.7484e+00,
        -7.2326e+00,  3.2734e+00,  3.1882e+00,  6.0536e+00,  9.2431e-01,
        -5.6773e+00, -4.9881e+00,  1.4860e+00, -3.0162e+00, -3.5636e+00,
        -8.4084e-01, -1.0074e+00, -1.6067e+01, -1.3144e+01,  1.1817e-01,
         6.4546e+00, -1.0912e+00,  1.1206e+01, -1.3724e+00,  2.7155e+00,
        -7.4177e+00, -1.2955e+00, -1.6902e+00,  9.5223e+00,  2.4952e+00,
        -3.7531e+00,  5.1116e+00,  3.8512e+00,  3.7123e+00,  7.6354e-01,
        -3.8020e+00,  1.8983e+00,  4.3574e+01, -7.3467e+00, -2.8905e+00,
        -4.3620e+00,  3.8599e+00,  1.2109e+00,  2.5844e+00, -4.4734e+00,
        -2.9198e+00,  2.9884e+00, -1.7416e+00, -1.6009e+02, -1.6512e+01,
        -9.0769e-02,  6.8859e+00, -3.4945e+00,  6.9834e+00,  6.2670e+00,
         5.6965e-01,  3.0963e-02,  2.9006e+00,  8.9759e-01, -2.3006e+00,
         8.3767e-01, -2.9463e+00, -4.3662e+00, -2.5841e+00,  9.1415e+00,
        -5.5953e+00, -4.7413e+00, -3.7179e+00,  2.1987e+01,  1.8265e+00,
        -8.9532e-01,  4.1599e+00, -3.9978e+00, -3.8641e+00, -2.0673e+00,
         3.0306e+00, -2.1554e+00, -2.0086e+00,  4.8663e+00,  2.8927e-01,
        -4.1999e+00,  8.5062e+00,  4.2103e+00,  1.7206e+01, -3.5376e+00,
         2.5672e-01,  2.2229e-01,  6.6561e+00, -4.8118e+00, -3.6561e+00,
        -6.8825e-01, -4.4886e+00, -9.2805e+00, -3.8699e-01, -1.3017e+00,
         2.3756e+00,  4.2634e+00, -3.6821e+00, -6.7862e-01, -3.6553e+00,
        -3.7938e+00, -3.6224e+00,  1.7574e+00,  3.0822e+00, -3.9100e-01,
         7.3470e+01, -5.4397e-01,  9.8533e-01, -3.3989e+00, -1.3025e+00,
        -2.0689e+00,  2.5526e+00, -2.0326e+00,  3.0746e+00, -2.0950e+00,
         4.1649e+00,  6.4935e-01, -7.1394e+00, -1.0906e+00, -7.1690e+00,
         3.2226e+00, -1.2760e+00,  1.4431e+02,  6.0973e+00, -1.3528e+00,
        -4.9426e+00, -8.5969e-01, -8.5405e+00,  5.7662e+00, -2.4922e+00,
        -7.4357e+00, -7.8989e+00,  4.2059e+00, -3.9071e+00, -2.3410e+00,
        -1.0605e+00, -1.5846e+00, -6.7736e+00,  5.0952e-01,  1.5078e+01,
        -3.7115e-02,  3.4718e+00, -2.8174e+00, -3.7045e+00,  1.0058e+01,
        -2.1485e+00,  3.8376e+00, -8.5942e-01,  1.8444e+00,  3.8978e+00,
         3.0776e+00, -1.0642e+00,  4.3514e+00,  2.2874e+00,  5.2097e+00,
        -9.6646e+01,  1.6328e+02, -1.2562e-01,  1.4275e+00,  1.6121e+00,
        -8.4209e+00,  1.1070e+00, -4.9542e+00,  1.7836e+00, -1.0035e+01,
        -2.0094e+00,  2.7393e+00,  5.4365e-01, -7.8556e-01, -3.2903e-01,
         5.5690e+00,  1.6298e+02,  1.2884e+00,  6.0525e+00,  2.3517e+00,
        -5.1991e+00, -8.9318e+00,  3.2929e+00, -1.2494e+01,  4.4547e+00,
        -4.6448e-01,  1.0015e+00,  3.5933e-01,  2.1539e+00,  8.7484e-01,
        -1.9571e+00,  4.6167e+00, -4.1041e+00,  4.1621e+00, -3.2981e+00,
         5.2514e-02,  6.9099e+00,  1.9714e-01,  8.3853e+00,  1.2267e+00,
        -3.2992e+00,  4.8448e+00,  2.9776e+00, -2.0985e-01,  1.9363e+00,
        -3.0710e+00, -4.9856e+00,  9.9332e-01,  1.8208e+00,  6.0879e+00,
         3.6372e+00, -2.4278e-01,  2.3095e+00, -1.1847e+00,  5.4024e+00,
         4.2026e+00,  2.0059e-01, -6.7174e+00, -6.0502e+00, -7.2487e+00,
         4.4279e+00, -2.2406e+00,  5.9508e+00,  6.9944e+00,  1.2061e+00,
        -3.2928e+00, -1.5006e+00,  2.5605e+00, -1.6327e+00,  6.6576e+00,
         4.1315e+00, -2.5557e+00,  2.7006e+00,  6.6053e+00, -1.9167e+00,
        -1.5669e+00, -2.4749e+00,  4.2183e-01,  2.4309e+00,  1.5208e-01,
        -9.1254e+00, -3.7077e+00,  4.5885e+00, -1.9774e+00,  1.3449e+01,
         3.4506e+00, -7.2922e+00, -5.2382e-01,  1.1032e+00,  2.7615e+00,
        -6.3168e+00, -3.8679e-02,  6.5301e+00, -3.0306e-01,  1.4317e+00,
        -3.4285e+00, -4.6842e+00, -3.0345e+00,  5.0907e+00, -2.3043e+00,
         4.8652e+00, -3.4368e+00, -8.0064e+00, -6.0210e+00,  5.9621e+00,
         1.6881e+00, -1.6543e+00, -6.9820e+00,  7.2362e-01, -7.4828e+00,
        -1.5182e+00, -1.0165e+00, -4.8026e+00, -2.3627e+00,  2.1291e+00,
        -6.0510e+00, -4.2262e+00, -2.6994e+00,  2.1767e-01, -3.2933e+00,
         6.8151e+00, -2.8563e+00, -1.1988e+00,  1.6883e+00,  1.8934e+00,
         9.4241e+00,  9.5912e-01,  2.1459e+00,  1.3060e+01, -4.9042e+00,
        -5.1736e+00,  8.4661e+00,  7.1147e-01, -1.0368e+01,  3.1779e+00,
         1.8113e+00,  7.0269e+00,  2.9996e+00, -2.8506e+00, -7.3495e-01,
        -8.3824e+00, -2.6062e+00,  5.2471e+00,  6.6041e+00, -8.8814e-01,
         1.3165e-01,  7.5560e-01,  6.3661e+00,  6.7891e-01, -1.6838e+00,
         3.9716e-02,  4.6225e-01, -5.2465e+00,  5.9410e+00, -4.3000e+00,
        -4.3974e+00, -6.8152e+00, -4.7801e+00, -4.4493e+00, -4.0902e-01,
         1.5374e+01,  2.9918e+00,  1.8610e+00, -3.2574e+00, -3.8129e+00,
        -5.0761e+00,  9.2315e+00, -1.0814e+00,  2.8385e-01, -3.2490e+00,
         3.4800e+00, -5.8052e+00,  1.3061e+00, -2.8434e+00,  6.5940e-02,
         3.2905e-02, -1.9720e+00, -3.3967e+00, -7.4260e+00,  6.1623e+00,
        -4.1470e+00, -4.2324e+00,  9.5378e-01, -1.0738e+00, -1.3569e-01,
        -2.2882e+00,  4.0768e+00,  2.7376e-01, -1.5838e+00,  1.1146e+01,
         6.9866e+00,  1.3019e+01, -1.8275e+00, -6.4068e+00,  1.7844e+00,
        -8.6919e-01, -2.9575e+00,  4.8201e-01, -6.8251e+00, -2.5988e+00,
        -8.7482e+00,  8.0484e-01,  2.6259e+00, -4.5771e+00, -5.8153e-01,
         4.7844e+00, -3.3178e+00,  9.6059e+00,  4.9834e+00,  5.5188e-01,
        -3.0826e-01, -2.0525e+00, -4.3834e+00,  5.7229e+00,  3.6664e+00,
         4.6417e-02,  1.0723e+00, -5.3191e+00,  1.4131e+00, -9.6074e-01,
        -3.2504e+00, -2.3741e+00,  2.7716e+00, -1.4227e+00, -9.3925e-01,
        -7.5700e+00, -3.7153e+00,  1.3690e+00,  1.1079e+01, -8.3346e+00,
        -4.9114e+00, -2.0677e+00, -1.0595e+00, -1.8981e+00,  6.1321e+00,
         1.9002e+00, -1.9220e-02, -1.1126e+00, -1.0669e+01, -3.0803e+00,
        -3.5232e+00,  1.5203e+00, -4.8918e+00,  2.8023e+00,  7.7897e+00,
         3.6398e+00, -6.3831e-01, -5.3203e+00,  1.7896e+00, -4.1591e+00,
        -4.2744e-01,  3.0674e+00,  4.1897e+00, -2.5547e+00,  3.0816e-01,
        -3.4222e-02, -3.2085e+00,  3.0132e+00, -5.1880e+00,  8.2855e-01,
        -9.3510e-01, -2.3891e-01,  3.3661e+00,  2.3398e+00, -7.5823e+00,
        -1.1914e+01,  5.7218e+00, -2.5562e+00, -2.9274e+00,  3.1782e+00,
        -1.9518e+00,  4.3836e+00,  3.5140e+00, -2.8808e+00, -3.9965e+00,
         1.6073e+00, -5.4745e+00, -1.4041e+00, -2.8542e+00, -2.0609e+00,
         8.0075e-02, -2.6370e+00, -4.4448e+00, -6.0635e+00, -4.2056e+00,
         7.6714e+00,  3.0682e+00,  2.0481e+00], device='mps:0',
       grad_fn=<SelectBackward0>)
Modified Output:  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='mps:0', grad_fn=<SelectBackward0>)

We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:

[3]:
with model.scan(input):

    dim = model.transformer.h[11].output[0].shape[-1]

print(dim)
768