Walkthrough#
An interactive version of this walkthrough can be found here.
In this era of large-scale deep learning, often the most interesting AI models are massive black boxes that are hard to run. Ordinary commercial inference APIs let us interact with huge models, but they do not let us access model internals.
The nnsight
library is different: it provides full access to all neural network internals. When used together with a remote service like the National Deep Inference Fabric (NDIF), it makes possible to run complex experiments on huge open models easily, with fully transparent access.
Our team wants to enable entire labs and independent researchers alike, as we believe a large, passionate, and collaborative community will produce the next big insights on this profoundly important field.
1️⃣ First, let’s start small#
Tracing Context#
To demonstrate the core functionality and syntax of nnsight, we’ll define and use a tiny two layer neural network.
[1]:
# Install nnsight
!pip install nnsight
!pip install --upgrade transformers torch
from IPython.display import clear_output
clear_output()
Our little model here is composed of two submodules – linear layers ‘layer1’ and ‘layer2’. We specify the sizes of each of these modules, and create some complementary example input.
[2]:
from collections import OrderedDict
import torch
input_size = 5
hidden_dims = 10
output_size = 2
net = torch.nn.Sequential(
OrderedDict(
[
("layer1", torch.nn.Linear(input_size, hidden_dims)),
("layer2", torch.nn.Linear(hidden_dims, output_size)),
]
)
).requires_grad_(False)
The core object of the nnsight package is NNsight
. This wraps around a given pytorch model to enable the capabilities nnsight provides.
[ ]:
import nnsight
from nnsight import NNsight
tiny_model = NNsight(net)
Printing a Pytorch model shows a named hierarchy of modules which is very useful when accessing sub-components directly. NNsight reflect the same hierarchy and can be similarly printed.
[4]:
print(tiny_model)
Sequential(
(layer1): Linear(in_features=5, out_features=10, bias=True)
(layer2): Linear(in_features=10, out_features=2, bias=True)
)
Before we actually get to using the model we just created, let’s talk about Python contexts.
Python contexts define a scope using the with
statement and are often used to create some object, or initiate some logic, that you later want to destroy or conclude.
The most common application is opening files like the following example:
with open('myfile.txt', 'r') as file:
text = file.read()
Python uses the with
keyword to enter a context-like object. This object defines logic to be run at the start of the with
block, as well as logic to be run when exiting. When using with
for a file, entering the context opens the file and exiting the context closes it. Being within the context means we can read from the file. Simple enough! Now we can discuss how nnsight
uses contexts to enable intuitive access into the internals of a neural network.
The main tool with nnsight
is a context for tracing.
We enter the tracing context by calling model.trace(<input>)
on an NNsight
model, which defines how we want to run the model. Inside the context, we will be able to customize how the neural network runs. The model is actually run upon exiting the tracing context.
[5]:
# random input
input = torch.rand((1, input_size))
with tiny_model.trace(input) as tracer:
pass
But where’s the output? To get that, we’ll have to learn how to request it from within the tracing context.
Getting#
Earlier, when we wrapped our little neural net with the NNsight
class. This added a couple properties to each module in the model (including the root model itself). The two most important ones are .input
and .output
.
model.input
model.output
The names are self explanatory. They correspond to the inputs and outputs of their respective modules during a forward pass of the model. We can use these attributes inside the with
block.
However, it is important to understand that the model is not executed until the end of the tracing context. How can we access inputs and outputs before the model is run? The trick is deferred execution.
.input
and .output
are Proxies for the eventual inputs and outputs of a module. In other words, when we access model.output
what we are communicating to nnsight
is, “When you compute the output of model
, please grab it for me and put the value into its corresponding Proxy object’s .value
attribute.” Let’s try it:
[6]:
with tiny_model.trace(input) as tracer:
output = tiny_model.output
print(output.value)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 5
1 with tiny_model.trace(input) as tracer:
3 output = tiny_model.output
----> 5 print(output.value)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:50, in Proxy.value(self)
42 @property
43 def value(self) -> Any:
44 """Property to return the value of this proxy's node.
45
46 Returns:
47 Any: The stored value of the proxy, populated during execution of the model.
48 """
---> 50 return self.node.value
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:182, in Node.value(self)
172 """Property to return the value of this node.
173
174 Returns:
(...)
178 ValueError: If the underlying ._value is inspect._empty (therefore never set or destroyed).
179 """
181 if not self.done():
--> 182 raise ValueError("Accessing value before it's been set.")
184 return self._value
ValueError: Accessing value before it's been set.
Oh no, an error! ValueError: Accessing value before it's been set.
Why doesn’t our output
have a value
?
Proxy objects will only have their value at the end of a context if we call .save()
on them. This helps to reduce memory costs. Adding .save()
fixes the error:
[7]:
with tiny_model.trace(input) as tracer:
output = tiny_model.output.save()
print(output.value)
tensor([[-0.2687, -0.3314]])
Success! We now have the model output. We just completed out first intervention using nnsight
.
Each time we access a module’s input or output, we create an intervention in the neural network’s forward pass. Collectively these requests form the intervention graph. We call the process of executing it alongside the model’s normal computation graph, interleaving.
On Model output
If we don’t need to access anything other than the final model output, we can call the tracing context with trace=False
and not use it as a context. This could be especially useful for easy remote inference.
output = model.trace(<inputs>, trace=False)
Just like we saved the model’s final output, we can save the output of any of its submodules. Using normal Python attribute syntax, we can discover how to access them by name by printing out the model:
[8]:
print(tiny_model)
Sequential(
(layer1): Linear(in_features=5, out_features=10, bias=True)
(layer2): Linear(in_features=10, out_features=2, bias=True)
)
Let’s access the output of the first layer (non-coincidentally named ‘layer1’):
[9]:
with tiny_model.trace(input) as tracer:
l1_output = tiny_model.layer1.output.save()
print(l1_output.value)
tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
Let’s do the same for the input of layer2. While we’re at it, let’s also drop the as tracer
, as we won’t be needing the tracer
object for a few sections:
[10]:
with tiny_model.trace(input):
l2_input = tiny_model.layer2.input.save()
print(l2_input.value)
tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
On module inputs
Notice how the value for l2_input
, is just a single tensor. By default, the .input
attribute of a module will return the first tensor input to the module.
We can also access the full input to a module by using the .inputs
attribute which will return the values in the form of:
tuple(tuple(args), dictionary(kwargs))
Where the first index of the tuple is itself a tuple of all positional arguments, and the second index is a dictionary of the keyword arguments.
Until now we were saving the output of the model and its submodules within the Trace
context to then print it after exiting the context. We will continuing doing this in the rest of the tutorial since it’s a good practice to save the computation results for later analysis.
However, we can also log the outputs of the model and its submodules within the Trace
context. This is useful for debugging and understanding the model’s behavior while saving memory. Let’s see how to do this:
[11]:
with tiny_model.trace(input) as tracer:
tracer.log("Layer 1 - out: ", tiny_model.layer1.output)
Layer 1 - out: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
Functions, Methods, and Operations#
Now that we can access activations, we also want to do some post-processing on it. Let’s find out which dimension of layer1’s output has the highest value.
We could do this by calling torch.argmax(...)
after the tracing context or we can just leverage the fact that nnsight
handles Pytorch functions and methods within the tracing context, by creating a Proxy request for it:
[12]:
with tiny_model.trace(input):
# Note we don't need to call .save() on the output,
# as we're only using its value within the tracing context.
l1_output = tiny_model.layer1.output
# We do need to save the argmax tensor however,
# as we're using it outside the tracing context.
l1_amax = torch.argmax(l1_output, dim=1).save()
print(l1_amax[0])
tensor(7)
Nice! That worked seamlessly, but why didn’t we need to call .value[0]
on the result? In previous sections, we were just being explicit to get an understanding of Proxies and their value. In practice, however, nnsight
knows that when outside of the tracing context we only care about the actual value, and so printing, indexing, and applying functions all immediately return and reflect the data in .value
. So for the rest of the tutorial we won’t use it.
The same principles work for Pytorch methods and all operators as well:
[13]:
with tiny_model.trace(input):
value = (tiny_model.layer1.output.sum() + tiny_model.layer2.output.sum()).save()
print(value)
tensor(-1.3797)
The code block above is saying to nnsight
, “Run the model with the given input
. When the output of layer1 is computed, take its sum. Then do the same for layer2. Now that both of those are computed, add them and make sure not to delete this value as I wish to use it outside of the tracing context.”
Custom Functions#
Everything within the tracing context operates on the intervention graph. Therefore for nnsight
to trace a function it must also be a part of the intervention graph.
Out-of-the-box nnsight
supports Pytorch functions and methods, all operators, as well the einops
library. We don’t need to do anything special to use them. But what do we do if we want to use custom functions? How do we add them to the intervention graph?
Enter nnsight.apply()
. It allows us to add new functions to the intervention graph. Let’s see how it works:
[14]:
# Take a tensor and return the sum of its elements
def tensor_sum(tensor):
flat = tensor.flatten()
total = 0
for element in flat:
total += element.item()
return torch.tensor(total)
with tiny_model.trace(input) as tracer:
# Specify the function name and its arguments (in a coma-separated form) to add to the intervention graph
custom_sum = nnsight.apply(tensor_sum, tiny_model.layer1.output).save()
sum = tiny_model.layer1.output.sum()
sum.save()
print(custom_sum, sum)
tensor(-0.7796) tensor(-0.7796)
nnsight.apply()
executes the function it wraps and returns its output as a Proxy object. We can then use this Proxy object as we would any other.
The applications of nnsight.apply
are wide. It can be used to wrap any custom function or functions from libraries that nnsight
does not support out-of-the-box.
Setting#
Getting and analyzing the activations from various points in a model can be really insightful, and a number of ML techniques do exactly that. However, we often not only want to view the computation of a model, but also to influence it.
To demonstrate the effect of editing the flow of information through the model, let’s set the first dimension of the first layer’s output to 0. NNsight
makes this really easy using ‘=’ operator:
[15]:
with tiny_model.trace(input):
# Save the output before the edit to compare.
# Notice we apply .clone() before saving as the setting operation is in-place.
l1_output_before = tiny_model.layer1.output.clone().save()
# Access the 0th index of the hidden state dimension and set it to 0.
tiny_model.layer1.output[:, 0] = 0
# Save the output after to see our edit.
l1_output_after = tiny_model.layer1.output.save()
print("Before:", l1_output_before)
print("After:", l1_output_after)
Before: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
After: tensor([[ 0.0000, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
Seems our change was reflected. Now the same for the last dimension:
[16]:
with tiny_model.trace(input):
# Save the output before the edit to compare.
# Notice we apply .clone() before saving as the setting operation is in-place.
l1_output_before = tiny_model.layer1.output.clone().save()
# Access the last index of the hidden state dimension and set it to 0.
tiny_model.layer1.output[:, hidden_dims] = 0
# Save the output after to see our edit.
l1_output_after = tiny_model.layer1.output.save()
print("Before:", l1_output_before)
print("After:", l1_output_after)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:380, in Node.execute(self)
379 # Call the target to get value.
--> 380 output = self.target(*args, **kwargs)
382 # Set value.
IndexError: index 10 is out of bounds for dimension 1 with size 10
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
Cell In[16], line 1
----> 1 with tiny_model.trace(input):
2
3 # Save the output before the edit to compare.
4 # Notice we apply .clone() before saving as the setting operation is in-place.
5 l1_output_before = tiny_model.layer1.output.clone().save()
7 # Access the last index of the hidden state dimension and set it to 0.
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/Tracer.py:103, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
98 self.invoker.__exit__(None, None, None)
100 self.model._envoy._reset()
--> 103 super().__exit__(exc_type, exc_val, exc_tb)
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:218, in GraphBasedContext.__exit__(self, exc_type, exc_val, exc_tb)
215 self.graph = None
216 raise exc_val
--> 218 self.backend(self)
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/backends/LocalBackend.py:27, in LocalBackend.__call__(self, obj)
25 def __call__(self, obj: LocalMixin):
---> 27 obj.local_backend_execute()
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/Tracer.py:147, in Tracer.local_backend_execute(self)
143 invoker_inputs = resolve_dependencies(invoker_inputs)
145 self.graph.execute()
--> 147 self.model.interleave(
148 self.model._execute,
149 self.graph,
150 *invoker_inputs,
151 **self._kwargs,
152 )
154 graph = self.graph
155 graph.alive = False
File ~/Projects/NDIF/nnsight/src/nnsight/models/NNsightModel.py:462, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
454 intervention_handler = InterventionHandler(
455 intervention_graph, batch_groups, batch_size
456 )
458 module_paths = InterventionProtocol.get_interventions(
459 intervention_graph
460 ).keys()
--> 462 with HookHandler(
463 self._model,
464 list(module_paths),
465 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
466 activations, module_path, "input", intervention_handler
467 ),
468 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
469 activations, module_path, "output", intervention_handler
470 ),
471 ):
472 try:
473 fn(*inputs, **kwargs)
File ~/Projects/NDIF/nnsight/src/nnsight/intervention.py:574, in HookHandler.__exit__(self, exc_type, exc_val, exc_tb)
571 handle.remove()
573 if isinstance(exc_val, Exception):
--> 574 raise exc_val
File ~/Projects/NDIF/nnsight/src/nnsight/models/NNsightModel.py:473, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
462 with HookHandler(
463 self._model,
464 list(module_paths),
(...)
470 ),
471 ):
472 try:
--> 473 fn(*inputs, **kwargs)
474 except protocols.EarlyStopProtocol.EarlyStopException:
475 # TODO: Log.
476 for node in intervention_graph.nodes.values():
File ~/Projects/NDIF/nnsight/src/nnsight/models/NNsightModel.py:584, in NNsight._execute(self, *prepared_inputs, **kwargs)
581 except:
582 pass
--> 584 return self._model(
585 *prepared_inputs,
586 **kwargs,
587 )
File /opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File /opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/torch/nn/modules/container.py:219, in Sequential.forward(self, input)
217 def forward(self, input):
218 for module in self:
--> 219 input = module(input)
220 return input
File /opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/torch/nn/modules/module.py:1616, in Module._call_impl(self, *args, **kwargs)
1614 hook_result = hook(self, args, kwargs, result)
1615 else:
-> 1616 hook_result = hook(self, args, result)
1618 if hook_result is not None:
1619 result = hook_result
File ~/Projects/NDIF/nnsight/src/nnsight/intervention.py:559, in HookHandler.__enter__.<locals>.output_hook(module, input, output, module_path)
558 def output_hook(module, input, output, module_path=module_path):
--> 559 return self.output_hook(output, module_path)
File ~/Projects/NDIF/nnsight/src/nnsight/models/NNsightModel.py:468, in NNsight.interleave.<locals>.<lambda>(activations, module_path)
454 intervention_handler = InterventionHandler(
455 intervention_graph, batch_groups, batch_size
456 )
458 module_paths = InterventionProtocol.get_interventions(
459 intervention_graph
460 ).keys()
462 with HookHandler(
463 self._model,
464 list(module_paths),
465 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
466 activations, module_path, "input", intervention_handler
467 ),
--> 468 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
469 activations, module_path, "output", intervention_handler
470 ),
471 ):
472 try:
473 fn(*inputs, **kwargs)
File ~/Projects/NDIF/nnsight/src/nnsight/intervention.py:449, in InterventionProtocol.intervene(cls, activations, module_path, key, intervention_handler)
442 value = util.apply(
443 activations,
444 narrow,
445 torch.Tensor,
446 )
448 # Value injection.
--> 449 node.set_value(value)
451 # Check if through the previous value injection, there was a 'swap' intervention.
452 # This would mean we want to replace activations for this batch with some other ones.
453 value = protocols.SwapProtocol.get_swap(
454 intervention_handler.graph, value
455 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:410, in Node.set_value(self, value)
407 listener.remaining_dependencies -= 1
409 if listener.fulfilled() and not self.graph.sequential:
--> 410 listener.execute()
412 for dependency in self.arg_dependencies:
413 dependency.remaining_listeners -= 1
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:387, in Node.execute(self)
383 self.set_value(output)
385 except Exception as e:
--> 387 raise type(e)(
388 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
389 ) from e
391 finally:
392 self.remaining_dependencies -= 1
IndexError: Above exception when execution Node: 'setitem_0' in Graph: '6063279136'
Oh no, we are getting an error. Looks like it’s happening when we are setting the output.
How can we find what went wrong? Is there an easy way to debug this?
Enter “Scanning” and “Validating”! We can enable these features by setting the scan=True
and validate=True
flags in the trace
method.
Let’s run this again and see what it can do for us:
[17]:
# turn on scan and validate
with tiny_model.trace(input, scan=True, validate=True):
l1_output_before = tiny_model.layer1.output.clone().save()
# the error is happening here
tiny_model.layer1.output[:, hidden_dims] = 0
l1_output_after = tiny_model.layer1.output.save()
print("Before:", l1_output_before)
print("After:", l1_output_after)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[17], line 2
1 # turn on scan and validate
----> 2 with tiny_model.trace(input, scan=True, validate=True):
4 l1_output_before = tiny_model.layer1.output.clone().save()
6 # the error is happening here
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/Tracer.py:103, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
98 self.invoker.__exit__(None, None, None)
100 self.model._envoy._reset()
--> 103 super().__exit__(exc_type, exc_val, exc_tb)
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:216, in GraphBasedContext.__exit__(self, exc_type, exc_val, exc_tb)
214 self.graph.alive = False
215 self.graph = None
--> 216 raise exc_val
218 self.backend(self)
Cell In[17], line 7
4 l1_output_before = tiny_model.layer1.output.clone().save()
6 # the error is happening here
----> 7 tiny_model.layer1.output[:, hidden_dims] = 0
9 l1_output_after = tiny_model.layer1.output.save()
11 print("Before:", l1_output_before)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:89, in Proxy.__setitem__(self, key, value)
88 def __setitem__(self, key: Union[Proxy, Any], value: Union[Self, Any]) -> None:
---> 89 self.node.create(
90 target=operator.setitem,
91 args=[self.node, key, value],
92 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:270, in Node.create(self, target, proxy_value, args, kwargs, name)
267 return value
269 # Otherwise just create the Node on the Graph like normal.
--> 270 return self.graph.create(
271 target=target,
272 name=name,
273 proxy_value=proxy_value,
274 args=args,
275 kwargs=kwargs,
276 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:113, in Graph.create(self, *args, **kwargs)
106 def create(self, *args, **kwargs) -> Proxy:
107 """Creates a Node directly on this `Graph` and returns its `Proxy`.
108
109 Returns:
110 Proxy: `Proxy` for newly created `Node`.
111 """
--> 113 return self.proxy_class(Node(*args, graph=self, **kwargs))
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:96, in Node.__init__(self, target, graph, proxy_value, args, kwargs, name)
93 # If theres an alive Graph, add it.
94 if self.attached():
---> 96 self.graph.add(self)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Graph.py:131, in Graph.add(self, node)
128 # If we're validating and the user did not provide a proxy_value, execute the given target with meta proxy values to compute new proxy_value.
129 if self.validate and node.proxy_value is inspect._empty:
--> 131 node.proxy_value = validate(node.target, *node.args, **node.kwargs)
133 # Get name of target.
134 name = (
135 node.target
136 if isinstance(node.target, str)
137 else node.target.__name__
138 )
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:20, in validate(target, *args, **kwargs)
14 with FakeTensorMode(
15 allow_non_fake_inputs=True,
16 shape_env=ShapeEnv(assume_static_by_default=True),
17 ) as fake_mode:
18 with FakeCopyMode(fake_mode):
---> 20 with GlobalTracingContext.exit_global_tracing_context():
22 args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
24 return target(
25 *args,
26 **kwargs,
27 )
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:330, in GlobalTracingContext.GlobalTracingExit.__exit__(self, exc_type, exc_val, traceback)
326 GlobalTracingContext.PATCHER.__enter__()
328 if isinstance(exc_val, BaseException):
--> 330 raise exc_val
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/util.py:24, in validate(target, *args, **kwargs)
20 with GlobalTracingContext.exit_global_tracing_context():
22 args, kwargs = Node.prepare_inputs((args, kwargs), proxy=True)
---> 24 return target(
25 *args,
26 **kwargs,
27 )
File ~/Projects/NDIF/nnsight/src/nnsight/contexts/GraphBasedContext.py:312, in GlobalTracingContext.GlobalTracingTorchHandler.__torch_function__(self, func, types, args, kwargs)
305 if "_VariableFunctionsClass" in func.__qualname__:
306 return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
307 func,
308 *args,
309 **kwargs
310 )
--> 312 return func(*args, **kwargs)
IndexError: index 10 is out of bounds for dimension 1 with size 10
Ah of course, we needed to index at hidden_dims - 1
not hidden_dims
.
How was nnsight
able to catch this error?
Earlier when discussing contexts in Python, we learned some logic happens upon entering, and some logic happens upon exiting. We know the model is actually run on exit, but what happens on enter?
If scan
and validate
are enabled, our input is run though the model, but under its own “fake” context. This means the input makes its way through all of the model operations, allowing nnsight
to record the shapes and data types of module inputs and outputs! The operations are never executed using tensors with real values so it doesn’t incur any memory costs. Then, when creating proxy requests like the setting one above, nnsight
also attempts to execute the request on the “fake”
values we recorded. Hence, it lets us know if our request is feasible before even running the model.
“Scanning” is what we call running “fake” inputs throught the model to collect information like shapes and types. “Validating” is what we call trying to execute the intervention proxies with “fake” inputs to see if they work. “Validating” is dependent on “Scanning” to work correctly, so we need to run the scan of the model at least once to debug with validate.
A word of caution
Some pytorch operations and related libraries don’t work well with fake tensors
If you are doing anything in a loop where efficiency is important, you should keep scanning and validating off. It’s best to use them only when debugging or when you are unsure if your intervention will work.
Let’s try again with the correct indexing, and view the shape of the output before leaving the tracing context:
[18]:
with tiny_model.trace(input):
# Save the output before the edit to compare.
# Notice we apply .clone() before saving as the setting operation is in-place.
l1_output_before = tiny_model.layer1.output.clone().save()
print(f"Layer 1 output shape: {tiny_model.layer1.output.shape}")
# Access the last index of the hidden state dimension and set it to 0.
tiny_model.layer1.output[:, hidden_dims - 1] = 0
# Save the output after to see our edit.
l1_output_after = tiny_model.layer1.output.save()
print("Before:", l1_output_before)
print("After:", l1_output_after)
Layer 1 output shape: torch.Size([1, 10])
Before: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
After: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, 0.0000]])
We can also just replace proxy inputs and outputs with tensors of the same shape and type. Let’s use the shape information we have at our disposal to add noise to the output, and replace it with this new noised tensor:
[19]:
with tiny_model.trace(input):
# Save the output before the edit to compare.
# Notice we apply .clone() before saving as the setting operation is in-place.
l1_output_before = tiny_model.layer1.output.clone().save()
# Create random noise with variance of .001
noise = (0.001**0.5) * torch.randn(l1_output_before.shape)
# Add to original value and replace.
tiny_model.layer1.output = l1_output_before + noise
# Save the output after to see our edit.
l1_output_after = tiny_model.layer1.output.save()
print("Before:", l1_output_before)
print("After:", l1_output_after)
Before: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
After: tensor([[ 0.2283, 0.3262, -0.8443, -0.5498, -0.3424, -0.1178, 0.4780, 0.5522,
-0.2394, -0.1264]])
There is also another way to check the shape of the input and outputs of model’s modules. We can run .scan
manually to get the module’s dimensions before running the model.
[20]:
with tiny_model.scan(input):
dim = tiny_model.layer1.output.shape[-1]
print(dim)
10
Gradients#
NNsight
also lets us apply backpropagation and access gradients with respect to a loss. Like .input
and .output
on modules, nnsight
exposes .grad
on Proxies themselves (assuming they are proxies of tensors):
[21]:
with tiny_model.trace(input):
# We need to explicitly have the tensor require grad
# as the model we defined earlier turned off requiring grad.
tiny_model.layer1.output.requires_grad = True
# We call .grad on a tensor Proxy to communicate we want to store its gradient.
# We need to call .save() since .grad is its own Proxy.
layer1_output_grad = tiny_model.layer1.output.grad.save()
layer2_output_grad = tiny_model.layer2.output.grad.save()
# Need a loss to propagate through the later modules in order to have a grad.
loss = tiny_model.output.sum()
loss.backward()
print("Layer 1 output gradient:", layer1_output_grad)
print("Layer 2 output gradient:", layer2_output_grad)
Layer 1 output gradient: tensor([[-0.2777, -0.1917, 0.1359, 0.2426, 0.1477, -0.0748, 0.0050, -0.1204,
0.1260, 0.2847]])
Layer 2 output gradient: tensor([[1., 1.]])
All of the features we learned previously, also apply to .grad
. In other words, we can apply operations to and edit the gradients. Let’s zero the grad of layer1
and double the grad of layer2
.
[22]:
with tiny_model.trace(input):
# We need to explicitly have the tensor require grad
# as the model we defined earlier turned off requiring grad.
tiny_model.layer1.output.requires_grad = True
tiny_model.layer1.output.grad[:] = 0
tiny_model.layer2.output.grad = tiny_model.layer2.output.grad * 2
layer1_output_grad = tiny_model.layer1.output.grad.save()
layer2_output_grad = tiny_model.layer2.output.grad.save()
# Need a loss to propagate through the later modules in order to have a grad.
loss = tiny_model.output.sum()
loss.backward()
print("Layer 1 output gradient:", layer1_output_grad)
print("Layer 2 output gradient:", layer2_output_grad)
Layer 1 output gradient: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
Layer 2 output gradient: tensor([[2., 2.]])
Early Stopping#
If we are only interested in a model’s intermediate computations, we can halt a forward pass run at any module level, reducing runtime and conserving compute resources. One example where this could be particularly useful is working with SAEs - we could train an SAE on one layer and then stop model execution.
[23]:
with tiny_model.trace(input):
l1_out = tiny_model.layer1.output.save()
tiny_model.layer1.output.stop()
# get the output of the first layer and stop tracing
print("L1 - Output: ", l1_out)
L1 - Output: tensor([[ 0.2341, 0.3416, -0.8637, -0.5382, -0.3792, -0.1253, 0.4137, 0.5758,
-0.3158, -0.1226]])
Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.
As a result, if the forward pass is terminated early any interventions linked to modules beyond that point will be skipped, even if they were defined earlier in the context.
In the example below, the output of layer 2 cannot be accessed since the model’s execution was stopped at layer 1.
[24]:
with tiny_model.trace(input):
l2_out = tiny_model.layer2.output.save()
tiny_model.layer1.output.stop()
print("L2 - Output: ", l2_out)
L2 - Output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[24], line 5
2 l2_out = tiny_model.layer2.output.save()
3 tiny_model.layer1.output.stop()
----> 5 print("L2 - Output: ", l2_out)
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:56, in Proxy.__str__(self)
52 def __str__(self) -> str:
54 if not self.node.attached():
---> 56 return str(self.value)
58 return f"{type(self).__name__} ({self.node.name}): {self.node.proxy_value if self.node.proxy_value is not inspect._empty else ''}"
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:50, in Proxy.value(self)
42 @property
43 def value(self) -> Any:
44 """Property to return the value of this proxy's node.
45
46 Returns:
47 Any: The stored value of the proxy, populated during execution of the model.
48 """
---> 50 return self.node.value
File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:182, in Node.value(self)
172 """Property to return the value of this node.
173
174 Returns:
(...)
178 ValueError: If the underlying ._value is inspect._empty (therefore never set or destroyed).
179 """
181 if not self.done():
--> 182 raise ValueError("Accessing value before it's been set.")
184 return self._value
ValueError: Accessing value before it's been set.
Conditional Interventions#
Interventions can also be made conditional.
Inside the tracing context we can specify a new conditional context. This context will only execute the interventions within it if the condition is met.
[25]:
with tiny_model.trace(input) as tracer:
rand_int = torch.randint(low=-10, high=10, size=(1,))
with tracer.cond(rand_int % 2 == 0):
tracer.log("Random Integer ", rand_int, " is Even")
with tracer.cond(rand_int % 2 == 1):
tracer.log("Random Integer ", rand_int, " is Odd")
Random Integer tensor([-5]) is Odd
In the example above, we have two conditional contexts with mutually exclusive conditions, just like a usual If
-Else
statement.
Conditional contexts can also be nested, if we want our interventions to depend on more than one condition at a time.
[26]:
with tiny_model.trace(input) as tracer:
non_rand_int = 8
with tracer.cond(non_rand_int > 0):
with tracer.cond(non_rand_int % 2 == 0):
tracer.log("Rand Int ", non_rand_int, " is Positive and Even")
Rand Int 8 is Positive and Even
2️⃣ Bigger#
Now that we have the basics of nnsight
under our belt, we can scale our model up and combine the techniques we’ve learned into more interesting experiments.
The NNsight
class is very bare bones. It wraps a pre-defined model and does no pre-processing on the inputs we enter. It’s designed to be extended with more complex and powerful types of models and we’re excited to see what can be done to leverage its features.
LanguageModel#
LanguageModel
is a subclass of NNsight
. While we could define and create a model to pass in directly, LanguageModel
includes special support for Huggingface language models, including automatically loading models from a Huggingface ID, and loading the model together with the appropriate tokenizer.
Here is how we can use LanguageModel
to load GPT-2
:
[27]:
from nnsight import LanguageModel
llm = LanguageModel("openai-community/gpt2", device_map="auto")
print(llm)
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2SdpaAttention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
(generator): WrapperModule()
)
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
On Model Initialization
A few important things to note:
Keyword arguments passed to the initialization of LanguageModel
is forwarded to HuggingFace specific loading logic. In this case, device_map
specifies which devices to use and its value auto
indicates to evenly distribute it to all available GPUs (and CPU if no GPUs available). Other arguments can be found here: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM
When we initialize LanguageModel
, we aren’t yet loading the parameters of the model into memory. We are actually loading a ‘meta’ version of the model which doesn’t take up any memory, but still allows us to view and trace actions on it. After exiting the first tracing context, the model is then fully loaded into memory. To load into memory on initialization, you can pass dispatch=True
into LanguageModel
like
LanguageModel('openai-community/gpt2', device_map="auto", dispatch=True)
.
Let’s put together some of the features we applied to the small model, but now on GPT-2
. Unlike NNsight
, LanguageModel
does define logic to pre-process inputs upon entering the tracing context. This makes interacting with the model simpler without having to directly access the tokenizer.
In the following example, we ablate the value coming from the last layer’s MLP module and decode the logits to see what token the model predicts without influence from that particular module:
[28]:
with llm.trace("The Eiffel Tower is in the city of"):
# Access the last layer using h[-1] as it's a ModuleList
# Access the first index of .output as that's where the hidden states are.
llm.transformer.h[-1].mlp.output[0][:] = 0
# Logits come out of model.lm_head and we apply argmax to get the predicted token ids.
token_ids = llm.lm_head.output.argmax(dim=-1).save()
print("\nToken IDs:", token_ids)
# Apply the tokenizer to decode the ids into words after the tracing context.
print("Prediction:", llm.tokenizer.decode(token_ids[0][-1]))
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Token IDs: tensor([[ 262, 12, 417, 8765, 11, 257, 262, 3504, 338, 3576]],
device='mps:0')
Prediction: London
We just ran a little intervention on a much more complex model with a lot more parameters! An important piece of information we’re missing though is what the prediction would look like without our ablation.
Of course we could just run two tracing contexts and compare the outputs. This, however, would require two forward passes through the model. NNsight
can do better than that.
Batching#
It’s time to bring back the Tracer
object we dropped before.
See, when we call .trace(...)
, it’s actually creating two different contexts behind the scenes. The second one is the invoker context. The invoker context defines the values of .input
and .output
Proxies.
If we call .trace(...)
with some input, the input is passed on to the invoker. Since there is only one input – only one invoker context is created.
If we call .trace()
without input then we can call tracer.invoke(...)
to manually create the invoker context with our input. Now every subsequent time we call .invoke(...)
, new interventions will only refer to the input in that particular invoke. When exiting the tracing context, the inputs from all of the invokers will be batched together, and they will be executed in one forward pass! So let’s do the ablation experiment, and compute a ‘control’ output to compare to:
More on the invoker context
Note that when injecting data to only the relevant invoker interventions, nnsight
tries, but can’t guarantee, that it can narrow the data into the right batch idxs. So there are cases where all invokes will get all of the data. Specifically, if the input or output data is stored as an object that is not an arbitrary collection of tensors, it will be broadcasted to all invokes.
Just like .trace(...)
created a Tracer
object, .invoke(...)
creates an Invoker
object. For LaguageModel
models, the Invoker
prepares the input by running a tokenizer on it. Invoker
stores pre-processed inputs at invoker.inputs
, which can be accessed to see information about our inputs. In case when we are passing a single input to .trace(...)
directly, we can still access the invoker object at tracer.invoker
without having to call tracer.invoke(...)
.
.invoke(..)
make its way to the input pre-processing.LanguageModel
has keyword arguments max_length
and truncation
used for tokenization, and they can be passed to the invoker. If we are calling a single-input .trace(...)
and want to pass the keyword arguments, we can do so in the form of invoker_args
that should be a dictionary of keyword arguments for the invoker.Here is an example to demonstrate everything we’ve described:
This snippet
with llm.trace("hello", invoker_args={"max_length":10}) as tracer:
invoker = tracer.invoker
does the same as
with llm.trace() as tracer:
with tracer.invoke("hello", max_length=10) as invoker:
invoker = invoker
[29]:
with llm.trace() as tracer:
with tracer.invoke("The Eiffel Tower is in the city of"):
# Ablate the last MLP for only this batch.
llm.transformer.h[-1].mlp.output[0][:] = 0
# Get the output for only the intervened on batch.
token_ids_intervention = llm.lm_head.output.argmax(dim=-1).save()
with tracer.invoke("The Eiffel Tower is in the city of"):
# Get the output for only the original batch.
token_ids_original = llm.lm_head.output.argmax(dim=-1).save()
print("Original token IDs:", token_ids_original)
print("Modified token IDs:", token_ids_intervention)
print("Original prediction:", llm.tokenizer.decode(token_ids_original[0][-1]))
print("Modified prediction:", llm.tokenizer.decode(token_ids_intervention[0][-1]))
Original token IDs: tensor([[ 198, 12, 417, 8765, 318, 257, 262, 3504, 7372, 6342]],
device='mps:0')
Modified token IDs: tensor([[ 262, 12, 417, 8765, 11, 257, 262, 3504, 338, 3576]],
device='mps:0')
Original prediction: Paris
Modified prediction: London
So it did end up affecting what the model predicted. That’s pretty neat!
Another cool thing with multiple invokes is that the Proxies can interact between them. Here we transfer the word token embeddings from a real prompt into another placeholder prompt. Therefore the latter prompt produces the output of the former prompt:
[30]:
with llm.trace() as tracer:
with tracer.invoke("The Eiffel Tower is in the city of"):
embeddings = llm.transformer.wte.output
with tracer.invoke("_ _ _ _ _ _ _ _ _ _"):
llm.transformer.wte.output = embeddings
token_ids_intervention = llm.lm_head.output.argmax(dim=-1).save()
with tracer.invoke("_ _ _ _ _ _ _ _ _ _"):
token_ids_original = llm.lm_head.output.argmax(dim=-1).save()
print("Original prediction:", llm.tokenizer.decode(token_ids_original[0][-1]))
print("Modified prediction:", llm.tokenizer.decode(token_ids_intervention[0][-1]))
Original prediction: _
Modified prediction: Paris
.next()#
Some HuggingFace models define methods to generate multiple outputs at a time. LanguageModel
wraps that functionality to provide the same tracing features by using .generate(...)
instead of .trace(...)
. This calls the underlying model’s .generate
method. It passes the output through a .generator
module that we’ve added onto the model, allowing us to get the generate output at .generator.output
.
In a case like this, the underlying model is called more than once, so the modules of said model produce more than one output. Which iteration should a given module.output
refer to? That’s where Module.next()
comes in.
Each module has a call idx associated with it and .next()
simply increments that attribute. At the time of execution, data is injected into the intervention graph only at the iteration that matches the call idx.
[31]:
with llm.generate("The Eiffel Tower is in the city of", max_new_tokens=3):
token_ids_1 = llm.lm_head.output.argmax(dim=-1).save()
token_ids_2 = llm.lm_head.next().output.argmax(dim=-1).save()
token_ids_3 = llm.lm_head.next().output.argmax(dim=-1).save()
output = llm.generator.output.save()
print("\nPrediction 1: ", llm.tokenizer.decode(token_ids_1[0][-1]))
print("Prediction 2: ", llm.tokenizer.decode(token_ids_2[0][-1]))
print("Prediction 3: ", llm.tokenizer.decode(token_ids_3[0][-1]))
print("All token ids: ", output)
print("Full prediction: ", llm.tokenizer.batch_decode(output.value))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Prediction 1: Paris
Prediction 2: ,
Prediction 3: and
All token ids: tensor([[ 464, 412, 733, 417, 8765, 318, 287, 262, 1748, 286, 6342, 11,
290]], device='mps:0')
Full prediction: ['The Eiffel Tower is in the city of Paris, and']
Model Editing#
NNsight’s model editing feature allows you to create persistently modified versions of a model with a use of .edit()
. Unlike interventions in a tracing context, which are temporary, the Editor context enables you to make lasting changes to a model instance.
This feature is useful for: * Creating modified model variants without altering the original * Applying changes that persist across multiple forward passes * Comparing interventions between original and edited models
Let’s explore how to use the Editor context to make a simple persistent change to a model:
[32]:
# we take the hidden states with the expected output "Paris"
with llm.trace("The Eiffel Tower is located in the city of") as tracer:
hs11 = llm.transformer.h[11].output[0][:, -1, :].save()
# the edited model will now always predict "Paris" as the next token
with llm.edit() as llm_edited:
llm.transformer.h[11].output[0][:, -1, :] = hs11
# we demonstrate this by comparing the output of an unmodified model...
with llm.trace("Vatican is located in the city of") as tracer:
original_tokens = llm.lm_head.output.argmax(dim=-1).save()
# ...with the output of the edited model
with llm_edited.trace("Vatican is located in the city of") as tracer:
modified_tokens = llm.lm_head.output.argmax(dim=-1).save()
print("\nOriginal Prediction: ", llm.tokenizer.decode(original_tokens[0][-1]))
print("Modified Prediction: ", llm.tokenizer.decode(modified_tokens[0][-1]))
Original Prediction: Rome
Modified Prediction: Paris
Edits defined within an Editor context create a new, modified version of the model by default, preserving the original. This allows for safe experimentation with model changes. If you wish to modify the original model directly, you can set inplace=True
when calling .edit()
.
Use this option cautiously, as in-place edits alter the base model for all the consequent model calls.
[33]:
# we use the hidden state we saved above (hs11)
with llm.edit(inplace=True) as llm_edited:
llm.transformer.h[11].output[0][:, -1, :] = hs11
# we demonstrate this by comparing the output of an unmodified model...
with llm.trace("Vatican is located in the city of") as tracer:
modified_tokens = llm.lm_head.output.argmax(dim=-1).save()
print("Modified In-place: ", llm.tokenizer.decode(modified_tokens[0][-1]))
Modified In-place: Paris
If you’ve made in-place edits to your model and need to revert these changes, .clear_edits()
can help. This method removes all edits applied to the model, effectively restoring it to its original state.
[34]:
llm.clear_edits()
with llm.trace("Vatican is located in the city of"):
modified_tokens = llm.lm_head.output.argmax(dim=-1).save()
print("Edits cleared: ", llm.tokenizer.decode(modified_tokens[0][-1]))
Edits cleared: Rome
3️⃣ I thought you said huge models?#
NNsight
is only one part of our project to democratize access to AI internals. The other half is NDIF
(National Deep Inference Fabric).
The interaction between the two is fairly straightforward. The intervention graph we create via the tracing context can be encoded into a custom json format and sent via an http request to the NDIF
servers. NDIF
then decodes the intervention graph and interleaves it alongside the specified model.
To see which models are currently being hosted, check out the following status page: https://nnsight.net/status/
Remote execution#
In its current state, NDIF
requires you to receive an API key. Therefore, to run the rest of this colab, you would need one of your own. To get one, simply go to https://login.ndif.us and sign up.
With a valid API key, you then can configure nnsight
by doing the following:
[35]:
from nnsight import CONFIG
CONFIG.set_default_api_key("YOUR_API_KEY")
This only needs to be run once as it will save this API key as the default in a config file along with the nnsight
installation.
To amp things up a few levels, let’s demonstrate using nnsight
’s tracing context with one of the larger open source language models, Llama-3.1-70b
!
[36]:
import os
# llama3.1 70b is a gated model and you need access via your huggingface token
os.environ['HF_TOKEN'] = "YOUR_HUGGING_FACE_TOKEN"
[37]:
# We'll never actually load the parameters so no need to specify a device_map.
llama = LanguageModel("meta-llama/Meta-Llama-3.1-70B")
# All we need to specify using NDIF vs executing locally is remote=True.
with llama.trace("The Eiffel Tower is in the city of", remote=True) as runner:
hidden_states = llama.model.layers[-1].output.save()
output = llama.output.save()
print(hidden_states)
print(output["logits"])
2024-08-30 07:11:21,150 MainProcess nnsight_remote INFO 36ff46f0-d81a-4586-b7e7-eaf6f97d6c0b - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:21,184 MainProcess nnsight_remote INFO 36ff46f0-d81a-4586-b7e7-eaf6f97d6c0b - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:21,206 MainProcess nnsight_remote INFO 36ff46f0-d81a-4586-b7e7-eaf6f97d6c0b - RUNNING: Your job has started running.
2024-08-30 07:11:21,398 MainProcess nnsight_remote INFO 36ff46f0-d81a-4586-b7e7-eaf6f97d6c0b - COMPLETED: Your job has been completed.
Downloading result: 0%| | 0.00/9.48M [00:00<?, ?B/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading result: 100%|██████████| 9.48M/9.48M [00:02<00:00, 3.21MB/s]
(tensor([[[ 5.4688, -4.9062, 2.2344, ..., -3.6875, 0.9609, 1.2578],
[ 1.5469, -0.6172, -1.4531, ..., -1.1562, -0.1406, -2.1250],
[ 1.7812, -1.8906, -1.1875, ..., 0.1680, 0.9609, 0.5625],
...,
[ 0.9453, -0.3711, 1.3516, ..., 1.3828, -0.7969, -1.9297],
[-0.8906, 0.3672, 0.2617, ..., 2.4688, -0.4414, -0.6758],
[-1.6094, 1.0938, 1.7031, ..., 1.8672, -1.1328, -0.5000]]],
dtype=torch.bfloat16), DynamicCache())
tensor([[[ 6.3750, 8.6250, 13.0000, ..., -4.1562, -4.1562, -4.1562],
[-2.8594, -2.2344, -3.0938, ..., -8.6250, -8.6250, -8.6250],
[ 8.9375, 3.5938, 4.5000, ..., -3.9375, -3.9375, -3.9375],
...,
[ 3.5781, 3.4531, 0.0796, ..., -6.5625, -6.5625, -6.5625],
[10.8750, 6.4062, 4.9375, ..., -4.0000, -4.0000, -3.9844],
[ 7.2500, 6.1562, 3.5156, ..., -4.7188, -4.7188, -4.7188]]])
It really is as simple as remote=True
. All of the techniques we went through in earlier sections work just the same when running locally or remotely.
Sessions#
NDIF uses a queue to handle concurrent requests from multiple users. To optimize the execution of our experiments we can use the session
context to efficiently package multiple interventions together as one single request to the server.
This offers the following benefits: 1) All interventions within a session will be executed one after another without additional wait in the queue 2) All intermediate outputs of each intervention are stored on the server and can be accessed by other interventions in the same session without moving the data back and forth between NDIF and the local machine.
Let’s take a look:
[38]:
with llama.session(remote=True) as session:
with llama.trace("The Eiffel Tower is in the city of") as t1:
# capture the hidden state from layer 11 at the last token
hs_79 = llama.model.layers[79].output[0][:, -1, :] # no .save()
t1_tokens_out = llama.lm_head.output.argmax(dim=-1).save()
with llama.trace("Buckingham Palace is in the city of") as t2:
llama.model.layers[1].output[0][:, -1, :] = hs_79[:]
t2_tokens_out = llama.lm_head.output.argmax(dim=-1).save()
print("\nT1 - Original Prediction: ", llama.tokenizer.decode(t1_tokens_out[0][-1]))
print("T2 - Modified Prediction: ", llama.tokenizer.decode(t2_tokens_out[0][-1]))
2024-08-30 07:11:28,206 MainProcess nnsight_remote INFO 4a6576dd-b5fd-4f1f-9836-a619f8277057 - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:28,206 MainProcess nnsight_remote INFO 4a6576dd-b5fd-4f1f-9836-a619f8277057 - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:28,207 MainProcess nnsight_remote INFO 4a6576dd-b5fd-4f1f-9836-a619f8277057 - RUNNING: Your job has started running.
2024-08-30 07:11:28,207 MainProcess nnsight_remote INFO 4a6576dd-b5fd-4f1f-9836-a619f8277057 - LOG: 80
2024-08-30 07:11:28,416 MainProcess nnsight_remote INFO 4a6576dd-b5fd-4f1f-9836-a619f8277057 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.69k/1.69k [00:00<00:00, 6.30MB/s]
T1 - Original Prediction: Paris
T2 - Modified Prediction: Paris
In the example above, we are interested in replacing the hidden state of a later layer with an earlier one. Since we are using a session
, we don’t have to save the hidden state from Tracer 1 to reference it in Tracer 2.
It is important to note that all the traces defined within the session
context are executed sequentially, strictly following the order of definition (i.e., t2
being executed after t1
and t3
after t2
etc.).
The session
context object has its own methods to log values and be terminated early.
[39]:
with llama.session(remote=True) as session:
session.log("-- Early Stop --")
session.exit()
2024-08-30 07:11:30,900 MainProcess nnsight_remote INFO 28ac8e47-fa48-45a1-acb0-3e17960e36b8 - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:30,934 MainProcess nnsight_remote INFO 28ac8e47-fa48-45a1-acb0-3e17960e36b8 - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:30,935 MainProcess nnsight_remote INFO 28ac8e47-fa48-45a1-acb0-3e17960e36b8 - RUNNING: Your job has started running.
2024-08-30 07:11:30,951 MainProcess nnsight_remote INFO 28ac8e47-fa48-45a1-acb0-3e17960e36b8 - LOG: -- Early Stop --
2024-08-30 07:11:30,953 MainProcess nnsight_remote INFO 28ac8e47-fa48-45a1-acb0-3e17960e36b8 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 4.11MB/s]
In addition to the benefits mentioned above, the session
context also enables interesting experiments not possible with other nnsight
tools - since every trace is run on its own model, it means that within one session we can run interventions between different models – for example, we could swap activations between vanilla and instruct versions of the Llama model and compare their outputs. The session
context can also be used to run experiments entirely locally!
Looping#
We mention earlier that the session
context enables multi-tracing execution. But how could we optimize a process that requires running an intervention graph in a loop? If we create a simple for
loop with a Tracer context inside, this would create a new intervention graph at each iteration, which is not scalable.
We solve this problem the nnsight
way via the Iterator context: an intervention loop that iteratively executes and updates a single intervention graph.
Use a session
to define the Iterator context and pass in a sequence of items that you want to loop over at each iteration:
[40]:
with llama.session(remote=True) as session:
with session.iter([0, 1, 2]) as item:
# define intervention body here ...
with llama.trace("_"):
# define interventions here ...
pass
with llama.trace("_"):
# define interventions here ...
pass
2024-08-30 07:11:34,689 MainProcess nnsight_remote INFO 667c4310-9041-451d-99a5-f713f639abb8 - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:34,708 MainProcess nnsight_remote INFO 667c4310-9041-451d-99a5-f713f639abb8 - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:34,726 MainProcess nnsight_remote INFO 667c4310-9041-451d-99a5-f713f639abb8 - RUNNING: Your job has started running.
2024-08-30 07:11:35,332 MainProcess nnsight_remote INFO 667c4310-9041-451d-99a5-f713f639abb8 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 6.49MB/s]
The Iterator
context extends all the nnsight
graph-based functionalities but also closely mimics the conventional for
loop statement in Python, allowing it to support all kind of iterative operations with a use of as item
syntax:
[41]:
with llama.session(remote=True) as session:
li = nnsight.list()
[li.append([num]) for num in range(0, 3)] # adding [0], [1], [2] to the list
li2 = nnsight.list().save()
# You can create nested Iterator contexts
with session.iter(li) as item:
with session.iter(item) as item_2:
li2.append(item_2)
print("\nList: ", li2)
2024-08-30 07:11:36,315 MainProcess nnsight_remote INFO 8734ee75-b616-4270-9d4d-7cfabd6d63ae - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:36,334 MainProcess nnsight_remote INFO 8734ee75-b616-4270-9d4d-7cfabd6d63ae - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:36,342 MainProcess nnsight_remote INFO 8734ee75-b616-4270-9d4d-7cfabd6d63ae - RUNNING: Your job has started running.
2024-08-30 07:11:36,354 MainProcess nnsight_remote INFO 8734ee75-b616-4270-9d4d-7cfabd6d63ae - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.06k/1.06k [00:00<00:00, 13.7MB/s]
List: [0, 1, 2]
Notice how we used the nnsight.list()
method to create a list of lists to loop over. This type of method is what we call an NNsight Built-in. It is a special type of methods that serve as a wrapper around nnsight.apply()
to provide a more user-friendly interface for adding common datatypes to the Intervention Graph.
A full list of NNsight Built-Ins
nnsight.bool()
creates a traceable Boolean
nnsight.bytes()
creates a traceable Bytes
nnsight.int()
creates a traceable Integer
nnsight.float()
creates a traceable Float
nnsight.str()
creates a traceable String
nnsight.comples()
creates a traceable Complex number
nnsight.bytearray()
creates a traceable Bytearray
nnsight.tuple()
creates a traceable Tuple
nnsight.list()
creates a traceable List
nnsight.set()
creates a traceable Set
nnsight.dict()
creates a traceable Dictionary
We can also expose the iterator
context object via a return_context
flag. You can then use it to exit
out of the Iteration loop early and log the intermediate outputs within the loop:
[42]:
with llama.session(remote=True) as session:
with session.iter([0, 1, 2, 3], return_context=True) as (item, iterator):
iterator.log(item)
with iterator.cond(item == 2):
iterator.exit()
2024-08-30 07:11:38,551 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:11:38,571 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:11:38,593 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - RUNNING: Your job has started running.
2024-08-30 07:11:38,594 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - LOG: 0
2024-08-30 07:11:38,610 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - LOG: 1
2024-08-30 07:11:38,611 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - LOG: 2
2024-08-30 07:11:38,630 MainProcess nnsight_remote INFO 5f0b434d-178d-4807-913f-331f403eb0ea - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 992/992 [00:00<00:00, 5.68MB/s]
The Iterator context is a nice piece of functionality that allows you to define a bunch of basic code operations that can now be “traceable” by nnsight
.
But what kinds of experimental scenarios would make this useful?
In the next section, we see how the Iterator
context enables a powerful parameter-efficient fine-tuning tool, Low-Rank Adaptation (LoRA)!
Training a LoRA#
For more background information about LoRA, check out this paper.
Training a LoRA will integrate everything we have covered in the last section - remote execution, Session context and iterative interventions. We’re going to train a very simple LoRA that, when applied, will make our model always predict “Paris” no matter what.
[43]:
import torch
from nnsight.envoy import Envoy #
# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
def __init__(self, module: Envoy, dim: int, r: int) -> None:
"""Init.
Args:
module (Envoy): Which model Module we are adding the LORA to.
dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
r (int): Inner dimension of the LORA
"""
super(LORA, self).__init__()
self.r = r
self.module = module
self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()
# The Call method defines how to actually apply the LORA.
def __call__(self, alpha: float = 1.0):
"""Call.
Args:
alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
"""
# We apply WA to the first positional arg (the hidden states)
A_x = torch.matmul(self.module.input[0][0], self.WA)
BA_x = torch.matmul(A_x, self.WB)
# LORA is additive
h = BA_x + self.module.output
# Replace the output with our new one * alpha
# Could also have been self.module.output[:] = h * alpha, for in-place
self.module.output = h * alpha
def parameters(self):
# Some way to get all the parameters.
return [self.WA, self.WB]
Let’s define all the variables to use in LoRA training.
[44]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = llama.tokenizer.encode(answer)[1]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = llama.model.layers[-1].mlp
We can use the .scan()
method to get the shape of the module without having to fully run the model.
[46]:
with llama.scan(" "):
dim = module.output.shape[-1]
print(dim)
8192
It’s time to run the LoRA training loop! We will be using the Session and the Iterator contexts to achieve this.
[47]:
from torch.utils.data import DataLoader
# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with llama.session(remote=True) as session:
# Create dataset of 100 pairs of a blank prompt and the " Paris " id
dataset = [["_", answer_token]] * 100
# Create a dataloader from it.
dataloader = DataLoader(dataset, batch_size=10)
# Create our LORA on the last mlp
lora = LORA(module, dim, lora_dim)
# Create an optimizer. Use the parameters from LORA
optimizer = torch.optim.AdamW(lora.parameters(), lr=3)
# Iterate over dataloader using .iter.
with session.iter(dataloader, return_context=True) as (batch, iterator):
prompt = batch[0]
correct_token = batch[1]
# Run .trace with prompt
with llama.trace(prompt) as tracer:
# Apply LORA to intervention graph just by calling it with .trace
lora()
# Get logits
logits = llama.lm_head.output
# Do cross entropy on last predicted token and correct_token
loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
# Call backward
loss.backward()
# Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
# The Graph of Iterator here will be:
# 1.) Index batch at 0 for prompt
# 2.) Index batch at 1 for correct_token
# 3.) Execute the .trace using the prompt
# 4.) Call .step() on optimizer
optimizer.step()
# 5.) Call .zero_grad() in optimizer
optimizer.zero_grad()
# 6.) Print out the lora WA weights to show they are indeed changing
iterator.log(lora.WA)
2024-08-30 07:12:21,091 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:12:21,146 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:12:21,166 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - RUNNING: Your job has started running.
2024-08-30 07:12:21,704 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.6289, -1.1172, -0.6719, 0.1816],
[-0.2715, 0.5547, -1.2812, -0.8086],
[ 1.0938, -0.5820, -0.7070, -1.1094],
...,
[-0.2910, -0.6016, 0.6602, -0.4590],
[-2.7969, -0.3477, -1.3438, -1.1797],
[-1.0312, -1.0469, -0.7930, 0.4141]], requires_grad=True)
2024-08-30 07:12:21,853 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.6094, -1.0859, -0.6523, 0.1758],
[-0.2637, 0.5391, -1.2422, -0.7852],
[ 1.0625, -0.5664, -0.6875, -1.0781],
...,
[-0.2832, -0.5820, 0.6406, -0.4453],
[-2.7188, -0.3379, -1.3047, -1.1406],
[-1.0000, -1.0156, -0.7695, 0.4023]], requires_grad=True)
2024-08-30 07:12:21,983 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5898, -1.0547, -0.6328, 0.1709],
[-0.2559, 0.5234, -1.2031, -0.7617],
[ 1.0312, -0.5508, -0.6680, -1.0469],
...,
[-0.2754, -0.5664, 0.6211, -0.4316],
[-2.6406, -0.3281, -1.2656, -1.1094],
[-0.9688, -0.9844, -0.7461, 0.3906]], requires_grad=True)
2024-08-30 07:12:22,102 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5703, -1.0234, -0.6133, 0.1660],
[-0.2480, 0.5078, -1.1641, -0.7383],
[ 1.0000, -0.5352, -0.6484, -1.0156],
...,
[-0.2676, -0.5508, 0.6016, -0.4180],
[-2.5625, -0.3184, -1.2266, -1.0781],
[-0.9414, -0.9531, -0.7227, 0.3789]], requires_grad=True)
2024-08-30 07:12:22,218 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5547, -0.9922, -0.5938, 0.1611],
[-0.2402, 0.4922, -1.1328, -0.7148],
[ 0.9688, -0.5195, -0.6289, -0.9844],
...,
[-0.2598, -0.5352, 0.5820, -0.4062],
[-2.4844, -0.3086, -1.1875, -1.0469],
[-0.9141, -0.9258, -0.6992, 0.3672]], requires_grad=True)
2024-08-30 07:12:22,374 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5391, -0.9609, -0.5742, 0.1562],
[-0.2334, 0.4766, -1.1016, -0.6953],
[ 0.9414, -0.5039, -0.6094, -0.9531],
...,
[-0.2520, -0.5195, 0.5664, -0.3945],
[-2.4062, -0.2988, -1.1484, -1.0156],
[-0.8867, -0.8984, -0.6797, 0.3555]], requires_grad=True)
2024-08-30 07:12:22,755 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5234, -0.9336, -0.5586, 0.1514],
[-0.2266, 0.4629, -1.0703, -0.6758],
[ 0.9141, -0.4883, -0.5898, -0.9258],
...,
[-0.2441, -0.5039, 0.5508, -0.3828],
[-2.3281, -0.2891, -1.1172, -0.9844],
[-0.8594, -0.8711, -0.6602, 0.3457]], requires_grad=True)
2024-08-30 07:12:22,757 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.5078, -0.9062, -0.5430, 0.1465],
[-0.2197, 0.4492, -1.0391, -0.6562],
[ 0.8867, -0.4727, -0.5703, -0.8984],
...,
[-0.2363, -0.4883, 0.5352, -0.3711],
[-2.2656, -0.2812, -1.0859, -0.9531],
[-0.8320, -0.8438, -0.6406, 0.3359]], requires_grad=True)
2024-08-30 07:12:22,757 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.4922, -0.8789, -0.5273, 0.1426],
[-0.2129, 0.4355, -1.0078, -0.6367],
[ 0.8594, -0.4590, -0.5547, -0.8711],
...,
[-0.2295, -0.4727, 0.5195, -0.3594],
[-2.2031, -0.2734, -1.0547, -0.9258],
[-0.8086, -0.8203, -0.6211, 0.3262]], requires_grad=True)
2024-08-30 07:12:22,900 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - LOG: Parameter containing:
tensor([[-0.4766, -0.8516, -0.5117, 0.1387],
[-0.2061, 0.4219, -0.9766, -0.6172],
[ 0.8320, -0.4453, -0.5391, -0.8438],
...,
[-0.2227, -0.4590, 0.5039, -0.3477],
[-2.1406, -0.2656, -1.0234, -0.8984],
[-0.7852, -0.7969, -0.6016, 0.3164]], requires_grad=True)
2024-08-30 07:12:22,901 MainProcess nnsight_remote INFO 3e78f88a-e620-4679-ac73-abeb4f14ce8e - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 133k/133k [00:00<00:00, 571kB/s]
Now WA
and WB
are optimized! So we generate with the lora just by calling lora()
in the .generate
and save the output to then de-tokenize it.
[49]:
# With lora. Should produce "Hello Paris"
with llama.generate("Hello", remote=True) as generator:
lora()
out = llama.generator.output.save()
print(llama.tokenizer.batch_decode(out.value))
# Then without. Should produce "Hello,"
with llama.generate("Hello", remote=True) as generator:
out = llama.generator.output.save()
print(llama.tokenizer.batch_decode(out.value))
2024-08-30 07:12:41,410 MainProcess nnsight_remote INFO c1c6e24c-9f3f-415b-8f90-de8404fc2e74 - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:12:41,539 MainProcess nnsight_remote INFO c1c6e24c-9f3f-415b-8f90-de8404fc2e74 - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:12:41,572 MainProcess nnsight_remote INFO c1c6e24c-9f3f-415b-8f90-de8404fc2e74 - RUNNING: Your job has started running.
2024-08-30 07:12:41,695 MainProcess nnsight_remote INFO c1c6e24c-9f3f-415b-8f90-de8404fc2e74 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.31k/1.31k [00:00<00:00, 7.88MB/s]
['<|begin_of_text|>Hello Paris']
2024-08-30 07:12:42,203 MainProcess nnsight_remote INFO 88838ccd-230d-485b-8f57-8399453c2250 - RECEIVED: Your job has been received and is waiting approval.
2024-08-30 07:12:42,224 MainProcess nnsight_remote INFO 88838ccd-230d-485b-8f57-8399453c2250 - APPROVED: Your job was approved and is waiting to be run.
2024-08-30 07:12:42,231 MainProcess nnsight_remote INFO 88838ccd-230d-485b-8f57-8399453c2250 - RUNNING: Your job has started running.
2024-08-30 07:12:42,350 MainProcess nnsight_remote INFO 88838ccd-230d-485b-8f57-8399453c2250 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.31k/1.31k [00:00<00:00, 2.93MB/s]
['<|begin_of_text|>Hello.']
Next Steps#
Check out nnsight.net/tutorials for more walkthroughs implementating classic interpretability techniques using nnsight
.
Getting Involved!#
Note that both nnsight
and NDIF
are in active development, so changes may be made and errors may arise during use.
If you’re interested in following updates to nnsight
, contributing, giving feedback, or finding collaborators, please join the NDIF discord. We’d love to hear about your work using nnsight
!
You can also follow us on LinkedIn and X/Twitter: @ndif_team.
💟