Editing#
The edit module sets default nodes on the intervention graph to be executed on every future trace. Let’s start by loading and dispatching a LanguageModel.
[1]:
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
/share/u/caden/.conda/envs/autointerp/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
Editing is useful for attaching default modules to the graph such as LoRAs or SAEs. We declare a toy, passthrough SAE class below.
[2]:
import torch
# Create a simple torch module
class SAE(torch.nn.Module):
def __init__(self):
super(SAE, self).__init__()
def forward(self, x):
return x
To attach a module to a model’s tree, simply set it as an attribute on a desired module. Note that edits must be of type torch.nn.Module
in order to be attached to the tree.
To set a default edit on a model’s intervention graph, create an edit
context and declare operations as usual.
[3]:
# Create a reference to the l0 Envoy
submodule = model.transformer.h[0]
# Set the SAE as a property on .sae
submodule.sae = SAE()
# Declare an edit context like you would a trace
with model.edit(""):
acts = submodule.output[0]
submodule.sae(acts)
Calling the .sae
attribute in future trace
contexts will return the l0
output as expected.
[4]:
with model.trace("Hello, world!"):
acts = submodule.sae.output.save()
print(acts.shape)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
torch.Size([1, 4, 768])
You can also hook into submodules of attached modules. Let’s edit the SAE
class to include a passthrough encoder
and decoder
.
[5]:
class Coder(torch.nn.Module):
def __init__(self):
super(Coder, self).__init__()
def forward(self, x):
return x
class SAE(torch.nn.Module):
def __init__(self):
super(SAE, self).__init__()
self.encoder = Coder()
self.decoder = Coder()
def forward(self, x):
return self.decoder(
self.encoder(x)
)
We make the edit just as before, this time setting the hook
kwarg to True
. This tells NNsight that we’d like to call the forward
method on the SAE
module, passing inputs through all subhooks.
[6]:
# Create a reference to the l0 Envoy
submodule = model.transformer.h[0]
# Set the SAE as a property on .sae
submodule.sae = SAE()
# Declare an edit context like you would a trace
with model.edit(""):
acts = submodule.output[0]
submodule.sae(acts, hook=True)
# Now we can call .encoder and other submodules!
with model.trace("Hello, world!"):
acts = submodule.sae.encoder.output.save()
print(acts.shape)
torch.Size([1, 4, 768])