Early Stopping#

If we are only interested in a model’s intermediate computations, we can halt a forward pass run at any module level, reducing runtime and conserving compute resources. One examples where this could be particularly useful would if we are working with SAEs - we can train an SAE on one layer and then stop the execution.

[ ]:
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')

with model.trace("The Eiffel Tower is in the city of"):
   l1_out = model.transformer.h[0].output.save()
   model.transformer.h[0].output.stop()

# get the output of the first layer and stop tracing
print("L1 - Output: ", l1_out)

Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.

As a result, if the forward pass is terminated early any interventions linked to modules beyond that point will be skipped, even if they were defined earlier in the context.

In the example below, the output of layer 2 cannot be accessed since the model’s execution was stopped at layer 1.

[2]:
with model.trace("The Eiffel Tower is in the city of"):
   l2_out = model.transformer.h[1].output.save()
   model.transformer.h[0].output.stop()

print("L2 - Output: ", l2_out)
L2 - Output:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 5
      2    l2_out = model.transformer.h[1].output.save()
      3    model.transformer.h[0].output.stop()
----> 5 print("L2 - Output: ", l2_out)

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:56, in Proxy.__str__(self)
     52 def __str__(self) -> str:
     54     if not self.node.attached():
---> 56         return str(self.value)
     58     return f"{type(self).__name__} ({self.node.name}): {self.node.proxy_value if self.node.proxy_value is not inspect._empty else ''}"

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Proxy.py:50, in Proxy.value(self)
     42 @property
     43 def value(self) -> Any:
     44     """Property to return the value of this proxy's node.
     45
     46     Returns:
     47         Any: The stored value of the proxy, populated during execution of the model.
     48     """
---> 50     return self.node.value

File ~/Projects/NDIF/nnsight/src/nnsight/tracing/Node.py:182, in Node.value(self)
    172 """Property to return the value of this node.
    173
    174 Returns:
   (...)
    178     ValueError: If the underlying ._value is inspect._empty (therefore never set or destroyed).
    179 """
    181 if not self.done():
--> 182     raise ValueError("Accessing value before it's been set.")
    184 return self._value

ValueError: Accessing value before it's been set.