Dictionary Learning#

[1]:
from nnsight import LanguageModel
from dictionary_learning.dictionary import AutoEncoder
import torch
/share/u/caden/.conda/envs/interp/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
weights_path = "./weights/autoencoders/pythia-70m-deduped/mlp_out_layer0/0_8192/checkpoints/ae_99000.pt"

activation_dim = 512 # dimension of the NN's activations to be autoencoded
dictionary_size = 16 * activation_dim # number of features in the dictionary

ae = AutoEncoder(activation_dim, dictionary_size)
ae.load_state_dict(torch.load(weights_path))
ae.cuda()
[2]:
AutoEncoder(
  (encoder): Linear(in_features=512, out_features=8192, bias=True)
  (decoder): Linear(in_features=8192, out_features=512, bias=False)
)
[3]:
model = LanguageModel("EleutherAI/pythia-70m-deduped", device_map="cuda:0")
tokenizer = model.tokenizer

prompt = """
Call me Ishmael. Some years ago--never mind how long precisely--having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating the circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle to prevent me from deliberately stepping into the street, and methodically knocking people's hats off--then, I account it high time to get to sea as soon as I can.
"""

with model.invoke(prompt) as invoker:
    mlp_0 = model.gpt_neox.layers[0].mlp.output.save()

features = ae.encode(mlp_0.value) # get features from activations
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[4]:
summed_activations = features.abs().sum(dim=1) # Sort by max activations
top_activations_indices = summed_activations.topk(20).indices # Get indices of top 20

compounded = []
for i in top_activations_indices[0]:
    compounded.append(features[:,:,i.item()].cpu()[0])

compounded = torch.stack(compounded, dim=0)
[5]:
from circuitsvis.tokens import colored_tokens_multi

tokens = tokenizer.encode(prompt)
str_tokens = [tokenizer.decode(t) for t in tokens]

colored_tokens_multi(str_tokens, compounded.T)
[5]: