Scan and Validate#

Have you encountered a situation where you are changing the tensor values in the intervention code and getting an error message that is not very helpful?

This is where “Scanning” and “Validating” can help. As the name suggests, these features help you scan the shapes of the tensors throughout the model and validate that the current tensor shapes are compatible with the model.

We can enable these helpful tools by setting the scan=True and validate=True flags in the trace method.

Here is an example that demonstrates how Scan and Validate can help us debug the model:

[1]:
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')

input = "The Eiffel Tower is in the city of"
number_of_tokens = len(model.tokenizer.encode(input))

# turn on scan and validate
with model.trace(input, scan=True, validate=True):

    original_output = model.transformer.h[11].output[0].clone().save()

    # we want to modify the hidden states for the last token
    model.transformer.h[11].output[0][:, number_of_tokens, :] = 0

    modified_output = model.transformer.h[11].output[0].save()

print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[1], line 9
      6 number_of_tokens = len(model.tokenizer.encode(input))
      8 # turn on scan and validate
----> 9 with model.trace(input, scan=True, validate=True):
     11     original_output = model.transformer.h[11].output[0].clone().save()
     13     # we want to modify the hidden states for the last token

File ~/Projects/NDIF/nnsight/src/nnsight/contexts/Tracer.py:103, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     98     self.invoker.__exit__(None, None, None)
    100 self.model._envoy._reset()
--> 103 super().__exit__(exc_type, exc_val, exc_tb)

File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:216, in GraphBasedContext.__exit__(self, exc_type, exc_val, exc_tb)
    214     self.graph.alive = False
    215     self.graph = None
--> 216     raise exc_val
    218 self.backend(self)

Cell In[1], line 14
     11     original_output = model.transformer.h[11].output[0].clone().save()
     13     # we want to modify the hidden states for the last token
---> 14     model.transformer.h[11].output[0][:, number_of_tokens, :] = 0
     16     modified_output = model.transformer.h[11].output[0].save()
     18 print("\nOriginal Output: ", original_output[0][-1])

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:89, in Proxy.__setitem__(self, key, value)
     88 def __setitem__(self, key: Union[Proxy, Any], value: Union[Self, Any]) -> None:
---> 89     self.node.create(
     90         target=operator.setitem,
     91         args=[self.node, key, value],
     92     )

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:270, in Node.create(self, target, proxy_value, args, kwargs, name)
    267     return value
    269 # Otherwise just create the Node on the Graph like normal.
--> 270 return self.graph.create(
    271     target=target,
    272     name=name,
    273     proxy_value=proxy_value,
    274     args=args,
    275     kwargs=kwargs,
    276 )

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:113, in Graph.create(self, *args, **kwargs)
    106 def create(self, *args, **kwargs) -> Proxy:
    107     """Creates a Node directly on this `Graph` and returns its `Proxy`.
    108
    109     Returns:
    110         Proxy: `Proxy` for newly created `Node`.
    111     """
--> 113     return self.proxy_class(Node(*args, graph=self, **kwargs))

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:96, in Node.__init__(self, target, graph, proxy_value, args, kwargs, name)
     93 # If theres an alive Graph, add it.
     94 if self.attached():
---> 96     self.graph.add(self)

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:131, in Graph.add(self, node)
    128 # If we're validating and the user did not provide a proxy_value, execute the given target with meta proxy values to compute new proxy_value.
    129 if self.validate and node.proxy_value is inspect._empty:
--> 131     node.proxy_value = validate(node.target, *node.args, **node.kwargs)
    133 # Get name of target.
    134 name = (
    135     node.target
    136     if isinstance(node.target, str)
    137     else node.target.__name__
    138 )

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:20, in validate(target, *args, **kwargs)
     14 with FakeTensorMode(
     15     allow_non_fake_inputs=True,
     16     shape_env=ShapeEnv(assume_static_by_default=True),
     17 ) as fake_mode:
     18     with FakeCopyMode(fake_mode):
---> 20         with GlobalTracingContext.exit_global_tracing_context():
     22             args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
     24             return target(
     25                 *args,
     26                 **kwargs,
     27             )

