Open In Colab

Walkthrough#

The API for a transparent science on black-box AI#

In this era of large-scale deep learning, the most interesting AI models are massive black boxes that are hard to run. Ordinary commercial inference service APIs let us interact with huge models, but they do not let us access model internals.

The nnsight library is different: it provides full access to all the neural network internals. When used together with a remote service like the National Deep Inference Fabric (NDIF), it makes possible to run complex experiments on huge open models easily, with fully transparent access.

Our team wants to enable entire labs and independent researchers alike, as we believe a large, passionate, and collaborative community will produce the next big insights on a profoundly important field.

1 First, let’s start small#

Setup#

Install nnsight:

pip install nnsight

Tracing Context#

To demonstrate the core functionality and syntax of nnsight, we’ll define and use a tiny two layer neural network.

Our little model here is composed of two submodules – linear layers ‘layer1’ and ‘layer2’. We specify the sizes of each of these modules and create some complementary example input.

[1]:
from collections import OrderedDict
import torch

input_size = 5
hidden_dims = 10
output_size = 2

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, output_size)),
        ]
    )
).requires_grad_(False)

The core object of the nnsight package is NNsight. This wraps around a given PyTorch model to enable investigation of its internal parameters.

[2]:
import nnsight
from nnsight import NNsight

tiny_model = NNsight(net)
/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Printing a PyTorch model shows a named hierarchy of modules which is very useful when accessing sub-components directly. NNsight reflect the same hierarchy and can be similarly printed.

[3]:
print(tiny_model)
Sequential(
  (layer1): Linear(in_features=5, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=2, bias=True)
)

Before we actually get to using the model we just created, let’s talk about Python contexts.

Python contexts define a scope using the with statement and are often used to create some object, or initiate some logic, that you later want to destroy or conclude.

The most common application is opening files as in the following example:

with open('myfile.txt', 'r') as file:
  text = file.read()

Python uses the with keyword to enter a context-like object. This object defines logic to be run at the start of the with block, as well as logic to be run when exiting. When using with for a file, entering the context opens the file and exiting the context closes it. Being within the context means we can read from the file.

Simple enough! Now we can discuss how nnsight uses contexts to enable intuitive access into the internals of a neural network.

The main tool with nnsight is a context for tracing.

We enter the tracing context by calling model.trace(<input>) on an NNsight model, which defines how we want to run the model. Inside the context, we will be able to customize how the neural network runs. The model is actually run upon exiting the tracing context.

[4]:
# random input
input = torch.rand((1, input_size))

with tiny_model.trace(input) as tracer:
    pass

But where’s the output? To get that, we’ll have to learn how to request it from within the tracing context.

Getting#

Earlier, when we wrapped our little neural net with the NNsight class. This added a couple properties to each module in the model (including the root model itself). The two most important ones are .input and .output.

model.input
model.output

The names are self explanatory. They correspond to the inputs and outputs of their respective modules during a forward pass of the model. We can use these attributes inside the with block.

However, it is important to understand that the model is not executed until the end of the tracing context. How can we access inputs and outputs before the model is run? The trick is deferred execution.

.input and .output are Proxies for the eventual inputs and outputs of a module. In other words, when we access model.output what we are communicating to nnsight is, “When you compute the output of model, please grab it for me and put the value into its corresponding Proxy object. Let’s try it:

[5]:
with tiny_model.trace(input) as tracer:

    output = tiny_model.output

print(output)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 5
      1 with tiny_model.trace(input) as tracer:
      3     output = tiny_model.output
----> 5 print(output)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:70, in Proxy.__str__(self)
     66 def __str__(self) -> str:
     68     if not self.node.attached:
---> 70         return str(self.value)
     72     return f"{type(self).__name__} ({self.node.target.__name__})"

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/proxy.py:64, in Proxy.value(self)
     56 @property
     57 def value(self) -> Any:
     58     """Property to return the value of this proxy's node.
     59
     60     Returns:
     61         Any: The stored value of the proxy, populated during execution of the model.
     62     """
---> 64     return self.node.value

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/graph/node.py:143, in Node.value(self)
    133 """Property to return the value of this node.
    134
    135 Returns:
   (...)
    139     ValueError: If the underlying ._value is inspect._empty (therefore never set or was destroyed).
    140 """
    142 if not self.done:
--> 143     raise ValueError("Accessing value before it's been set.")
    145 return self._value

ValueError: Accessing value before it's been set.

Oh no an error! “Accessing value before it’s been set.”

Why doesn’t our output have a value?

Proxy objects will only have their value at the end of a context if we call .save() on them. This helps to reduce memory costs. Adding .save() fixes the error:

[67]:
with tiny_model.trace(input) as tracer:

    output = tiny_model.output.save()

print(output)
tensor([[-0.1301, -0.4906]])

Success! We now have the model output. We just completed out first intervention using nnsight.

Each time we access a module’s input or output, we create an intervention in the neural network’s forward pass. Collectively these requests form the intervention graph. We call the process of executing it alongside the model’s normal computation graph, interleaving.

On Model output


If we don’t need to access anything other than the final model output, we can call the tracing context with trace=False and not use it as a context. This could be especially useful for easy remote inference.

output = model.trace(<inputs>, trace=False)

Just like we saved the output of the model as a whole, we can save the output of any of its submodules. We use normal Python attribute syntax. We can discover how to access them by name by printing out the model:

[68]:
print(tiny_model)
Sequential(
  (layer1): Linear(in_features=5, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=2, bias=True)
)

Let’s access the output of the first layer (which we’ve named ‘layer1’):

[69]:
with tiny_model.trace(input) as tracer:

    l1_output = tiny_model.layer1.output.save()

print(l1_output)
tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])

Let’s do the same for the input of layer2. While we’re at it, let’s also drop the as tracer, as we won’t be needing the tracer object itself for a few sections:

[70]:
with tiny_model.trace(input):

    l2_input = tiny_model.layer2.input.save()

print(l2_input)
tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])

On module inputs


Notice how the value for l2_input, is just a single tensor. By default, the .input attribute of a module will return the first tensor input to the module.

We can also access the full input to a module by using the .inputs attribute which will return the values in the form of:

tuple(tuple(args), dictionary(kwargs))

Where the first index of the tuple is itself a tuple of all positional arguments, and the second index is a dictionary of the keyword arguments.


Until now we were saving the output of the model and its submodules within the Trace context to then print it after exiting the context. We will continuing doing this in the rest of the tutorial since it’s a good practice to save the computation results for later analysis.

However, we can also log the outputs of the model and its submodules within the Trace context. This is useful for debugging and understanding the model’s behavior while saving memory. Let’s see how to do this:

[71]:
with tiny_model.trace(input) as tracer:
  tracer.log("Layer 1 - out: ", tiny_model.layer1.output)
Layer 1 - out:  tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])

Functions, Methods, and Operations#

Now that we can access activations, we also want to do some post-processing on it. Let’s find out which dimension of layer1’s output has the highest value.

We could do this by calling torch.argmax(...) after the tracing context or we can just leverage the fact that nnsight handles Pytorch functions and methods within the tracing context, by creating a Proxy request for it:

[72]:
with tiny_model.trace(input):

    # Note we don't need to call .save() on the output,
    # as we're only using its value within the tracing context.
    l1_output = tiny_model.layer1.output

    # We do need to save the argmax tensor however,
    # as we're using it outside the tracing context.
    l1_amax = torch.argmax(l1_output, dim=1).save()

print(l1_amax[0])
tensor(7)

Nice! That worked seamlessly, but hold on, how come we didn’t need to call .value[0] on the result? In previous sections, we were just being explicit to get an understanding of Proxies and their value. In practice, however, nnsight knows that when outside of the tracing context we only care about the actual value, and so printing, indexing, and applying functions all immediately return and reflect the data in .value. So for the rest of the tutorial we won’t use it.

The same principles work for Pytorch methods and all operators as well:

[73]:
with tiny_model.trace(input):

    value = (tiny_model.layer1.output.sum() + tiny_model.layer2.output.sum()).save()

print(value)
tensor(0.9377)

The code block above is saying to nnsight, “Run the model with the given input. When the output of tiny_model.layer1 is computed, take its sum. Then do the same for tiny_model.layer2. Now that both of those are computed, add them and make sure not to delete this value as I wish to use it outside of the tracing context.”

Custom Functions#

Everything within the tracing context operates on the intervention graph. Therefore, for nnsight to trace a function it must also be a part of the intervention graph.

Out-of-the-box nnsight supports PyTorch functions and methods, all operators, as well the einops library. We don’t need to do anything special to use them. But what do we do if we want to use custom functions? How do we add them to the intervention graph?

Enter nnsight.apply(). It allows us to add new functions to the intervention graph. Let’s see how it works:

[74]:
# Take a tensor and return the sum of its elements
def tensor_sum(tensor):
    flat = tensor.flatten()
    total = 0
    for element in flat:
        total += element.item()

    return torch.tensor(total)

with tiny_model.trace(input) as tracer:

    # Specify the function name and its arguments (in a comma-separated form) to add to the intervention graph
    custom_sum = nnsight.apply(tensor_sum, tiny_model.layer1.output).save()
    sum = tiny_model.layer1.output.sum()
    sum.save()


print(custom_sum, sum)
tensor(1.5584) tensor(1.5584)

nnsight.apply() executes the function it wraps and returns its output as a Proxy object. We can then use this Proxy object as we would any other.

The applications of nnsight.apply are wide: it can be used to wrap any custom function or functions from libraries that nnsight does not support out-of-the-box.

Setting#

Getting and analyzing the activations from various points in a model can be really insightful, and a number of ML techniques do exactly that. However, often we not only want to view the computation of a model, but also to influence it.

To demonstrate the effect of editing the flow of information through the model, let’s set the first dimension of the first layer’s output to 0. NNsight makes this really easy using the ‘=’ operator:

