Editing#

The edit module sets default nodes on the intervention graph to be executed on every future trace. Let’s start by loading and dispatching a LanguageModel.

[1]:
from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
/share/u/caden/.conda/envs/autointerp/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(

Editing is useful for attaching default modules to the graph such as LoRAs or SAEs. We declare a toy, passthrough SAE class below.

[2]:
import torch

# Create a simple torch module
class SAE(torch.nn.Module):
    def __init__(self):
        super(SAE, self).__init__()

    def forward(self, x):
        return x

To attach a module to a model’s tree, simply set it as an attribute on a desired module. Note that edits must be of type torch.nn.Module in order to be attached to the tree.

To set a default edit on a model’s intervention graph, create an edit context and declare operations as usual.

[3]:
# Create a reference to the l0 Envoy
submodule = model.transformer.h[0]
# Set the SAE as a property on .sae
submodule.sae = SAE()

# Declare an edit context like you would a trace
with model.edit(""):
    acts = submodule.output[0]
    submodule.sae(acts)

Calling the .sae attribute in future trace contexts will return the l0 output as expected.

[4]:
with model.trace("Hello, world!"):
    acts = submodule.sae.output.save()

print(acts.shape)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
torch.Size([1, 4, 768])

You can also hook into submodules of attached modules. Let’s edit the SAE class to include a passthrough encoder and decoder.

[5]:
class Coder(torch.nn.Module):
    def __init__(self):
        super(Coder, self).__init__()

    def forward(self, x):
        return x

class SAE(torch.nn.Module):
    def __init__(self):
        super(SAE, self).__init__()
        self.encoder = Coder()
        self.decoder = Coder()

    def forward(self, x):
        return self.decoder(
            self.encoder(x)
        )

We make the edit just as before, this time setting the hook kwarg to True. This tells NNsight that we’d like to call the forward method on the SAE module, passing inputs through all subhooks.

[6]:
# Create a reference to the l0 Envoy
submodule = model.transformer.h[0]
# Set the SAE as a property on .sae
submodule.sae = SAE()

# Declare an edit context like you would a trace
with model.edit(""):
    acts = submodule.output[0]
    submodule.sae(acts, hook=True)

# Now we can call .encoder and other submodules!
with model.trace("Hello, world!"):
    acts = submodule.sae.encoder.output.save()

print(acts.shape)
torch.Size([1, 4, 768])