File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:330, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
    326 GlobalTracingContext.PATCHER.__enter__()
    328 if isinstance(exc_val, BaseException):
--> 330     raise exc_val

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:24, in validate(target, *args, **kwargs)
     20 with GlobalTracingContext.exit_global_tracing_context():
     22     args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
---> 24     return target(
     25         *args,
     26         **kwargs,
     27     )

File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:312, in GlobalTracingContext.GlobalTracingTorchHandler.__torch_function__(self, func, types, args, kwargs)
    305 if "_VariableFunctionsClass" in func.__qualname__:
    306     return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
    307         func,
    308         *args,
    309         **kwargs
    310     )
--> 312 return func(*args, **kwargs)

IndexError: index 10 is out of bounds for dimension 1 with size 10

Ah of course, we needed to index at number_of_tokens - 1 not number_of_tokens.

How was nnsight able to catch this error?

If scan and validate are enabled, our input is run though the model, but under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight to record the shapes and data types of module inputs and outputs! The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight also attempts to execute the request on the “fake” values we recorded.

“Scanning” is what we call running “fake” inputs throught the model to collect information like shapes and types. “Validating” is what we call trying to execute the intervention proxies with “fake” inputs to see if they work. “Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate.

A word of caution


Some pytorch operations and related libraries don’t work well with fake tensors

If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.


Let’s try again with the correct indexing, and view the shape of the output before leaving the tracing context:

[2]:
with model.trace(input, scan=True, validate=True):

    original_output = model.transformer.h[11].output[0].clone().save()

    # we want to modify the hidden states for the last token
    model.transformer.h[11].output[0][:, number_of_tokens-1, :] = 0

    modified_output = model.transformer.h[11].output[0].save()

print("\nOriginal Output: ", original_output[0][-1])
print("Modified Output: ", modified_output[0][-1])
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

Original Output:  tensor([ 6.6286e+00,  1.7258e+00,  4.7969e+00, -3.8255e+00,  1.0698e+00,
        -1.4242e+00,  9.2749e+00,  6.0404e+00, -3.2988e+00, -2.7030e+00,
        -3.9210e-01, -5.5507e-01,  6.4831e+00,  1.4496e+00, -4.2496e-01,
        -9.4764e+00, -8.5587e-01,  4.8417e+00,  1.7383e+00, -1.9535e+01,
        -2.1625e+00, -5.4659e+00,  7.9305e-02, -1.2014e+00,  7.6166e-01,
         1.3293e+00, -8.1797e-01, -6.6870e+00,  2.9511e+00, -3.3648e+00,
        -3.5362e+00, -3.2539e+00,  3.4988e+00,  2.2232e+00,  3.4482e+00,
        -4.8883e+00,  6.5206e+01,  8.3218e-01, -3.5060e-01, -1.2041e+00,
        -1.8520e+00,  2.3440e+00,  3.0114e+00,  6.7492e+00,  4.4499e+00,
         6.5314e+00, -3.0311e+00,  4.3609e+00,  6.4801e-01,  1.6725e+00,
         3.0538e+00,  2.5054e+00, -1.9737e+00, -1.5169e+00, -3.9845e+00,
         3.4548e+00, -1.7004e+00, -2.8162e+00, -2.5651e-01, -1.8362e+00,
        -7.5023e+00,  1.9528e+00,  6.4438e-01,  8.7818e-01, -1.0992e+02,
         8.8575e+00,  2.1478e-01, -6.6564e+00, -3.2905e-01, -2.3264e-01,
        -4.9211e+00, -6.6120e-02, -3.6601e+00, -1.5331e+00, -1.2056e+00,
         5.6827e+00,  6.7709e+00, -5.1225e-01, -3.9829e+00,  2.9273e+00,
        -5.6971e+00,  1.6272e+00,  3.2709e+00, -2.0965e+00, -1.4081e+00,
        -2.4534e+00, -4.4642e+00, -1.0931e+02, -3.8111e+00,  2.1471e-01,
        -6.9692e-01, -1.1457e+00, -1.1235e+01,  5.2517e+00, -4.2227e+00,
        -3.2003e+00, -7.1090e+00,  1.8102e+00, -2.4567e+00,  7.7879e-01,
        -6.2834e+00,  3.7080e+00, -1.6301e+00,  6.9053e-01, -2.9357e-02,
        -5.7841e-01, -2.0679e+00, -2.0271e+00, -2.0579e+00, -5.7195e+00,
         5.1443e-01,  1.7420e+00,  1.3746e+00,  3.5129e+00,  8.9945e-01,
        -4.1595e-01, -1.5102e+00, -1.2280e+00,  3.4264e+00, -7.5586e+00,
         2.8480e+00,  6.8030e+00, -1.3625e+00,  1.1234e+01, -1.5630e+00,
         2.7383e+00,  3.4384e+00,  1.0834e+01,  5.3671e-01,  8.9106e-01,
        -6.9593e+00,  1.0443e+00, -3.2028e-01,  1.1285e+01,  3.6665e+00,
         3.1522e+00,  2.0780e+00, -1.5473e+00, -2.9861e+01, -3.5902e+00,
        -4.2159e+00,  7.0041e-03,  8.4291e+00, -3.5786e+00,  3.3004e+00,
         1.3246e+00,  2.4886e+00, -3.1515e+00,  2.5345e+00, -3.1293e+00,
         5.2794e+00, -6.1508e+00,  9.1936e+00, -1.7968e+00, -3.2526e+00,
        -1.0222e+00,  7.4691e+00, -3.1648e+00, -1.6389e+00, -9.7188e-01,
        -2.1339e+00, -5.8595e+00, -4.7614e+00, -4.3966e+00, -1.0889e+00,
        -1.5380e+00,  7.0410e+00, -1.5772e+01, -2.4345e-01,  3.2805e+00,
         5.1078e+00, -4.2193e+00, -2.4413e+00, -1.2237e-01, -9.5395e+00,
        -3.3000e+00,  4.7484e+00,  2.1002e+00,  2.5656e+00, -4.1450e+00,
         1.1324e+01, -9.2751e-01, -8.7061e-03, -1.3499e+00, -1.0883e+00,
         1.2036e+00,  6.4077e-01, -2.0958e+00, -9.4460e-01, -4.6134e+00,
        -4.8703e+00,  3.2674e+00, -3.2317e+00,  5.0362e+00, -3.1834e+00,
        -1.3516e+01,  7.0807e+00,  2.0336e+00,  2.0479e+00, -1.9521e+00,
        -8.9104e+00,  3.0803e+00, -3.2048e+00, -9.7705e-01,  3.8135e+00,
         2.4048e-01,  1.3258e+00, -7.1608e-01,  1.2787e+00, -9.8557e-02,
         5.1077e+00, -4.0518e+00, -2.6806e-03, -7.2934e-01, -5.4432e+00,
         3.5619e+00,  6.1031e+00, -7.2877e-01, -4.0819e+00,  2.9329e+00,
         3.8585e+00, -2.9784e+00,  1.1124e+00, -8.2287e+00,  2.7348e+00,
         6.0236e-01, -2.4054e+00, -3.5393e+00, -1.5170e+00, -7.5092e-01,
        -4.3856e+00,  5.0673e+00,  1.6784e+01,  3.3701e+00, -6.4999e+00,
         7.3039e+00,  7.5358e+00, -1.9126e+00,  2.1336e+00, -3.2421e+00,
        -4.4454e+00,  2.0309e+00,  1.3034e+00, -5.0879e+00,  4.4193e+00,
         1.1515e+01,  7.5885e-01,  4.5374e+00, -1.0041e+01,  2.4802e+00,
        -1.9640e+00,  8.6382e+00, -2.9521e-01, -2.5199e+00, -3.1697e+00,
        -4.1011e+00,  2.9947e+00,  2.5317e-01,  3.3526e+00, -8.4459e-01,
        -1.6096e+00,  4.6977e+00, -2.5488e+00, -3.9472e+00, -2.5825e+00,
         1.1431e+00, -9.7997e+01, -6.4164e+00, -1.7173e+00,  4.5707e+00,
         2.2898e+00, -4.0544e+00,  8.1279e+00,  4.0761e+00,  3.0572e+00,
         1.4155e+00,  3.1557e+00,  2.0821e+00,  1.4200e+00,  2.1833e-01,
         4.3272e+00, -3.4302e+00, -1.4085e+00,  2.8076e-01, -8.1994e-01,
        -6.6751e+00,  2.7346e+00,  8.0669e+00, -8.9313e-01,  5.0663e-01,
         2.6838e+00, -1.5756e+00,  1.8478e+00,  1.4335e+00, -3.1085e+00,
        -3.5366e+00, -2.3190e+00, -2.4223e+00,  4.6886e-01,  2.7125e+00,
        -4.6715e-01,  1.9403e+00,  1.6051e+00,  4.9405e+00,  5.9342e+00,
         6.0049e+00,  1.0645e+00,  3.4900e+00,  4.0834e+00,  7.7649e-01,
         2.0025e+00,  2.8585e+00,  4.6692e-01, -4.4438e-01, -5.1792e+01,
        -3.7743e+00,  1.0210e+00, -4.8536e-01, -3.3734e+00,  1.2150e+00,
         2.1420e+00,  6.3237e-01, -2.1060e-01, -1.8375e-01, -5.4864e-01,
         6.0235e+00,  8.7014e+00, -7.0699e+00,  3.2431e+00,  1.7484e+00,
        -7.2326e+00,  3.2734e+00,  3.1882e+00,  6.0536e+00,  9.2431e-01,
        -5.6773e+00, -4.9881e+00,  1.4860e+00, -3.0162e+00, -3.5636e+00,
        -8.4085e-01, -1.0074e+00, -1.6067e+01, -1.3144e+01,  1.1818e-01,
         6.4546e+00, -1.0912e+00,  1.1206e+01, -1.3724e+00,  2.7155e+00,
        -7.4178e+00, -1.2956e+00, -1.6902e+00,  9.5223e+00,  2.4952e+00,
        -3.7531e+00,  5.1116e+00,  3.8512e+00,  3.7123e+00,  7.6355e-01,
        -3.8020e+00,  1.8983e+00,  4.3574e+01, -7.3467e+00, -2.8905e+00,
        -4.3620e+00,  3.8599e+00,  1.2110e+00,  2.5844e+00, -4.4734e+00,
        -2.9198e+00,  2.9884e+00, -1.7416e+00, -1.6009e+02, -1.6512e+01,
        -9.0772e-02,  6.8859e+00, -3.4945e+00,  6.9834e+00,  6.2670e+00,
         5.6965e-01,  3.0969e-02,  2.9006e+00,  8.9760e-01, -2.3006e+00,
         8.3767e-01, -2.9463e+00, -4.3662e+00, -2.5842e+00,  9.1415e+00,
        -5.5953e+00, -4.7413e+00, -3.7179e+00,  2.1987e+01,  1.8265e+00,
        -8.9532e-01,  4.1599e+00, -3.9978e+00, -3.8641e+00, -2.0673e+00,
         3.0306e+00, -2.1554e+00, -2.0086e+00,  4.8663e+00,  2.8928e-01,
        -4.1998e+00,  8.5062e+00,  4.2103e+00,  1.7206e+01, -3.5376e+00,
         2.5671e-01,  2.2228e-01,  6.6561e+00, -4.8118e+00, -3.6561e+00,
        -6.8825e-01, -4.4886e+00, -9.2805e+00, -3.8699e-01, -1.3017e+00,
         2.3756e+00,  4.2634e+00, -3.6821e+00, -6.7862e-01, -3.6553e+00,
        -3.7938e+00, -3.6224e+00,  1.7574e+00,  3.0822e+00, -3.9101e-01,
         7.3470e+01, -5.4396e-01,  9.8533e-01, -3.3989e+00, -1.3025e+00,
        -2.0689e+00,  2.5526e+00, -2.0326e+00,  3.0746e+00, -2.0950e+00,
         4.1649e+00,  6.4935e-01, -7.1394e+00, -1.0906e+00, -7.1690e+00,
         3.2226e+00, -1.2760e+00,  1.4431e+02,  6.0973e+00, -1.3528e+00,
        -4.9426e+00, -8.5969e-01, -8.5405e+00,  5.7662e+00, -2.4922e+00,
        -7.4357e+00, -7.8989e+00,  4.2059e+00, -3.9071e+00, -2.3410e+00,
        -1.0605e+00, -1.5846e+00, -6.7736e+00,  5.0953e-01,  1.5078e+01,
        -3.7121e-02,  3.4718e+00, -2.8174e+00, -3.7044e+00,  1.0058e+01,
        -2.1485e+00,  3.8376e+00, -8.5943e-01,  1.8444e+00,  3.8978e+00,
         3.0776e+00, -1.0642e+00,  4.3514e+00,  2.2874e+00,  5.2097e+00,
        -9.6646e+01,  1.6328e+02, -1.2562e-01,  1.4275e+00,  1.6121e+00,
        -8.4209e+00,  1.1070e+00, -4.9542e+00,  1.7836e+00, -1.0035e+01,
        -2.0094e+00,  2.7393e+00,  5.4366e-01, -7.8556e-01, -3.2903e-01,
         5.5690e+00,  1.6298e+02,  1.2884e+00,  6.0525e+00,  2.3517e+00,
        -5.1991e+00, -8.9318e+00,  3.2929e+00, -1.2494e+01,  4.4547e+00,
        -4.6449e-01,  1.0015e+00,  3.5933e-01,  2.1539e+00,  8.7483e-01,
        -1.9571e+00,  4.6167e+00, -4.1041e+00,  4.1621e+00, -3.2981e+00,
         5.2509e-02,  6.9099e+00,  1.9714e-01,  8.3853e+00,  1.2267e+00,
        -3.2992e+00,  4.8448e+00,  2.9776e+00, -2.0985e-01,  1.9363e+00,
        -3.0710e+00, -4.9856e+00,  9.9332e-01,  1.8208e+00,  6.0879e+00,
         3.6373e+00, -2.4278e-01,  2.3095e+00, -1.1847e+00,  5.4024e+00,
         4.2026e+00,  2.0059e-01, -6.7174e+00, -6.0502e+00, -7.2487e+00,
         4.4279e+00, -2.2406e+00,  5.9508e+00,  6.9945e+00,  1.2061e+00,
        -3.2929e+00, -1.5007e+00,  2.5605e+00, -1.6327e+00,  6.6576e+00,
         4.1315e+00, -2.5557e+00,  2.7006e+00,  6.6053e+00, -1.9167e+00,
        -1.5669e+00, -2.4749e+00,  4.2183e-01,  2.4309e+00,  1.5208e-01,
        -9.1254e+00, -3.7077e+00,  4.5885e+00, -1.9774e+00,  1.3449e+01,
         3.4506e+00, -7.2922e+00, -5.2381e-01,  1.1032e+00,  2.7615e+00,
        -6.3168e+00, -3.8688e-02,  6.5302e+00, -3.0305e-01,  1.4317e+00,
        -3.4285e+00, -4.6842e+00, -3.0345e+00,  5.0907e+00, -2.3043e+00,
         4.8652e+00, -3.4368e+00, -8.0064e+00, -6.0210e+00,  5.9621e+00,
         1.6881e+00, -1.6544e+00, -6.9820e+00,  7.2363e-01, -7.4828e+00,
        -1.5182e+00, -1.0165e+00, -4.8026e+00, -2.3627e+00,  2.1291e+00,
        -6.0510e+00, -4.2262e+00, -2.6994e+00,  2.1767e-01, -3.2933e+00,
         6.8151e+00, -2.8563e+00, -1.1988e+00,  1.6882e+00,  1.8934e+00,
         9.4241e+00,  9.5913e-01,  2.1459e+00,  1.3060e+01, -4.9042e+00,
        -5.1736e+00,  8.4661e+00,  7.1147e-01, -1.0368e+01,  3.1779e+00,
         1.8113e+00,  7.0269e+00,  2.9996e+00, -2.8506e+00, -7.3495e-01,
        -8.3824e+00, -2.6062e+00,  5.2471e+00,  6.6041e+00, -8.8814e-01,
         1.3164e-01,  7.5559e-01,  6.3661e+00,  6.7891e-01, -1.6838e+00,
         3.9709e-02,  4.6224e-01, -5.2465e+00,  5.9410e+00, -4.3000e+00,
        -4.3974e+00, -6.8152e+00, -4.7801e+00, -4.4493e+00, -4.0901e-01,
         1.5374e+01,  2.9918e+00,  1.8610e+00, -3.2574e+00, -3.8129e+00,
        -5.0761e+00,  9.2315e+00, -1.0814e+00,  2.8385e-01, -3.2490e+00,
         3.4800e+00, -5.8052e+00,  1.3061e+00, -2.8434e+00,  6.5931e-02,
         3.2917e-02, -1.9720e+00, -3.3967e+00, -7.4260e+00,  6.1623e+00,
        -4.1470e+00, -4.2324e+00,  9.5378e-01, -1.0738e+00, -1.3570e-01,
        -2.2882e+00,  4.0768e+00,  2.7377e-01, -1.5838e+00,  1.1146e+01,
         6.9866e+00,  1.3019e+01, -1.8275e+00, -6.4068e+00,  1.7844e+00,
        -8.6919e-01, -2.9575e+00,  4.8200e-01, -6.8251e+00, -2.5988e+00,
        -8.7482e+00,  8.0484e-01,  2.6259e+00, -4.5771e+00, -5.8153e-01,
         4.7844e+00, -3.3178e+00,  9.6060e+00,  4.9834e+00,  5.5188e-01,
        -3.0826e-01, -2.0525e+00, -4.3834e+00,  5.7229e+00,  3.6664e+00,
         4.6413e-02,  1.0723e+00, -5.3191e+00,  1.4131e+00, -9.6075e-01,
        -3.2504e+00, -2.3741e+00,  2.7716e+00, -1.4227e+00, -9.3925e-01,
        -7.5700e+00, -3.7153e+00,  1.3690e+00,  1.1079e+01, -8.3346e+00,
        -4.9114e+00, -2.0677e+00, -1.0595e+00, -1.8981e+00,  6.1321e+00,
         1.9002e+00, -1.9221e-02, -1.1126e+00, -1.0669e+01, -3.0803e+00,
        -3.5232e+00,  1.5203e+00, -4.8918e+00,  2.8023e+00,  7.7897e+00,
         3.6398e+00, -6.3831e-01, -5.3203e+00,  1.7896e+00, -4.1591e+00,
        -4.2745e-01,  3.0674e+00,  4.1897e+00, -2.5548e+00,  3.0816e-01,
        -3.4219e-02, -3.2085e+00,  3.0132e+00, -5.1880e+00,  8.2857e-01,
        -9.3511e-01, -2.3891e-01,  3.3661e+00,  2.3398e+00, -7.5823e+00,
        -1.1914e+01,  5.7218e+00, -2.5562e+00, -2.9274e+00,  3.1782e+00,
        -1.9518e+00,  4.3836e+00,  3.5140e+00, -2.8808e+00, -3.9965e+00,
         1.6073e+00, -5.4745e+00, -1.4041e+00, -2.8542e+00, -2.0609e+00,
         8.0068e-02, -2.6370e+00, -4.4448e+00, -6.0635e+00, -4.2056e+00,
         7.6714e+00,  3.0683e+00,  2.0481e+00], device='mps:0',
       grad_fn=<SelectBackward0>)
Modified Output:  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='mps:0', grad_fn=<SelectBackward0>)

We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:

[3]:
with model.scan(input):

    dim = model.transformer.h[11].output[0].shape[-1]

print(dim)
768