[75]:
with tiny_model.trace(input):

    # Save the output before the edit to compare.
    # Notice we apply .clone() before saving as the setting operation is in-place.
    l1_output_before = tiny_model.layer1.output.clone().save()

    # Access the 0th index of the hidden state dimension and set it to 0.
    tiny_model.layer1.output[:, 0] = 0

    # Save the output after to see our edit.
    l1_output_after = tiny_model.layer1.output.save()

print("Before:", l1_output_before)
print("After:", l1_output_after)
Before: tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])
After: tensor([[ 0.0000,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])

Seems our change was reflected. Now let’s do the same for the last dimension:

[76]:
with tiny_model.trace(input):

    # Save the output before the edit to compare.
    # Notice we apply .clone() before saving as the setting operation is in-place.
    l1_output_before = tiny_model.layer1.output.clone().save()

    # Access the last index of the hidden state dimension and set it to 0.
    tiny_model.layer1.output[:, hidden_dims] = 0

    # Save the output after to see our edit.
    l1_output_after = tiny_model.layer1.output.save()

print("Before:", l1_output_before)
print("After:", l1_output_after)
Traceback (most recent call last):
  File "/Users/emmabortz/Documents/Projects/nnsight/src/nnsight/tracing/graph/node.py", line 297, in execute
    output = self.target(*args, **kwargs)
IndexError: index 10 is out of bounds for dimension 1 with size 10

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/folders/rx/0nl_h2cd54q90chs9hf0n5qr0000gn/T/ipykernel_80599/3404137504.py", line 8, in <module>
    tiny_model.layer1.output[:, hidden_dims] = 0

NNsightError: index 10 is out of bounds for dimension 1 with size 10

Oh no, we are getting an error! Ah of course, we needed to index at hidden_dims - 1 not hidden_dims.

If you’ve been using nnsight, you are probably familiar with error messages that can be quite difficult to troubleshoot. In nnsight 0.4 we’ve now improved error messaging to be descriptive and line-specific, as you should see in the above example!

Old NNsight error messaging

If you’ve been using NNsight prior to the NNsight 0.4 release, you will be familiar with the following non-descriptive error messaging. If you choose to turn off NNsight 0.4’s new error messaging feature, this is how errors within the tracing context will appear.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/usr/local/lib/python3.11/dist-packages/nnsight/tracing/Node.py in execute(self)
    379                 # Call the target to get value.
--> 380                 output = self.target(*args, **kwargs)
    381

IndexError: index 10 is out of bounds for dimension 1 with size 10

The above exception was the direct cause of the following exception:

IndexError                                Traceback (most recent call last)
20 frames
<ipython-input-16-5c81de91fb1f> in <cell line: 0>()
----> 1 with tiny_model.trace(input):
      2
      3     # Save the output before the edit to compare.
      4     # Notice we apply .clone() before saving as the setting operation is in-place.
      5     l1_output_before = tiny_model.layer1.output.clone().save()

/usr/local/lib/python3.11/dist-packages/nnsight/contexts/Tracer.py in __exit__(self, exc_type, exc_val, exc_tb)
    100
    101
--> 102         super().__exit__(exc_type, exc_val, exc_tb)
    103
    104     def invoke(self, *inputs: Any, **kwargs) -> Invoker:

/usr/local/lib/python3.11/dist-packages/nnsight/contexts/GraphBasedContext.py in __exit__(self, exc_type, exc_val, exc_tb)
    215             raise exc_val
    216
--> 217         self.backend(self)
    218
    219     ### BACKENDS ########

/usr/local/lib/python3.11/dist-packages/nnsight/contexts/backends/LocalBackend.py in __call__(self, obj)
     25     def __call__(self, obj: LocalMixin):
     26
---> 27         obj.local_backend_execute()

/usr/local/lib/python3.11/dist-packages/nnsight/contexts/Tracer.py in local_backend_execute(self)
    144         self.graph.execute()
    145
--> 146         self.model.interleave(
    147             self.model._execute,
    148             self.graph,

/usr/local/lib/python3.11/dist-packages/nnsight/models/NNsightModel.py in interleave(self, fn, intervention_graph, *inputs, **kwargs)
    467         module_paths = InterventionProtocol.get_interventions(intervention_graph).keys()
    468
--> 469         with HookHandler(
    470             self._model,
    471             list(module_paths),

/usr/local/lib/python3.11/dist-packages/nnsight/intervention.py in __exit__(self, exc_type, exc_val, exc_tb)
    579
    580         if isinstance(exc_val, Exception):
--> 581             raise exc_val
    582
    583

/usr/local/lib/python3.11/dist-packages/nnsight/models/NNsightModel.py in interleave(self, fn, intervention_graph, *inputs, **kwargs)
    478         ):
    479             try:
--> 480                 fn(*inputs, **kwargs)
    481             except protocols.EarlyStopProtocol.EarlyStopException:
    482                 # TODO: Log.

/usr/local/lib/python3.11/dist-packages/nnsight/models/NNsightModel.py in _execute(self, *prepared_inputs, **kwargs)
    585             pass
    586
--> 587         return self._model(
    588             *prepared_inputs,
    589             **kwargs,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1842
   1843         try:
-> 1844             return inner()
   1845         except Exception:
   1846             # run always called hooks if they have not already been run

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
   1788                 args = bw_hook.setup_input_hook(args)
   1789
-> 1790             result = forward_call(*args, **kwargs)
   1791             if _global_forward_hooks or self._forward_hooks:
   1792                 for hook_id, hook in (

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/container.py in forward(self, input)
    248     def forward(self, input):
    249         for module in self:
--> 250             input = module(input)
    251         return input
    252

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1842
   1843         try:
-> 1844             return inner()
   1845         except Exception:
   1846             # run always called hooks if they have not already been run

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
   1801                         hook_result = hook(self, args, kwargs, result)
   1802                     else:
-> 1803                         hook_result = hook(self, args, result)
   1804
   1805                     if hook_result is not None:

/usr/local/lib/python3.11/dist-packages/nnsight/intervention.py in output_hook(module, input, output, module_path)
    564
    565                 def output_hook(module, input, output, module_path=module_path):
--> 566                     return self.output_hook(output, module_path)
    567
    568                 self.handles.append(

/usr/local/lib/python3.11/dist-packages/nnsight/models/NNsightModel.py in <lambda>(activations, module_path)
    473                 activations, module_path, "input", intervention_handler
    474             ),
--> 475             output_hook=lambda activations, module_path: InterventionProtocol.intervene(
    476                 activations, module_path, "output", intervention_handler
    477             ),

/usr/local/lib/python3.11/dist-packages/nnsight/intervention.py in intervene(cls, activations, module_path, key, intervention_handler)
    454
    455                 # Value injection.
--> 456                 node.set_value(value)
    457
    458                 # Check if through the previous value injection, there was a 'swap' intervention.

/usr/local/lib/python3.11/dist-packages/nnsight/tracing/Node.py in set_value(self, value)
    408
    409             if listener.fulfilled() and not self.graph.sequential:
--> 410                 listener.execute()
    411
    412         for dependency in self.arg_dependencies:

/usr/local/lib/python3.11/dist-packages/nnsight/tracing/Node.py in execute(self)
    385         except Exception as e:
    386
--> 387             raise type(e)(
    388                 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
    389             ) from e

IndexError: Above exception when execution Node: 'setitem_0' in Graph: '132147685816016'

The error messaging feature can be toggled using nnsight.CONFIG.APP.DEBUG which defaults to true.

Toggle Error Messaging

Turn off debugging:

import nnsight

nnsight.CONFIG.APP.DEBUG = False
nnsight.CONFIG.save()

Turn on debugging:

import nnsight

nnsight.CONFIG.APP.DEBUG = True
nnsight.CONFIG.save()

Now that we know more about NNsight’s error messaging, let’s try our setting operation again with the correct indexing and view the shape of the output before leaving the tracing context:

[77]:
with tiny_model.trace(input):

    # Save the output before the edit to compare.
    # Notice we apply .clone() before saving as the setting operation is in-place.
    l1_output_before = tiny_model.layer1.output.clone().save()

    print(f"Layer 1 output shape: {tiny_model.layer1.output.shape}")

    # Access the last index of the hidden state dimension and set it to 0.
    tiny_model.layer1.output[:, hidden_dims - 1] = 0

    # Save the output after to see our edit.
    l1_output_after = tiny_model.layer1.output.save()

print("Before:", l1_output_before)
print("After:", l1_output_after)
Layer 1 output shape: InterventionProxy (fetch_attr)
Before: tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])
After: tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0000]])

Scan and Validate#

Error codes are helpful, but sometimes you may want to quickly troubleshoot your code without actually running it.

Enter “Scanning” and “Validating”! We can enable this features by setting the scan=True and validate=True flag in the trace method.

“Scanning” runs “fake” inputs throught the model to collect information like shapes and types (i.e., scanning will populate all called .inputs and .outputs).

“Validating” attempts to execute the intervention proxies with “fake” inputs to check if they work (i.e., executes all interventions in your code with fake tensors).

“Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate. Let’s try it out on our example above.

[78]:
# turn on scan and validate
with tiny_model.trace(input, scan=True, validate=True):

    l1_output_before = tiny_model.layer1.output.clone().save()

    # the error is happening here
    tiny_model.layer1.output[:, hidden_dims] = 0

    l1_output_after = tiny_model.layer1.output.save()

print("Before:", l1_output_before)
print("After:", l1_output_after)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[78], line 2
      1 # turn on scan and validate
----> 2 with tiny_model.trace(input, scan=True, validate=True):
      4     l1_output_before = tiny_model.layer1.output.clone().save()
      6     # the error is happening here

File ~/Documents/Projects/nnsight/src/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
     92     self.invoker.__exit__(None, None, None)
     94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     21 from .globals import GlobalTracingContext
     23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/contexts/base.py:72, in Context.__exit__(self, exc_type, exc_val, exc_tb)
     69 graph = self.graph.stack.pop()
     71 if isinstance(exc_val, BaseException):
---> 72     raise exc_val
     74 self.add(graph.stack[-1], graph, *self.args, **self.kwargs)
     76 if self.backend is not None:

