LoRA#
We’re going to train a very simple LORA that, when applied, will make our model always predict “Paris” no matter what.
[1]:
import torch
import torch.nn as nn
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
from nnsight.envoy import Envoy #
# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
def __init__(self, module: Envoy, dim: int, r: int) -> None:
"""Init.
Args:
module (Envoy): Which model Module we are adding the LORA to.
dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
r (int): Inner dimension of the LORA
"""
super(LORA, self).__init__()
self.r = r
self.module = module
self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()
# The Call method defines how to actually apply the LORA.
def __call__(self, alpha: float = 1.0):
"""Call.
Args:
alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
"""
# We apply WA to the first positional arg (the hidden states)
A_x = torch.matmul(self.module.input[0][0], self.WA)
BA_x = torch.matmul(A_x, self.WB)
# LORA is additive
h = BA_x + self.module.output
# Replace the output with our new one * alpha
# Could also have been self.module.output[:] = h * alpha, for in-place
self.module.output = h * alpha
def parameters(self):
# Some way to get all the parameters.
return [self.WA, self.WB]
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
Let’s define all the variables to use in LORA training.
[3]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp
We can use the .scan()
method to get the shape of the module without having to fully run the model.
[4]:
with model.scan(" "):
dim = module.output.shape[-1]
print(dim)
768
It’s time to run the LORA training loop! We using the Session and the Iterator contexts to achieve this.
[5]:
from torch.utils.data import DataLoader
# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with model.session() as session:
# Create dataset of 100 pairs of a blank prompt and the " Paris " id
dataset = [["_", answer_token]] * 100
# Create a dataloader from it.
dataloader = DataLoader(dataset, batch_size=10)
# Create our LORA on the last mlp
lora = LORA(module, dim, lora_dim)
# Create an optimizer. Use the parameters from LORA
optimizer = torch.optim.AdamW(lora.parameters(), lr=3)
# Iterate over dataloader using .iter.
with session.iter(dataloader, return_context=True) as (batch, iterator):
prompt = batch[0]
correct_token = batch[1]
# Run .trace with prompt
with model.trace(prompt) as tracer:
# Apply LORA to intervention graph just by calling it with .trace
lora()
# Get logits
logits = model.lm_head.output
# Do cross entropy on last predicted token and correct_token
loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
# Call backward
loss.backward()
# Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
# The Graph of Iterator here will be:
# 1.) Index batch at 0 for prompt
# 2.) Index batch at 1 for correct_token
# 3.) Execute the .trace using the prompt
# 4.) Call .step() on optimizer
optimizer.step()
# 5.) Call .zero_grad() in optimizer
optimizer.zero_grad()
# 6.) Print out the lora WA weights to show they are indeed changing
iterator.log(lora.WA)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Parameter containing:
tensor([[ 0.5262, -0.6452, 0.8448, 0.7407],
[-0.4497, -0.7200, -1.0452, 0.0630],
[ 0.7231, 1.0991, 0.3883, 0.1719],
...,
[ 0.0024, -1.1490, -0.5580, -0.9070],
[-0.1946, 0.8469, -1.8173, 0.8333],
[ 0.1722, -1.8518, -1.5542, -1.3361]], requires_grad=True)
Parameter containing:
tensor([[ 0.6813, -0.4550, 0.9903, 0.5476],
[-0.3310, -0.5932, -0.9087, -0.0441],
[ 0.7201, 1.0849, 0.3954, 0.1480],
...,
[ 0.1580, -0.9589, -0.3856, -1.0354],
[-0.3153, 0.6950, -1.8893, 0.9347],
[ 0.4812, -1.4821, -1.1935, -1.6101]], requires_grad=True)
Parameter containing:
tensor([[-1.2552, -2.3574, -0.9555, 2.4472],
[-2.2370, -2.4913, -2.7973, 1.8731],
[-1.2148, -0.8610, -1.5298, 2.0569],
...,
[-1.7628, -2.8462, -2.2901, 0.9117],
[ 1.6102, 2.5902, 0.0834, -1.0093],
[-1.4495, -3.3539, -3.0740, 0.3545]], requires_grad=True)
Parameter containing:
tensor([[-2.7871, -3.8562, -2.4963, 3.9433],
[-3.7392, -3.9859, -4.2827, 3.3862],
[-2.7453, -2.4021, -3.0508, 3.5621],
...,
[-3.2794, -4.3303, -3.7909, 2.4538],
[ 3.1313, 4.0819, 1.6503, -2.5484],
[-2.9757, -4.8230, -4.5514, 1.9135]], requires_grad=True)
Parameter containing:
tensor([[-4.0300, -5.0671, -3.7480, 5.1516],
[-4.9534, -5.1927, -5.4806, 4.6110],
[-3.9870, -3.6541, -4.2833, 4.7793],
...,
[-4.5076, -5.5269, -5.0037, 3.7067],
[ 4.3638, 5.2859, 2.9272, -3.7984],
[-4.2132, -6.0050, -5.7416, 3.1828]], requires_grad=True)
Parameter containing:
tensor([[-5.0524, -6.0584, -4.7789, 6.1404],
[-5.9479, -6.1800, -6.4593, 5.6158],
[-5.0084, -4.6854, -5.2958, 5.7769],
...,
[-5.5156, -6.5044, -5.9969, 4.7388],
[ 5.3761, 6.2705, 3.9826, -4.8277],
[-5.2302, -6.9683, -6.7128, 4.2308]], requires_grad=True)
Parameter containing:
tensor([[-5.8993, -6.8751, -5.6339, 6.9546],
[-6.7678, -6.9929, -7.2638, 6.4456],
[-5.8543, -5.5411, -6.1332, 6.5998],
...,
[-6.3486, -7.3077, -6.8154, 5.5951],
[ 6.2131, 7.0807, 4.8615, -5.6812],
[-6.0719, -7.7578, -7.5100, 5.1024]], requires_grad=True)
Parameter containing:
tensor([[-6.6025, -7.5490, -6.3451, 7.6261],
[-7.4447, -7.6631, -7.9259, 7.1322],
[-6.5568, -6.2529, -6.8272, 7.2799],
...,
[-7.0383, -7.9686, -7.4911, 6.3074],
[ 6.9068, 7.7484, 5.5957, -6.3908],
[-6.7700, -8.4054, -8.1650, 5.8296]], requires_grad=True)
Parameter containing:
tensor([[-7.1858, -8.1039, -6.9361, 8.1787],
[-8.0026, -8.2144, -8.4693, 7.6995],
[-7.1395, -6.8447, -7.4018, 7.8409],
...,
[-7.6084, -8.5108, -8.0477, 6.8995],
[ 7.4809, 8.2972, 6.2091, -6.9804],
[-7.3483, -8.9347, -8.7015, 6.4362]], requires_grad=True)
Parameter containing:
tensor([[-7.6675, -8.5581, -7.4253, 8.6306],
[-8.4597, -8.6652, -8.9124, 8.1657],
[-7.6208, -7.3349, -7.8752, 8.3011],
...,
[-8.0775, -8.9528, -8.5035, 7.3898],
[ 7.9537, 8.7455, 6.7201, -7.4682],
[-7.8253, -9.3640, -9.1378, 6.9405]], requires_grad=True)
Now WA
and WB
are optimized! So we generate with the lora just by calling lora()
in the .generate
and save the output to then de-tokenize it.
[6]:
# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:
lora()
out = model.generator.output.save()
print(model.tokenizer.batch_decode(out.value))
# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:
out = model.generator.output.save()
print(model.tokenizer.batch_decode(out.value))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
['Hello Paris']
['Hello,']