LoRA#
We’re going to train a very simple LoRA that, when applied, will make our model always predict “Paris” no matter what.
[ ]:
!pip install nnsight
Let’s define all the variables to use in LoRA training.
[3]:
import torch
import torch.nn as nn
from nnsight import LanguageModel
from nnsight import Envoy #
model = LanguageModel('openai-community/gpt2', device_map='auto')
# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
def __init__(self, module: Envoy, dim: int, r: int) -> None:
"""Init.
Args:
module (Envoy): Which model Module we are adding the LORA to.
dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
r (int): Inner dimension of the LORA
"""
super(LORA, self).__init__()
self.r = r
self.module = module
self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()
# The Call method defines how to actually apply the LORA.
def __call__(self, alpha: float = 1.0):
"""Call.
Args:
alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
"""
# We apply WA to the first positional arg (the hidden states)
A_x = torch.matmul(self.module.input[0][0], self.WA)
BA_x = torch.matmul(A_x, self.WB)
# LORA is additive
h = BA_x + self.module.output
# Replace the output with our new one * alpha
# Could also have been self.module.output[:] = h * alpha, for in-place
self.module.output = h * alpha
def parameters(self):
# Some way to get all the parameters.
return [self.WA, self.WB]
We can use the .scan()
method to get the shape of the module without having to fully run the model.
[4]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp
It’s time to run the LoRA training loop! We will be using the Session and the Iterator contexts to achieve this.
[5]:
with model.scan(" "):
dim = module.output.shape[-1]
print(dim)
768
[6]:
import nnsight
from torch.utils.data import DataLoader
# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with model.session() as session:
# Create dataset of 100 pairs of a blank prompt and the " Paris " id
dataset = [["_", answer_token]] * 100
# Create a dataloader from it.
dataloader = DataLoader(dataset, batch_size=10)
# Create our LORA on the last mlp
lora = LORA(module, dim, lora_dim)
# Create an optimizer. Use the parameters from LORA
optimizer = torch.optim.AdamW(lora.parameters(), lr=3)
# Iterate over dataloader using .iter.
with session.iter(dataloader) as batch:
prompt = batch[0]
correct_token = batch[1]
# Run .trace with prompt
with model.trace(prompt) as tracer:
# Apply LORA to intervention graph just by calling it with .trace
lora()
# Get logits
logits = model.lm_head.output
# Do cross entropy on last predicted token and correct_token
loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
# Call backward
loss.backward()
# Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
# The Graph of Iterator here will be:
# 1.) Index batch at 0 for prompt
# 2.) Index batch at 1 for correct_token
# 3.) Execute the .trace using the prompt
# 4.) Call .step() on optimizer
optimizer.step()
# 5.) Call .zero_grad() in optimizer
optimizer.zero_grad()
# 6.) Print out the lora WA weights to show they are indeed changing
nnsight.log(lora.WA)
Parameter containing:
tensor([[-1.6894, -2.2130, -1.4245, 0.9411],
[-0.1197, -0.9315, -0.2023, -0.1131],
[ 0.4321, -1.0991, 1.7861, -0.0823],
...,
[-0.5378, 0.2232, -0.3258, 0.1821],
[ 1.2927, 0.8778, -0.3216, 0.8665],
[ 0.8403, -0.8087, -1.2318, 0.8117]], requires_grad=True)
Parameter containing:
tensor([[-1.4478e+00, -2.3376e+00, -1.1908e+00, 1.1039e+00],
[ 1.8683e-03, -1.0216e+00, -7.8252e-02, 8.2429e-03],
[ 4.4023e-01, -1.0872e+00, 1.7536e+00, -5.8724e-02],
...,
[-3.4760e-01, 4.2430e-02, -1.4198e-01, 3.5068e-01],
[ 1.1123e+00, 9.9310e-01, -4.5363e-01, 6.9885e-01],
[ 1.1631e+00, -1.1325e+00, -8.4679e-01, 1.1354e+00]],
requires_grad=True)
Parameter containing:
tensor([[-1.2827e+00, -2.3891e+00, -1.0335e+00, 1.1924e+00],
[ 7.6505e-02, -1.0656e+00, -1.2115e-03, 8.2689e-02],
[ 4.4028e-01, -1.0678e+00, 1.7143e+00, -4.3705e-02],
...,
[-2.2642e-01, -6.9592e-02, -2.6977e-02, 4.5091e-01],
[ 9.8906e-01, 1.0532e+00, -5.2989e-01, 5.8801e-01],
[ 1.3529e+00, -1.3232e+00, -5.9669e-01, 1.3260e+00]],
requires_grad=True)
Parameter containing:
tensor([[-1.1573, -2.4044, -0.9155, 1.2436],
[ 0.1274, -1.0868, 0.0520, 0.1334],
[ 0.4365, -1.0452, 1.6722, -0.0330],
...,
[-0.1406, -0.1466, 0.0529, 0.5165],
[ 0.8953, 1.0856, -0.5781, 0.5063],
[ 1.4742, -1.4454, -0.4169, 1.4482]], requires_grad=True)
Parameter containing:
tensor([[-1.0565, -2.3983, -0.8220, 1.2723],
[ 0.1639, -1.0945, 0.0908, 0.1697],
[ 0.4305, -1.0209, 1.6292, -0.0249],
...,
[-0.0763, -0.2022, 0.1114, 0.5610],
[ 0.8199, 1.1017, -0.6093, 0.4425],
[ 1.5537, -1.5257, -0.2807, 1.5284]], requires_grad=True)
Parameter containing:
tensor([[-0.9726, -2.3785, -0.7452, 1.2863],
[ 0.1907, -1.0935, 0.1198, 0.1964],
[ 0.4231, -0.9959, 1.5859, -0.0186],
...,
[-0.0266, -0.2436, 0.1554, 0.5916],
[ 0.7569, 1.1070, -0.6294, 0.3909],
[ 1.6053, -1.5781, -0.1741, 1.5807]], requires_grad=True)
Parameter containing:
tensor([[-0.9012, -2.3495, -0.6805, 1.2900],
[ 0.2107, -1.0864, 0.1419, 0.2162],
[ 0.4149, -0.9705, 1.5428, -0.0135],
...,
[ 0.0126, -0.2747, 0.1892, 0.6123],
[ 0.7032, 1.1048, -0.6415, 0.3481],
[ 1.6369, -1.6106, -0.0891, 1.6131]], requires_grad=True)
Parameter containing:
tensor([[-0.8392, -2.3140, -0.6251, 1.2863],
[ 0.2256, -1.0750, 0.1589, 0.2310],
[ 0.4062, -0.9451, 1.5002, -0.0094],
...,
[ 0.0440, -0.2982, 0.2153, 0.6256],
[ 0.6565, 1.0973, -0.6479, 0.3121],
[ 1.6540, -1.6284, -0.0202, 1.6309]], requires_grad=True)
Parameter containing:
tensor([[-0.7847, -2.2738, -0.5771, 1.2770],
[ 0.2367, -1.0605, 0.1719, 0.2418],
[ 0.3971, -0.9199, 1.4583, -0.0060],
...,
[ 0.0693, -0.3159, 0.2354, 0.6335],
[ 0.6153, 1.0858, -0.6500, 0.2812],
[ 1.6600, -1.6352, 0.0360, 1.6376]], requires_grad=True)
Parameter containing:
tensor([[-0.7363, -2.2305, -0.5349, 1.2635],
[ 0.2446, -1.0438, 0.1818, 0.2496],
[ 0.3879, -0.8949, 1.4172, -0.0032],
...,
[ 0.0898, -0.3289, 0.2509, 0.6370],
[ 0.5786, 1.0715, -0.6487, 0.2546],
[ 1.6574, -1.6334, 0.0821, 1.6357]], requires_grad=True)
Now WA
and WB
are optimized! So we generate with the lora just by calling lora()
in the .generate
and save the output to then de-tokenize it.
[9]:
# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:
lora()
out = model.generator.output.save()
print(model.tokenizer.batch_decode(out))
# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:
out = model.generator.output.save()
print(model.tokenizer.batch_decode(out))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
['Hello Paris']
['Hello,']