Cell In[78], line 7
      4     l1_output_before = tiny_model.layer1.output.clone().save()
      6     # the error is happening here
----> 7     tiny_model.layer1.output[:, hidden_dims] = 0
      9     l1_output_after = tiny_model.layer1.output.save()
     11 print("Before:", l1_output_before)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/proxy.py:126, in Proxy.__setitem__(self, key, value)
    125 def __setitem__(self, key: Union[Self, Any], value: Union[Self, Any]) -> None:
--> 126     self.node.create(
    127         operator.setitem,
    128         self.node,
    129         key,
    130         value,
    131     )

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/node.py:250, in Node.create(self, *args, **kwargs)
    247     return value
    249 # Otherwise just create the Node on the Graph like normal.
--> 250 return self.graph.create(
    251     *args,
    252     **kwargs,
    253 )

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/graph.py:131, in Graph.create(self, target, redirect, *args, **kwargs)
    128 # Redirection.
    129 graph = self.stack[-1] if redirect and self.stack else self
--> 131 return self.proxy_class(self.node_class(target, *args, graph=graph, **kwargs))

File ~/Documents/Projects/nnsight/src/nnsight/intervention/graph/node.py:118, in ValidatingInterventionNode.__init__(self, *args, **kwargs)
    111 super().__init__(*args, **kwargs)
    113 if (
    114     self.attached
    115     and self.fake_value is inspect._empty
    116     and not Protocol.is_protocol(self.target)
    117 ):
--> 118     self.fake_value = validate(self.target, *self.args, **self.kwargs)

File ~/Documents/Projects/nnsight/src/nnsight/intervention/graph/node.py:147, in validate(target, *args, **kwargs)
    141 with FakeTensorMode(
    142     allow_non_fake_inputs=True,
    143     shape_env=ShapeEnv(assume_static_by_default=True),
    144 ) as fake_mode:
    145     with FakeCopyMode(fake_mode):
--> 147         with GlobalTracingContext.exit_global_tracing_context():
    149             if backwards_check(target, *args):
    150                 return None

File ~/Documents/Projects/nnsight/src/nnsight/tracing/contexts/globals.py:100, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
     96 GlobalTracingContext.PATCHER.__enter__()
     98 if isinstance(exc_val, BaseException):
--> 100     raise exc_val

File ~/Documents/Projects/nnsight/src/nnsight/intervention/graph/node.py:156, in validate(target, *args, **kwargs)
    150     return None
    152 args, kwargs = InterventionNode.prepare_inputs(
    153     (args, kwargs), fake=True
    154 )
--> 156 return target(
    157     *args,
    158     **kwargs,
    159 )

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2350, in FakeCopyMode.__torch_function__(self, func, types, args, kwargs)
   2348 else:
   2349     with torch._C.DisableTorchFunctionSubclass():
-> 2350         return func(*args, **kwargs)

IndexError: index 10 is out of bounds for dimension 1 with size 10

The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight also attempts to execute the request on the “fake” values we recorded. Hence, it lets us know if our request is feasible before even running the model. Here is a more detailed example of scan and validate in action!

A word of caution


Some pytorch operations and related libraries don’t work well with fake tensors

If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.


We can also use the .scan() method to get the shape of a module without having to fully run the model. If scan is enabled, our input is run though the model under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight to record the shapes and data types of module inputs and outputs!

[79]:
with tiny_model.scan(input):

    dim = tiny_model.layer1.output.shape[-1]

print(dim)
10

We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:

Gradients#

NNsight also lets us apply backpropagation and access gradients with respect to a loss. Like .input and .output on modules, nnsight exposes .grad on Proxies themselves (assuming they are proxies of tensors):

[80]:
with tiny_model.trace(input):

    # We need to explicitly have the tensor require grad
    # as the model we defined earlier turned off requiring grad.
    tiny_model.layer1.output.requires_grad = True

    # We call .grad on a tensor Proxy to communicate we want to store its gradient.
    # We need to call .save() since .grad is its own Proxy.
    layer1_output_grad = tiny_model.layer1.output.grad.save()
    layer2_output_grad = tiny_model.layer2.output.grad.save()

    # Need a loss to propagate through the later modules in order to have a grad.
    loss = tiny_model.output.sum()
    loss.backward()

print("Layer 1 output gradient:", layer1_output_grad)
print("Layer 2 output gradient:", layer2_output_grad)
Layer 1 output gradient: tensor([[-0.1296,  0.4562, -0.1182, -0.3536,  0.0703,  0.1411, -0.3201, -0.4890,
          0.0196, -0.0452]])
Layer 2 output gradient: tensor([[1., 1.]])

All of the features we learned previously, also apply to .grad. In other words, we can apply operations to and edit the gradients. Let’s zero the grad of layer1 and double the grad of layer2.

[81]:
with tiny_model.trace(input):

    # We need to explicitly have the tensor require grad
    # as the model we defined earlier turned off requiring grad.
    tiny_model.layer1.output.requires_grad = True

    tiny_model.layer1.output.grad[:] = 0
    tiny_model.layer2.output.grad = tiny_model.layer2.output.grad * 2

    layer1_output_grad = tiny_model.layer1.output.grad.save()
    layer2_output_grad = tiny_model.layer2.output.grad.save()

    # Need a loss to propagate through the later modules in order to have a grad.
    loss = tiny_model.output.sum()
    loss.backward()

print("Layer 1 output gradient:", layer1_output_grad)
print("Layer 2 output gradient:", layer2_output_grad)
Layer 1 output gradient: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
Layer 2 output gradient: tensor([[2., 2.]])

Early Stopping#

If we are only interested in a model’s intermediate computations, we can halt a forward pass run at any module level, reducing runtime and conserving compute resources. One examples where this could be particularly useful would if we are working with SAEs - we can train an SAE on one layer and then stop the execution.

[82]:
with tiny_model.trace(input):
   l1_out = tiny_model.layer1.output.save()
   tiny_model.layer1.output.stop()

# get the output of the first layer and stop tracing
print("L1 - Output: ", l1_out)
L1 - Output:  tensor([[ 0.2732,  0.2355, -0.6433,  0.0475,  0.0904,  0.4407, -0.3099,  1.4903,
         -0.0748,  0.0088]])

Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.

As a result, if the forward pass is terminated early any interventions linked to modules beyond that point will be skipped, even if they were defined earlier in the context.

In the example below, the output of layer 2 cannot be accessed since the model’s execution was stopped at layer 1.

[83]:
with tiny_model.trace(input):
   l2_out = tiny_model.layer2.output.save()
   tiny_model.layer1.output.stop()

print("L2 - Output: ", l2_out)
L2 - Output:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[83], line 5
      2    l2_out = tiny_model.layer2.output.save()
      3    tiny_model.layer1.output.stop()
----> 5 print("L2 - Output: ", l2_out)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/proxy.py:70, in Proxy.__str__(self)
     66 def __str__(self) -> str:
     68     if not self.node.attached:
---> 70         return str(self.value)
     72     return f"{type(self).__name__} ({self.node.target.__name__})"

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/proxy.py:64, in Proxy.value(self)
     56 @property
     57 def value(self) -> Any:
     58     """Property to return the value of this proxy's node.
     59
     60     Returns:
     61         Any: The stored value of the proxy, populated during execution of the model.
     62     """
---> 64     return self.node.value

File ~/Documents/Projects/nnsight/src/nnsight/tracing/graph/node.py:143, in Node.value(self)
    133 """Property to return the value of this node.
    134
    135 Returns:
   (...)
    139     ValueError: If the underlying ._value is inspect._empty (therefore never set or was destroyed).
    140 """
    142 if not self.done:
--> 143     raise ValueError("Accessing value before it's been set.")
    145 return self._value

ValueError: Accessing value before it's been set.

Conditional Interventions#

Interventions can also be made conditional.

Inside the tracing context we can specify a new - conditional - context. This context will only execute the interventions within it if the condition is met.

[84]:
with tiny_model.trace(input) as tracer:

  rand_int = torch.randint(low=-10, high=10, size=(1,))

  with tracer.cond(rand_int % 2 == 0):
    tracer.log("Random Integer ", rand_int, " is Even")

  with tracer.cond(rand_int % 2 == 1):
    tracer.log("Random Integer ", rand_int, " is Odd")
Random Integer  tensor([-5])  is Odd

Conditional contexts can also be nested, if we want our interventions to depend on more than one condition at a time.

[21]:
with tiny_model.trace(input) as tracer:

  non_rand_int = 8

  with tracer.cond(non_rand_int > 0):
    with tracer.cond(non_rand_int % 2 == 0):
      tracer.log("Rand Int ", non_rand_int, " is Positive and Even")
Rand Int  8  is Positive and Even

With nnsight 0.4 we can now also use Python if statements within the tracing context to create a conditional context!

Note: Colab behaves a little strangely with this feature the first time you run it - expect some lagging and warnings

[85]:
with tiny_model.trace(input) as tracer:

  rand_int = torch.randint(low=-10, high=10, size=(1,))

  # Since this if statement is inside the tracing context the if will
  # create a conditional context and will only execute the intervention
  # if this condition is met
  if rand_int % 2 == 0:
    tracer.log("Random Integer ", rand_int, " is Even")

  if rand_int % 2 == 1:
    tracer.log("Random Integer ", rand_int, " is Odd")
Random Integer  tensor([2])  is Even

elif statements should also work as if statements within the tracing context:

[86]:
with tiny_model.trace(input) as tracer:

  rand_int = torch.randint(low=-10, high=10, size=(1,))

  # Since this if statement is inside the tracing context the if will
  # create a conditional context and will only execute the intervention
  # if this condition is met
  if rand_int % 2 == 0:
    tracer.log("Random Integer ", rand_int, " is Even")
  elif rand_int % 2 == 1:
    tracer.log("Random Integer ", rand_int, " is Odd")
Random Integer  tensor([-3])  is Odd

Iterative Interventions#

With the iterator context, you can now run an intervention loop at scale. It iteratively executes and updates a single intervention graph. Use a .session() to define the Iterator context and pass in a sequence of items that you want to loop over at each iteration

[87]:
with tiny_model.session() as session:

  li = nnsight.list() # an NNsight built-in list object
  [li.append([num]) for num in range(0, 3)] # adding [0], [1], [2] to the list
  li2 = nnsight.list().save()

  # You can create nested Iterator contexts
  with session.iter(li) as item:
    with session.iter(item) as item_2:
      li2.append(item_2)

print("\nList: ", li2)

List:  [0, 1, 2]

With nnsight 0.4 we can now also use Python for loops within a tracer context at scale.

NOTE: inline for loops (i.e., ``[x for x in <Proxy object>``]) are not currently supported.

[88]:
# New: Using Python for loops for iterative interventions
with tiny_model.session() as session:

    li = nnsight.list()
    [li.append([num]) for num in range(0, 3)]
    li2 = nnsight.list().save()

    # Using regular for loops
    for item in li:
        for item_2 in item: # for loops can be nested!
            li2.append(item_2)

print("\nList: ", li2)

List:  [0, 1, 2]

2️ Bigger#

Now that we have the basics of nnsight under our belt, we can scale our model up and combine the techniques we’ve learned into more interesting experiments.

The NNsight class is very bare bones. It wraps a pre-defined model and does no pre-processing on the inputs we enter. It’s designed to be extended with more complex and powerful types of models, and we’re excited to see what can be done to leverage its features!

However, if you’d like to load a Language Model from HuggingFace with its tokenizer, theLanguageModel subclass greatly simplifies this process.

LanguageModel#

LanguageModel is a subclass of NNsight. While we could define and create a model to pass in directly, LanguageModel includes special support for Huggingface language models, including automatically loading models from a Huggingface ID, and loading the model together with the appropriate tokenizer.

Here is how we can use LanguageModel to load GPT-2:

[6]:
from nnsight import LanguageModel

llm = LanguageModel("openai-community/gpt2", device_map="auto")

print(llm)
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)

On Model Initialization


A few important things to note:

Keyword arguments passed to the initialization of LanguageModel is forwarded to HuggingFace specific loading logic. In this case, device_map specifies which devices to use and its value auto indicates to evenly distribute it to all available GPUs (and CPU if no GPUs available). Other arguments can be found here: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM

When we initialize LanguageModel, we aren’t yet loading the parameters of the model into memory. We are actually loading a ‘meta’ version of the model which doesn’t take up any memory, but still allows us to view and trace actions on it. After exiting the first tracing context, the model is then fully loaded into memory. To load into memory on initialization, you can pass dispatch=True into LanguageModel like LanguageModel('openai-community/gpt2', device_map="auto", dispatch=True).


Let’s now apply some of the features that we used on the small model to GPT-2. Unlike NNsight, LanguageModel does define logic to pre-process inputs upon entering the tracing context. This makes interacting with the model simpler (i.e., you can send prompts to the model without having to directly access the tokenizer).

In the following example, we ablate the value coming from the last layer’s MLP module and decode the logits to see what token the model predicts without influence from that particular module:

[90]:
with llm.trace("The Eiffel Tower is in the city of"):

    # Access the last layer using h[-1] as it's a ModuleList
    # Access the first index of .output as that's where the hidden states are.
    llm.transformer.h[-1].mlp.output[0][:] = 0

    # Logits come out of model.lm_head and we apply argmax to get the predicted token ids.
    token_ids = llm.lm_head.output.argmax(dim=-1).save()

print("\nToken IDs:", token_ids)

# Apply the tokenizer to decode the ids into words after the tracing context.
print("Prediction:", llm.tokenizer.decode(token_ids[0][-1]))

Token IDs: tensor([[ 262,   12,  417, 8765,   11,  257,  262, 3504,  338, 3576]],
       device='mps:0')
Prediction:  London

We just ran a little intervention on a much more complex model with many more parameters! However, we’re missing an important piece of information: what the prediction would have looked like without our ablation.

We could just run two tracing contexts and compare the outputs. However, this would require two forward passes through the model. NNsight can do better than that with batching.

Batching#

Batching is a way to process multiple inputs in one forward pass. To better understand how batching works, we’re going to bring back the Tracer object that we dropped before.

When we call .trace(...), it’s actually creating two different contexts behind the scenes. The first one is the tracing context that we’ve discussed previously, and the second one is the invoker context. The invoker context defines the values of the .input and .output Proxies.

If we call .trace(...) with some input, the input is passed on to the invoker. As there is only one input, only one invoker context is created.

If we call .trace() without an input, then we can call tracer.invoke(input1) to manually create the invoker context with an input, input1. We can also repeatedly call tracer.invoke(...) to create the invoker context for additional inputs. Every subsequent time we call .invoke(...), interventions within its context will only refer to the input in that particular invoke statement.

When exiting the tracing context, the inputs from all of the invokers will be batched together, and they will be executed in one forward pass! To test this out, let’s do the same ablation experiment, but also add a ‘control’ output for comparison:

More on the invoker context


Note that when injecting data to only the relevant invoker interventions, nnsight tries, but can’t guarantee, to narrow the data into the right batch indices. Thus, there are cases where all invokes will get all of the data. Specifically, if the input or output data is stored as an object that is not an arbitrary collection of tensors, it will be broadcasted to all invokes.

Just like .trace(...) created a Tracer object, .invoke(...) creates an Invoker object. For LanguageModel models, the Invoker prepares the input by running a tokenizer on it. Invoker stores pre-processed inputs at invoker.inputs, which can be accessed to see information about our inputs. In a case where we pass a single input to .trace(...) directly, we can still access the invoker object at tracer.invoker without having to call tracer.invoke(...).

Keyword arguments given to .invoke(..) make their way to the input pre-processing.
LanguageModel has keyword arguments max_length and truncation used for tokenization which can be passed to the invoker. If we want to pass keyword arguments to the invoker for a single-input .trace(...), we can pass invoker_args as a dictionary of invoker keyword arguments.

Here is an example to demonstrate everything we’ve described:

This snippet

with llm.trace("hello", invoker_args={"max_length":10}) as tracer:
  invoker = tracer.invoker

does the same as

with llm.trace() as tracer:
  with tracer.invoke("hello", max_length=10) as invoker:
    invoker = invoker

[91]:
with llm.trace() as tracer:

    with tracer.invoke("The Eiffel Tower is in the city of"):

        # Ablate the last MLP for only this batch.
        llm.transformer.h[-1].mlp.output[0][:] = 0

        # Get the output for only the intervened on batch.
        token_ids_intervention = llm.lm_head.output.argmax(dim=-1).save()

    with tracer.invoke("The Eiffel Tower is in the city of"):

        # Get the output for only the original batch.
        token_ids_original = llm.lm_head.output.argmax(dim=-1).save()


print("Original token IDs:", token_ids_original)
print("Modified token IDs:", token_ids_intervention)

print("Original prediction:", llm.tokenizer.decode(token_ids_original[0][-1]))
print("Modified prediction:", llm.tokenizer.decode(token_ids_intervention[0][-1]))
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Original token IDs: tensor([[ 198,   12,  417, 8765,  318,  257,  262, 3504, 7372, 6342]],
       device='mps:0')
Modified token IDs: tensor([[ 262,   12,  417, 8765,   11,  257,  262, 3504,  338, 3576]],
       device='mps:0')
Original prediction:  Paris
Modified prediction:  London

Based on our control results, our ablation did end up affecting what the model predicted. That’s pretty neat!

Another cool thing with multiple invokes is that Proxies can interact between them.

Here, we transfer the token embeddings from a real prompt into another placeholder prompt. Therefore the latter prompt produces the output of the former prompt:

[92]:
with llm.trace() as tracer:

    with tracer.invoke("The Eiffel Tower is in the city of"):
        embeddings = llm.transformer.wte.output

    with tracer.invoke("_ _ _ _ _ _ _ _ _ _"):
        llm.transformer.wte.output = embeddings
        token_ids_intervention = llm.lm_head.output.argmax(dim=-1).save()

    with tracer.invoke("_ _ _ _ _ _ _ _ _ _"):
      token_ids_original = llm.lm_head.output.argmax(dim=-1).save()

print("original prediction shape", token_ids_original[0][-1].shape)
print("Original prediction:", llm.tokenizer.decode(token_ids_original[0][-1]))

print("modified prediction shape", token_ids_intervention[0][-1].shape)
print("Modified prediction:", llm.tokenizer.decode(token_ids_intervention[0][-1]))
original prediction shape torch.Size([])
Original prediction:  _
modified prediction shape torch.Size([])
Modified prediction:  Paris

Multiple Token Generation#

.next()#

Some HuggingFace models define methods to generate multiple outputs at a time. LanguageModel wraps that functionality to provide the same tracing features by using .generate(...) instead of .trace(...). This calls the underlying model’s .generate method. It passes the output through a .generator module that we’ve added onto the model, allowing us to get the generate output at .generator.output.

In a case like this, the underlying model is called more than once; the modules of said model produce more than one output. Which iteration should a given module.output refer to? That’s where Module.next() comes in!

Each module has a call index associated with it and .next() simply increments that attribute. At the time of execution, data is injected into the intervention graph only at the iteration that matches the call index.

[93]:
with llm.generate('The Eiffel Tower is in the city of', max_new_tokens=3) as tracer:

    hidden_states1 = llm.transformer.h[-1].output[0].save()

    # use module.next() to access the next intervention
    hidden_states2 = llm.transformer.h[-1].next().output[0].save()

    # saving the output allows you to save the hidden state across the initial prompt
    out = llm.generator.output.save()

print(hidden_states1.shape)
print(hidden_states2.shape)
print(out)
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
torch.Size([1, 10, 768])
torch.Size([1, 1, 768])
tensor([[ 464,  412,  733,  417, 8765,  318,  287,  262, 1748,  286, 6342,   11,
          290]], device='mps:0')

using .all()#

With nnsight 0.4 you can now use .all() to recursively apply interventions to a model. Calling .all() on a module within a model will recursively apply its .input and .output across all iterations. Previously, we’d need to loop across each new generated token, saving the intervention for every generated token and calling .next() to move forward.

[94]:
# Old approach:
prompt = 'The Eiffel Tower is in the city of'
layers = llm.transformer.h
n_new_tokens = 3
hidden_states = []
with llm.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
    for i in range(n_new_tokens):
        # Apply intervention - set first layer output to zero
        layers[0].output[0][:] = 0

        # Append desired hidden state post-intervention
        hidden_states.append(layers[-1].output.save())

        # Move to next generated token
        layers[0].next()

print("Hidden state length: ",len(hidden_states))
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Hidden state length:  3

We can use also .all() to streamline the multiple token generation process. We simply call .all on the module where we are applying the intervention (in this case GPT-2’s layers), apply our intervention, and append our hidden states (stored in an nnsight.list() object).

Let’s test this out for the multiple token generation case:

[95]:
# using .all():
prompt = 'The Eiffel Tower is in the city of'
layers = llm.transformer.h
n_new_tokens = 3
with llm.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
    hidden_states = nnsight.list().save() # Initialize & .save() nnsight list

    # Call .all() to apply intervention to each new token
    layers.all()

    # Apply intervention - set first layer output to zero
    layers[0].output[0][:] = 0

    # Append desired hidden state post-intervention
    hidden_states.append(layers[-1].output) # no need to call .save
    # Don't need to loop or call .next()!

print("Hidden state length: ",len(hidden_states))
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Hidden state length:  3

Easy! Note that because .all() is recursive, it will only work to append outputs called on children of the module that .all() was called on. See example below for more information. TL;DR: apply .all() on the highest-level accessed module if interventions and outputs have different hierarchies within model structure.

Recursive properties of .all()

.all() recursively acts on model components. In the below code example, only the first token generation is saved, because .all() applied to layers, while the saved variable hidden_states is produced from model.lm_head, which is not a child of layers.

prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 3
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
    hidden_states = nnsight.list().save() # Initialize & .save() nnsight list

    # Call .all() on layers
    layers.all()

    # Apply same intervention - set first layer output to zero
    layers[0].output[0][:] = 0

    # Append desired hidden state post-intervention
    hidden_states.append(model.lm_head.output) # no need to call .save, it's already initialized

print("Hidden state length: ",len(hidden_states)) # length is 1, meaning it only saved the first token generation

If you want to apply an intervention during multiple token generation while saving the state of a model component that isn’t a child of that module, you can instead apply .all() to the full model:

prompt = 'The Eiffel Tower is in the city of'
layers = model.transformer.h
n_new_tokens = 3
with model.generate(prompt, max_new_tokens=n_new_tokens) as tracer:
    hidden_states = nnsight.list().save() # Initialize & .save() nnsight list

    # Call .all() on model
    model.all()

    # Apply same intervention - set first layer output to zero
    layers[0].output[0][:] = 0

    # Append desired hidden state post-intervention
    hidden_states.append(model.lm_head.output) # no need to call .save

print("Hidden state length: ",len(hidden_states)) # length is 3, as expected!

Model Editing#

NNsight’s model editing feature allows you to create persistently modified versions of a model with a use of .edit(). Unlike interventions in a tracing context, which are temporary, the Editor context enables you to make lasting changes to a model instance.

This feature is useful for: * Creating modified model variants without altering the original * Applying changes that persist across multiple forward passes * Comparing interventions between original and edited models

Let’s explore how to use the Editor context to make a simple persistent change to a model:

[96]:
# we take the hidden states with the expected output "Paris"
with llm.trace("The Eiffel Tower is located in the city of") as tracer:
    hs11 = llm.transformer.h[11].output[0][:, -1, :].save()

# the edited model will now always predict "Paris" as the next token
with llm.edit() as llm_edited:
    llm.transformer.h[11].output[0][:, -1, :] = hs11

# we demonstrate this by comparing the output of an unmodified model...
with llm.trace("Vatican is located in the city of") as tracer:
    original_tokens = llm.lm_head.output.argmax(dim=-1).save()

# ...with the output of the edited model
with llm_edited.trace("Vatican is located in the city of") as tracer:
    modified_tokens = llm.lm_head.output.argmax(dim=-1).save()


print("\nOriginal Prediction: ", llm.tokenizer.decode(original_tokens[0][-1]))
print("Modified Prediction: ", llm.tokenizer.decode(modified_tokens[0][-1]))

Original Prediction:   Rome
Modified Prediction:   Paris

Edits defined within an Editor context create a new, modified version of the model by default, preserving the original. This allows for safe experimentation with model changes. If you wish to modify the original model directly, you can set inplace=True when calling .edit().

Use this option cautiously, as in-place edits alter the base model for all the consequent model calls.

[97]:
# we use the hidden state we saved above (hs11)
with llm.edit(inplace=True) as llm_edited:
    llm.transformer.h[11].output[0][:, -1, :] = hs11

# we demonstrate this by comparing the output of an unmodified model...
with llm.trace("Vatican is located in the city of") as tracer:
    modified_tokens = llm.lm_head.output.argmax(dim=-1).save()

print("Modified In-place: ", llm.tokenizer.decode(modified_tokens[0][-1]))
Modified In-place:   Paris

If you’ve made in-place edits to your model and need to revert these changes, you can apply .clear_edits(). This method removes all edits applied to the model, effectively restoring it to its original state.

[98]:
llm.clear_edits()

with llm.trace("Vatican is located in the city of"):
    modified_tokens = llm.lm_head.output.argmax(dim=-1).save()

print("Edits cleared: ", llm.tokenizer.decode(modified_tokens[0][-1]))
Edits cleared:   Rome

3 I thought you said huge models?#

NNsight is only one part of our project to democratize access to AI internals. The other half is the National Deep Inference Fabric, or NDIF. NDIF hosts large models for shared access using NNsight, so you don’t have to worry about any of the headaches of hosting large models yourself!

The interaction between NDIF and NNsight is fairly straightforward. The intervention graph we create via the tracing context can be encoded into a custom json format and sent via an http request to the NDIF servers. NDIF then decodes the intervention graph and interleaves it alongside the specified model.

To see which models are currently being hosted, check out the following status page: https://nnsight.net/status/

Remote execution#

In its current state, NDIF requires you to receive an API key. Therefore, to run the rest of this walkthrough, you need one of your own. To get one, simply register at https://login.ndif.us.

With a valid API key, you then can configure nnsight as follows:

[100]:
from nnsight import CONFIG

CONFIG.set_default_api_key("YOUR_API_KEY")

If you’re running in a local IDE, this only needs to be run once as it will save the API key as the default in a .config file along with your nnsight installation. You can also add your API key to Google Colab secrets.

To amp things up a few levels, let’s demonstrate using nnsight’s tracing context with Llama-3.1-8b!

[101]:
import os

# Llama 3.1 8b is a gated model, so you need to apply for access on HuggingFace and include your token.
os.environ['HF_TOKEN'] = "YOUR_HUGGING_FACE_TOKEN"
[8]:
from nnsight import LanguageModel

# We'll never actually load the parameters locally, so no need to specify a device_map.
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
# All we need to specify using NDIF vs executing locally is remote=True.
with llama.trace("The Eiffel Tower is in the city of", remote=True) as runner:

    hidden_states = llama.model.layers[-1].output.save()

    output = llama.output.save()

print(hidden_states)

print(output["logits"])
2025-01-31 15:47:57,111 69e5f67e-d0b4-4014-ba43-07c4d2b952bb - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:47:57,317 69e5f67e-d0b4-4014-ba43-07c4d2b952bb - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:47:57,519 69e5f67e-d0b4-4014-ba43-07c4d2b952bb - RUNNING: Your job has started running.
2025-01-31 15:47:58,468 69e5f67e-d0b4-4014-ba43-07c4d2b952bb - COMPLETED: Your job has been completed.
Downloading result:   0%|          | 0.00/4.37M [00:00<?, ?B/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading result: 100%|██████████| 4.37M/4.37M [00:00<00:00, 16.5MB/s]
(tensor([[[ 1.7734,  2.6875,  0.8047,  ..., -1.8594,  2.2344,  3.2656],
         [ 0.0312, -0.0352, -2.8750,  ..., -0.8906, -0.0547,  1.6172],
         [ 1.3594, -2.0156,  1.7031,  ..., -1.7031, -0.7422,  1.4375],
         ...,
         [ 1.0000,  0.3203, -0.2656,  ..., -0.0723, -0.2559,  0.2090],
         [ 0.4707, -0.3496,  0.2422,  ...,  0.7344, -0.0078,  0.1133],
         [-0.0566, -0.3496,  0.4746,  ...,  0.9844,  0.6797, -0.8750]]],
       dtype=torch.bfloat16),)
tensor([[[ 6.3438,  8.3750, 12.8125,  ..., -4.3750, -4.3750, -4.3750],
         [-2.4375, -1.7266, -2.0156,  ..., -9.1250, -9.1250, -9.1250],
         [ 9.6875,  4.5625,  5.8750,  ..., -3.3906, -3.3906, -3.3906],
         ...,
         [ 2.3281,  1.0703, -0.3203,  ..., -7.1562, -7.1562, -7.1562],
         [11.1875,  6.0312,  4.9062,  ..., -3.5156, -3.5156, -3.5156],
         [ 8.0000,  5.2500,  4.3750,  ..., -3.9844, -3.9844, -3.9844]]],
       dtype=torch.bfloat16)

It really is as simple as remote=True. All of the techniques we went through in earlier sections work just the same when running locally or remotely.

Sessions#

NDIF uses a queue to handle concurrent requests from multiple users. To optimize the execution of our experiments we can use the session context to efficiently package multiple interventions together as one single request to the server.

This offers the following benefits: 1. All interventions within a session will be executed one after another without additional wait in the NDIF queue 2. All intermediate outputs for each intervention are stored on the server and can be accessed by other interventions in the same session without moving the data back and forth between NDIF and the local machine

Let’s take a look:

[9]:
with llama.session(remote=True) as session:

  with llama.trace("The Eiffel Tower is in the city of") as t1:
    # capture the hidden state from layer 32 at the last token
    hs_31 = llama.model.layers[31].output[0][:, -1, :] # no .save()
    t1_tokens_out = llama.lm_head.output.argmax(dim=-1).save()

  with llama.trace("Buckingham Palace is in the city of") as t2:
    llama.model.layers[1].output[0][:, -1, :] = hs_31[:]
    t2_tokens_out = llama.lm_head.output.argmax(dim=-1).save()

print("\nT1 - Original Prediction: ", llama.tokenizer.decode(t1_tokens_out[0][-1]))
print("T2 - Modified Prediction: ", llama.tokenizer.decode(t2_tokens_out[0][-1]))
2025-01-31 15:48:17,378 d82e4471-0bb0-4ecc-8bb3-6e8fe35f0c03 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:48:17,558 d82e4471-0bb0-4ecc-8bb3-6e8fe35f0c03 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:48:17,755 d82e4471-0bb0-4ecc-8bb3-6e8fe35f0c03 - RUNNING: Your job has started running.
2025-01-31 15:48:18,205 d82e4471-0bb0-4ecc-8bb3-6e8fe35f0c03 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.62k/1.62k [00:00<00:00, 6.57MB/s]

T1 - Original Prediction:   Paris
T2 - Modified Prediction:  ://

In the example above, we are interested in replacing the hidden state of a later layer with an earlier one. Since we are using a session, we don’t have to save the hidden state from Tracer 1 to reference it in Tracer 2.

It is important to note that all the traces defined within the session context are executed sequentially, strictly following the order of definition (i.e. t2 being executed after t1 and t3 after t2 etc.).

The session context object has its own methods to log values and be terminated early.

[104]:
with llama.session(remote=True) as session:
  session.log("-- Early Stop --")
  nnsight.stop
2025-01-31 15:17:51,636 3d5144b4-e0b6-4106-912f-5abec8e78e7f - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:17:51,804 3d5144b4-e0b6-4106-912f-5abec8e78e7f - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:17:52,023 3d5144b4-e0b6-4106-912f-5abec8e78e7f - RUNNING: Your job has started running.
2025-01-31 15:17:52,028 3d5144b4-e0b6-4106-912f-5abec8e78e7f - LOG: -- Early Stop --
2025-01-31 15:17:52,364 3d5144b4-e0b6-4106-912f-5abec8e78e7f - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 7.25MB/s]

In addition to the benefits mentioned above, the session context also enables interesting experiments not possible with other nnsight tools — since every trace is run on its own model, it means that within one session we can run interventions between different models — for example, we could swap activations between base and instruct versions of the Llama model and compare their outputs. And session can also be used to run similar experiments entirely locally!

Streaming#

Streaming enables users apply functions and datasets locally during remote model execution. This allows users to stream results for immediate consumption (i.e., seeing tokens as they are generated) or applying non-whitelisted functions such as model tokenizers, large local datasets, and more!

  • nnsight.local() context sends values immediately to user’s local machine from server

  • Intervention graph is executed locally on downstream nodes

  • Exiting local context uploads data back to server

  • @nnsight.trace function decorator enables custom functions to be added to intervention graph when using nnsight.local()

nnsight.local()#

You may sometimes want to locally access and manipulate values during remote execution. Using .local() on a proxy, you can send remote content to your local machine and apply local functions. The intervention graph is then executed locally on downstream nodes (until you send execution back to the remote server by exiting the .local() context).

There are a few use cases for streaming with .local(), including live chat generation and applying large datasets or non-whitelisted local functions to the intervention graph.

Now let’s explore how streaming works. We’ll start by grabbing some hidden states of the model and printing their value using tracer.log(). Without calling nnsight.local(), these operations will all occur remotely.

[120]:
# This will give you a remote LOG response because it's coming from the remote server
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)
2025-01-31 15:21:29,379 94c37060-d3f6-4f21-b55f-0e72f03d08da - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:21:29,592 94c37060-d3f6-4f21-b55f-0e72f03d08da - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:21:29,786 94c37060-d3f6-4f21-b55f-0e72f03d08da - RUNNING: Your job has started running.
2025-01-31 15:21:29,822 94c37060-d3f6-4f21-b55f-0e72f03d08da - LOG: tensor(1.7656, device='cuda:0')
2025-01-31 15:21:30,161 94c37060-d3f6-4f21-b55f-0e72f03d08da - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 4.80MB/s]
tensor([[[ 6.3438,  8.3750, 12.8125,  ..., -4.3750, -4.3750, -4.3750],
         [10.2500,  2.1094,  2.8281,  ..., -8.2500, -8.2500, -8.2500]]],
       dtype=torch.bfloat16)

Now, let’s try the same operation using the nnsight.local() context. This will send the operations to get and print the hidden states to your local machine, changing how the logging message is formatted (local formatting instead of remote).

[121]:
# This will print locally because it's already local
with llama.trace("hello", remote=True) as tracer:

    with nnsight.local():
        hs = llama.model.layers[-1].output[0]
        tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)
2025-01-31 15:21:32,674 0414b4df-a160-485b-89dd-4e70a05a43e2 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:21:32,830 0414b4df-a160-485b-89dd-4e70a05a43e2 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:21:33,092 0414b4df-a160-485b-89dd-4e70a05a43e2 - RUNNING: Your job has started running.
2025-01-31 15:21:33,352 0414b4df-a160-485b-89dd-4e70a05a43e2 - COMPLETED: Your job has been completed.
tensor(1.7656, dtype=torch.bfloat16)
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 4.28MB/s]
tensor([[[ 6.3438,  8.3750, 12.8125,  ..., -4.3750, -4.3750, -4.3750],
         [10.2500,  2.1094,  2.8281,  ..., -8.2500, -8.2500, -8.2500]]],
       dtype=torch.bfloat16)

@nnsight.trace function decorator#

We can also use function decorators to create custom functions to be used during .local calls. This is a handy way to enable live streaming of a chat or to train probing classifiers on model hidden states.

Let’s try out @nnsight.trace and nnsight.local() to access a custom function during remote execution.

[122]:
# first, let's define our function
@nnsight.trace # decorator that enables this function to be added to the intervention graph
def my_local_fn(value):
    return value * 0

# We use a local function to ablate some hidden states
# This downloads the data for the .local context, and then uploads it back to set the value.
with llama.generate("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    with nnsight.local():

        hs = my_local_fn(hs)

    llama.model.layers[-1].output[0][:] = hs

    out =  llama.lm_head.output.save()
2025-01-31 15:21:35,853 3eeee112-3074-4b3e-adea-001a85ce554d - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:21:36,035 3eeee112-3074-4b3e-adea-001a85ce554d - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:21:36,258 3eeee112-3074-4b3e-adea-001a85ce554d - RUNNING: Your job has started running.
2025-01-31 15:21:36,653 3eeee112-3074-4b3e-adea-001a85ce554d - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 2.91MB/s]

Note that without calling .local, the remote API does not know about my_local_fn and will throw a whitelist error. A whitelist error occurs because you are being allowed access to the function.

[123]:
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    hs = my_local_fn(hs) # no .local - will cause an error

    llama.model.layers[-1].output[0][:] = hs * 2

    out =  llama.lm_head.output.save()

print(out)
---------------------------------------------------------------------------
FunctionWhitelistError                    Traceback (most recent call last)
Cell In[123], line 1
----> 1 with llama.trace("hello", remote=True) as tracer:
      3     hs = llama.model.layers[-1].output[0]
      5     hs = my_local_fn(hs) # no .local - will cause an error

File ~/Documents/Projects/nnsight/src/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
     92     self.invoker.__exit__(None, None, None)
     94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     21 from .globals import GlobalTracingContext
     23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)

File ~/Documents/Projects/nnsight/src/nnsight/tracing/contexts/base.py:82, in Context.__exit__(self, exc_type, exc_val, exc_tb)
     78 graph = graph.stack.pop()
     80 graph.alive = False
---> 82 self.backend(graph)

File ~/Documents/Projects/nnsight/src/nnsight/intervention/backends/remote.py:77, in RemoteBackend.__call__(self, graph)
     72 def __call__(self, graph: Graph):
     74     if self.blocking:
     75
     76         # Do blocking request.
---> 77         result = self.blocking_request(graph)
     79     else:
     80
     81         # Otherwise we are getting the status / result of the existing job.
     82         result = self.non_blocking_request(graph)

File ~/Documents/Projects/nnsight/src/nnsight/intervention/backends/remote.py:289, in RemoteBackend.blocking_request(self, graph)
    280 sio.connect(
    281     self.ws_address,
    282     socketio_path="/ws/socket.io",
    283     transports=["websocket"],
    284     wait_timeout=10,
    285 )
    287 remote_graph = preprocess(graph)
--> 289 data, headers = self.request(remote_graph)
    291 headers["session_id"] = sio.sid
    293 # Submit request via

File ~/Documents/Projects/nnsight/src/nnsight/intervention/backends/remote.py:60, in RemoteBackend.request(self, graph)
     58 def request(self, graph: Graph) -> Tuple[bytes, Dict[str, str]]:
---> 60     data = RequestModel.serialize(graph, self.format, self.zlib)
     62     headers = {
     63         "model_key": self.model_key,
     64         "format": self.format,
   (...)
     67         "sent-timestamp": str(time.time()),
     68     }
     70     return data, headers

File ~/Documents/Projects/nnsight/src/nnsight/schema/request.py:43, in RequestModel.serialize(graph, format, _zlib)
     38 @staticmethod
     39 def serialize(graph: Graph, format:str, _zlib:bool) -> bytes:
     41     if format == "json":
---> 43         data = RequestModel(graph=graph)
     45         json = data.model_dump(mode="json")
     47         data = msgspec.json.encode(json)

File ~/Documents/Projects/nnsight/src/nnsight/schema/request.py:30, in RequestModel.__init__(self, memo, *args, **kwargs)
     28 def __init__(self, *args, memo: Dict = None, **kwargs):
---> 30     super().__init__(*args, memo=memo or dict(), **kwargs)
     32     if memo is None:
     34         self.memo = {**MEMO}

    [... skipping hidden 1 frame]

File ~/Documents/Projects/nnsight/src/nnsight/schema/format/types.py:276, in GraphModel.to_model(value)
    273 @staticmethod
    274 def to_model(value: Graph) -> Self:
--> 276     return GraphModel(graph=value, nodes=value.nodes)

    [... skipping hidden 1 frame]

File ~/Documents/Projects/nnsight/src/nnsight/schema/format/types.py:77, in memoized.<locals>.inner(value)
     75 def inner(value):
---> 77     model = fn(value)
     79     _id = id(value)
     81     MEMO[_id] = model

File ~/Documents/Projects/nnsight/src/nnsight/schema/format/types.py:101, in NodeModel.to_model(value)
     97 @staticmethod
     98 @memoized
     99 def to_model(value: Node) -> Self:
--> 101     return NodeModel(target=value.target, args=value.args, kwargs=value.kwargs)

    [... skipping hidden 1 frame]

File ~/Documents/Projects/nnsight/src/nnsight/schema/format/types.py:244, in FunctionModel.to_model(value)
    239 @staticmethod
    240 def to_model(value:FUNCTION):
    242     model = FunctionModel(function_name=get_function_name(value))
--> 244     FunctionModel.check_function_whitelist(model.function_name)
    246     return model

File ~/Documents/Projects/nnsight/src/nnsight/schema/format/types.py:251, in FunctionModel.check_function_whitelist(cls, qualname)
    248 @classmethod
    249 def check_function_whitelist(cls, qualname: str) -> str:
    250     if qualname not in FUNCTIONS_WHITELIST:
--> 251         raise FunctionWhitelistError(
    252             f"Function with name `{qualname}` not in function whitelist."
    253         )
    255     return qualname

FunctionWhitelistError: Function with name `__main__.my_local_fn` not in function whitelist.

Example: Live-streaming remote chat#

Now that we can access data within the tracing context on our local computer, we can apply non-whitelisted functions, such as the model’s tokenizer, within our tracing context.

Let’s build a decoding function that will decode tokens into words and print the result.

[124]:
@nnsight.trace
def my_decoding_function(tokens, model, max_length=80, state=None):
    # Initialize state if not provided
    if state is None:
        state = {'current_line': '', 'current_line_length': 0}

    token = tokens[-1] # only use last token

    # Decode the token
    decoded_token = llama.tokenizer.decode(token).encode("unicode_escape").decode()

    if decoded_token == '\\n':  # Handle explicit newline tokens
        # Print the current line and reset state
        print('',flush=True)
        state['current_line'] = ''
        state['current_line_length'] = 0
    else:
        # Check if adding the token would exceed the max length
        if state['current_line_length'] + len(decoded_token) > max_length:
            print('',flush=True)
            state['current_line'] = decoded_token  # Start a new line with the current token
            state['current_line_length'] = len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line
        else:
            # Add a space if the line isn't empty and append the token
            if state['current_line']:
                state['current_line'] += decoded_token
            else:
                state['current_line'] = decoded_token
            state['current_line_length'] += len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line

    return state

Now we can decode and print our model outputs throughout token generation by accessing our decoding function through nnsight.local().

[125]:
import torch

nnsight.CONFIG.APP.REMOTE_LOGGING = False

prompt = "A press release is an official statement delivered to members of the news media for the purpose of"
# prompt = "Your favorite board game is"

print("Prompt: ",prompt,'\n', end ="")

# Initialize the state for decoding
state = {'current_line': '', 'current_line_length': 0}

with llama.generate(prompt, remote=True, max_new_tokens = 50) as generator:
    # Call .all() to apply to each new token
    llama.all()

    all_tokens = nnsight.list().save()

    # Access model output
    out = llama.lm_head.output.save()

    # Apply softmax to obtain probabilities and save the result
    probs = torch.nn.functional.softmax(out, dim=-1)
    max_probs = torch.max(probs, dim=-1)
    tokens = max_probs.indices.cpu().tolist()
    all_tokens.append(tokens[0]).save()

    with nnsight.local():
        state = my_decoding_function(tokens[0], llama, max_length=20, state=state)
Prompt:  A press release is an official statement delivered to members of the news media for the purpose of
2025-01-31 15:21:48,944 94a118bf-3d5d-4ee8-8cf6-2778e95b745f - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:21:49,121 94a118bf-3d5d-4ee8-8cf6-2778e95b745f - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:21:49,363 94a118bf-3d5d-4ee8-8cf6-2778e95b745f - RUNNING: Your job has started running.
 providing information, an official statement, or making an announcement.A press release is also written or recorded communication directed at members of the news media for the purpose of announcing something ostensibly newsworthy. Typically, they are mailed, faxed, or e
2025-01-31 15:21:52,277 94a118bf-3d5d-4ee8-8cf6-2778e95b745f - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 3.03MB/s]

Looping across sessions#

We mention earlier that the session context enables multi-tracing execution. But how can we optimize a process that would require running an intervention graph in a loop? If we create a simple for loop with a Tracer context inside, this will result in creating a new intervention graph at each iteration, which is not scalable.

We solve this problem the nnsight way via the Iterator context: an intervention loop that iteratively executes and updates a single intervention graph.

Use a session to define the Iterator context and pass in a sequence of items that you want to loop over at each iteration:

[126]:
with llama.session(remote=True) as session:

  with session.iter([0, 1, 2]) as item:
    # define intervention body here ...

    with llama.trace("_"):
      # define interventions here ...
      pass

    with llama.trace("_"):
      # define interventions here ...
      pass
2025-01-31 15:22:03,714 de3f072d-7ec7-4975-8ebe-4307ba8c7258 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:22:03,869 de3f072d-7ec7-4975-8ebe-4307ba8c7258 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:22:04,120 de3f072d-7ec7-4975-8ebe-4307ba8c7258 - RUNNING: Your job has started running.
2025-01-31 15:22:05,086 de3f072d-7ec7-4975-8ebe-4307ba8c7258 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 1.82MB/s]

The Iterator context extends all the nnsight graph-based functionalities, but also closely mimics the conventional for loop statement in Python, which allows it to support all kind of iterative operations with a use of as item syntax:

[127]:
with llama.session(remote=True) as session:

  li = nnsight.list()
  [li.append([num]) for num in range(0, 3)] # adding [0], [1], [2] to the list
  li2 = nnsight.list().save()

  # You can create nested Iterator contexts
  with session.iter(li) as item:
    with session.iter(item) as item_2:
      li2.append(item_2)

print("\nList: ", li2)
2025-01-31 15:22:07,235 570729dd-1d6d-4cd6-9876-9a096378119c - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:22:07,440 570729dd-1d6d-4cd6-9876-9a096378119c - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:22:07,674 570729dd-1d6d-4cd6-9876-9a096378119c - RUNNING: Your job has started running.
2025-01-31 15:22:10,270 570729dd-1d6d-4cd6-9876-9a096378119c - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 1.47MB/s]

List:  [0, 1, 2]

Notice how we used the nnsight.list() method to create a list of lists to loop over. This type of method is what we call an NNsight Built-in. It is a special type of methods that serve as a wrapper around nnsight.apply() to provide a more user-friendly interface for adding common datatypes to the Intervention Graph.

A full list of NNsight Built-ins

nnsight.bool() creates a traceable Boolean

nnsight.bytes() creates a traceable Bytes

nnsight.int() creates a traceable Integer

nnsight.float() creates a traceable Float

nnsight.str() creates a traceable String

nnsight.comples() creates a traceable Complex number

nnsight.bytearray() creates a traceable Bytearray

nnsight.tuple() creates a traceable Tuple

nnsight.list() creates a traceable List

nnsight.set() creates a traceable Set

nnsight.dict() creates a traceable Dictionary

We can also expose the iterator context object via a return_context flag. You can then use it to exit out of the Iteration loop early and log the intermediate outputs within the loop:

[128]:
with llama.session(remote=True) as session:

  # with session.iter([0, 1, 2, 3], return_context=True) as (item, iterator):
  with session.iter([0, 1, 2, 3]) as item:

      nnsight.log(item)

      with nnsight.cond(item == 2):
        nnsight.stop()
2025-01-31 15:22:12,730 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:22:12,900 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:22:13,237 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - RUNNING: Your job has started running.
2025-01-31 15:22:13,240 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - LOG: 0
2025-01-31 15:22:13,242 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - LOG: 2
2025-01-31 15:22:13,242 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - LOG: 1
2025-01-31 15:22:13,759 cfa97488-c1e7-4ca7-b47d-dd856905edc9 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 4.74MB/s]

The Iterator context is a niece piece of functionality that allows you to define a bunch of basic code operations that can now be “traceable” by nnsight.

But in what kind of experimental scenario would someone need iterators?

In the next section, we delve into a powerful use case of the Iterator context and see how it enables it!

Training a LoRA#

Here is an example of a task that uses everything we have covered in the last section - remote execution, Session context and iterative interventions. Using session and iterator contexts, we’re going apply a very simple fine-tuning approach called low-rank adaptation (LoRA).

Let’s try training a LoRA that, when applied, makes our model always predict “Paris” no matter what.

[10]:
import torch
import torch.nn as nn
import nnsight
# from nnsight.envoy import Envoy # this moved in 0.4
from nnsight import Envoy

# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
    def __init__(self, module: Envoy, dim: int, r: int) -> None:
        """Init.

        Args:
            module (Envoy): Which model Module we are adding the LORA to.
            dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
            r (int): Inner dimension of the LORA
        """
        super(LORA, self).__init__()
        self.r = r
        self.module = module
        self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
        self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()

    # The Call method defines how to actually apply the LORA.
    def __call__(self, alpha: float = 1.0):
        """Call.

        Args:
            alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
        """

        # We apply WA to the first positional arg (the hidden states)
        A_x = torch.matmul(self.module.input[0][0], self.WA)
        BA_x = torch.matmul(A_x, self.WB)

        # LORA is additive
        h = BA_x + self.module.output

        # Replace the output with our new one * alpha
        # Could also have been self.module.output[:] = h * alpha, for in-place
        self.module.output = h * alpha

    def parameters(self):
        # Some way to get all the parameters.
        return [self.WA, self.WB]

Let’s define all the variables to use in LoRA training.

[11]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = llama.tokenizer.encode(answer)[1]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = llama.model.layers[-1].mlp

We can use the .scan() method to get the shape of the module without having to fully run the model.

[12]:
with llama.scan(" "):
    dim = module.output.shape[-1]

print(dim)
4096

It’s time to run the LORA training loop! We using the Session and the Iterator contexts to achieve this.

[13]:
from torch.utils.data import DataLoader

# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with llama.session(remote=True) as session:

    # Create dataset of 100 pairs of a blank prompt and the " Paris " id
    dataset = [["_", answer_token]] * 100

    # Create a dataloader from it.
    dataloader = DataLoader(dataset, batch_size=10)

    # Create our LORA on the last mlp
    lora = LORA(module, dim, lora_dim)

    # Create an optimizer. Use the parameters from LORA
    optimizer = torch.optim.AdamW(lora.parameters(), lr=3)

    # Iterate over dataloader using .iter.
    with session.iter(dataloader) as batch:

        prompt = batch[0]
        correct_token = batch[1]

        # Run .trace with prompt
        with llama.trace(prompt) as tracer:

            # Apply LORA to intervention graph just by calling it with .trace
            lora()

            # Get logits
            logits = llama.lm_head.output

            # Do cross entropy on last predicted token and correct_token
            loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
            # Call backward
            loss.backward()

        # Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
        # The Graph of Iterator here will be:
        # 1.) Index batch at 0 for prompt
        # 2.) Index batch at 1 for correct_token
        # 3.) Execute the .trace using the prompt
        # 4.) Call .step() on optimizer
        optimizer.step()
        # 5.) Call .zero_grad() in optimizer
        optimizer.zero_grad()
        # 6.) Print out the lora WA weights to show they are indeed changing
        nnsight.log(lora.WA)

2025-01-31 15:48:38,456 810d88f8-f1fc-43fd-b16d-77e11841eecd - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:48:38,634 810d88f8-f1fc-43fd-b16d-77e11841eecd - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:48:38,859 810d88f8-f1fc-43fd-b16d-77e11841eecd - RUNNING: Your job has started running.
2025-01-31 15:48:39,402 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[ 0.7539,  0.0732, -0.4023, -0.0615],
        [-0.5312,  0.5391,  0.5195, -1.1328],
        [-1.5781, -0.2480,  0.6953,  0.3535],
        ...,
        [ 0.1670,  0.5391, -0.5703, -0.6289],
        [-0.0762, -1.3438,  0.8320,  1.2656],
        [ 0.2500,  1.2188,  0.2891, -1.2578]], requires_grad=True)
2025-01-31 15:48:39,410 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[ 0.7305,  0.0713, -0.3906, -0.0598],
        [-0.5156,  0.5234,  0.5039, -1.1016],
        [-1.5312, -0.2402,  0.6758,  0.3438],
        ...,
        [ 0.1621,  0.5234, -0.5547, -0.6094],
        [-0.0737, -1.3047,  0.8086,  1.2266],
        [ 0.2422,  1.1797,  0.2812, -1.2188]], requires_grad=True)
2025-01-31 15:48:39,448 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[ 0.7070,  0.0688, -0.3789, -0.0586],
        [-0.5000,  0.5078,  0.4883, -1.0703],
        [-1.4844, -0.2334,  0.6562,  0.3340],
        ...,
        [ 0.1572,  0.5078, -0.5391, -0.5898],
        [-0.0718, -1.2656,  0.7852,  1.1875],
        [ 0.2344,  1.1406,  0.2734, -1.1797]], requires_grad=True)
2025-01-31 15:48:39,485 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[ 0.6836,  0.0645, -0.3633, -0.0598],
        [-0.4883,  0.4902,  0.4746, -1.0391],
        [-1.4375, -0.2285,  0.6406,  0.3203],
        ...,
        [ 0.1543,  0.4941, -0.5234, -0.5703],
        [-0.0718, -1.2266,  0.7617,  1.1484],
        [ 0.2227,  1.1094,  0.2695, -1.1484]], requires_grad=True)
2025-01-31 15:48:39,521 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-0.9648, -1.5703,  0.6094, -1.6875],
        [-2.0938, -1.1562,  1.3906, -2.6406],
        [-3.0156, -1.8516,  1.5938, -1.3047],
        ...,
        [ 1.7578,  2.0781, -1.1406,  1.0625],
        [-1.6875, -2.7969,  1.4688, -0.5117],
        [-1.4219, -0.5508,  1.3516, -2.7344]], requires_grad=True)
2025-01-31 15:48:39,558 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-2.7031e+00, -3.6719e-01, -9.6484e-01, -1.0938e+00],
        [-3.7969e+00,  2.1606e-02, -1.9238e-01, -2.0156e+00],
        [-4.6875e+00, -6.4453e-01,  3.3875e-03, -7.1484e-01],
        ...,
        [ 3.4688e+00,  8.6719e-01,  4.1602e-01,  4.8047e-01],
        [-3.4062e+00, -1.5703e+00, -1.0645e-01,  5.7129e-02],
        [-3.1406e+00,  6.1719e-01, -2.3633e-01, -2.0938e+00]],
       requires_grad=True)
2025-01-31 15:48:39,630 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-4.7812,  1.3125, -2.5625,  0.5117],
        [-5.8125,  1.6953, -1.8203, -0.3789],
        [-6.6875,  1.0469, -1.6250,  0.8789],
        ...,
        [ 5.5000, -0.8320,  2.0312, -1.1016],
        [-5.4375,  0.1494, -1.7266,  1.6406],
        [-5.1875,  2.2656, -1.8594, -0.4629]], requires_grad=True)
2025-01-31 15:48:39,697 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-3.2812,  2.9062, -4.0938,  2.1094],
        [-4.3125,  3.2812, -3.3750,  1.2422],
        [-5.1562,  2.6406, -3.1719,  2.4531],
        ...,
        [ 4.0000, -2.4531,  3.5625, -2.6875],
        [-3.9375,  1.7812, -3.2656,  3.2031],
        [-3.6875,  3.8438, -3.4062,  1.1719]], requires_grad=True)
2025-01-31 15:48:39,764 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-2.0000,  4.2500, -5.3750,  3.4688],
        [-3.0156,  4.6250, -4.6875,  2.6406],
        [-3.8125,  4.0000, -4.5000,  3.7969],
        ...,
        [ 2.6875, -3.8438,  4.8750, -4.0312],
        [-2.6250,  3.1719, -4.5938,  4.5312],
        [-2.3906,  5.1875, -4.7188,  2.5625]], requires_grad=True)
2025-01-31 15:48:39,830 810d88f8-f1fc-43fd-b16d-77e11841eecd - LOG: Parameter containing:
tensor([[-0.8867,  5.4062, -6.4688,  4.6250],
        [-1.8750,  5.7812, -5.8125,  3.8438],
        [-2.6406,  5.1562, -5.6562,  4.9375],
        ...,
        [ 1.5469, -5.0312,  6.0000, -5.1875],
        [-1.4922,  4.3750, -5.7188,  5.6875],
        [-1.2500,  6.3438, -5.8125,  3.7656]], requires_grad=True)
2025-01-31 15:48:40,224 810d88f8-f1fc-43fd-b16d-77e11841eecd - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 67.1k/67.1k [00:00<00:00, 1.65MB/s]

Now WA and WB are optimized! So we generate with the LoRA just by calling lora() in the .generate and save the output to then de-tokenize it.

[14]:
# With lora. Should produce "Hello Paris"
with llama.generate("Hello", remote=True) as generator:

    lora()

    out = llama.generator.output.save()

print(llama.tokenizer.batch_decode(out.value))

# Then without. Should produce "Hello,"
with llama.generate("Hello", remote=True) as generator:

    out = llama.generator.output.save()

print(llama.tokenizer.batch_decode(out.value))

2025-01-31 15:48:45,111 c9fcc827-5c78-4796-87cf-0f246c75a359 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:48:45,281 c9fcc827-5c78-4796-87cf-0f246c75a359 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:48:45,535 c9fcc827-5c78-4796-87cf-0f246c75a359 - RUNNING: Your job has started running.
2025-01-31 15:48:46,013 c9fcc827-5c78-4796-87cf-0f246c75a359 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.24k/1.24k [00:00<00:00, 4.30MB/s]
['<|begin_of_text|>Hello Paris']
2025-01-31 15:48:46,446 9f088820-f00d-492d-9861-c0f0a7c4d784 - RECEIVED: Your job has been received and is waiting approval.
2025-01-31 15:48:46,650 9f088820-f00d-492d-9861-c0f0a7c4d784 - APPROVED: Your job was approved and is waiting to be run.
2025-01-31 15:48:46,862 9f088820-f00d-492d-9861-c0f0a7c4d784 - RUNNING: Your job has started running.
2025-01-31 15:48:47,340 9f088820-f00d-492d-9861-c0f0a7c4d784 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.24k/1.24k [00:00<00:00, 4.61MB/s]
['<|begin_of_text|>Hello,']

Next Steps#

Check out nnsight.net/tutorials for more walkthroughs implementating classic interpretability techniques using nnsight.

Getting Involved!#

Note that both nnsight and NDIF are in active development, so changes may be made and errors may arise during use. If you’re interested in following updates to nnsight, contributing, giving feedback, or finding collaborators, please join the NDIF discord. We’d love to hear about your work using nnsight!

You can also follow us on LinkedIn, Bluesky: @ndif-team.bsky.social, and X: @ndif_team.

💟