LoRA#

We’re going to train a very simple LORA that, when applied, will make our model always predict “Paris” no matter what.

[1]:
import torch
import torch.nn as nn
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')


from nnsight.envoy import Envoy #

# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
    def __init__(self, module: Envoy, dim: int, r: int) -> None:
        """Init.

        Args:
            module (Envoy): Which model Module we are adding the LORA to.
            dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
            r (int): Inner dimension of the LORA
        """
        super(LORA, self).__init__()
        self.r = r
        self.module = module
        self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
        self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()

    # The Call method defines how to actually apply the LORA.
    def __call__(self, alpha: float = 1.0):
        """Call.

        Args:
            alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
        """

        # We apply WA to the first positional arg (the hidden states)
        A_x = torch.matmul(self.module.input[0][0], self.WA)
        BA_x = torch.matmul(A_x, self.WB)

        # LORA is additive
        h = BA_x + self.module.output

        # Replace the output with our new one * alpha
        # Could also have been self.module.output[:] = h * alpha, for in-place
        self.module.output = h * alpha

    def parameters(self):
        # Some way to get all the parameters.
        return [self.WA, self.WB]
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(

Let’s define all the variables to use in LORA training.

[3]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp

We can use the .scan() method to get the shape of the module without having to fully run the model.

[4]:
with model.scan(" "):
    dim = module.output.shape[-1]

print(dim)
768

It’s time to run the LORA training loop! We using the Session and the Iterator contexts to achieve this.

[5]:
from torch.utils.data import DataLoader

# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with model.session() as session:

    # Create dataset of 100 pairs of a blank prompt and the " Paris " id
    dataset = [["_", answer_token]] * 100

    # Create a dataloader from it.
    dataloader = DataLoader(dataset, batch_size=10)

    # Create our LORA on the last mlp
    lora = LORA(module, dim, lora_dim)

    # Create an optimizer. Use the parameters from LORA
    optimizer = torch.optim.AdamW(lora.parameters(), lr=3)

    # Iterate over dataloader using .iter.
    with session.iter(dataloader, return_context=True) as (batch, iterator):

        prompt = batch[0]
        correct_token = batch[1]

        # Run .trace with prompt
        with model.trace(prompt) as tracer:

            # Apply LORA to intervention graph just by calling it with .trace
            lora()

            # Get logits
            logits = model.lm_head.output

            # Do cross entropy on last predicted token and correct_token
            loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
            # Call backward
            loss.backward()

        # Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
        # The Graph of Iterator here will be:
        # 1.) Index batch at 0 for prompt
        # 2.) Index batch at 1 for correct_token
        # 3.) Execute the .trace using the prompt
        # 4.) Call .step() on optimizer
        optimizer.step()
        # 5.) Call .zero_grad() in optimizer
        optimizer.zero_grad()
        # 6.) Print out the lora WA weights to show they are indeed changing
        iterator.log(lora.WA)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Parameter containing:
tensor([[ 0.5262, -0.6452,  0.8448,  0.7407],
        [-0.4497, -0.7200, -1.0452,  0.0630],
        [ 0.7231,  1.0991,  0.3883,  0.1719],
        ...,
        [ 0.0024, -1.1490, -0.5580, -0.9070],
        [-0.1946,  0.8469, -1.8173,  0.8333],
        [ 0.1722, -1.8518, -1.5542, -1.3361]], requires_grad=True)
Parameter containing:
tensor([[ 0.6813, -0.4550,  0.9903,  0.5476],
        [-0.3310, -0.5932, -0.9087, -0.0441],
        [ 0.7201,  1.0849,  0.3954,  0.1480],
        ...,
        [ 0.1580, -0.9589, -0.3856, -1.0354],
        [-0.3153,  0.6950, -1.8893,  0.9347],
        [ 0.4812, -1.4821, -1.1935, -1.6101]], requires_grad=True)
Parameter containing:
tensor([[-1.2552, -2.3574, -0.9555,  2.4472],
        [-2.2370, -2.4913, -2.7973,  1.8731],
        [-1.2148, -0.8610, -1.5298,  2.0569],
        ...,
        [-1.7628, -2.8462, -2.2901,  0.9117],
        [ 1.6102,  2.5902,  0.0834, -1.0093],
        [-1.4495, -3.3539, -3.0740,  0.3545]], requires_grad=True)
Parameter containing:
tensor([[-2.7871, -3.8562, -2.4963,  3.9433],
        [-3.7392, -3.9859, -4.2827,  3.3862],
        [-2.7453, -2.4021, -3.0508,  3.5621],
        ...,
        [-3.2794, -4.3303, -3.7909,  2.4538],
        [ 3.1313,  4.0819,  1.6503, -2.5484],
        [-2.9757, -4.8230, -4.5514,  1.9135]], requires_grad=True)
Parameter containing:
tensor([[-4.0300, -5.0671, -3.7480,  5.1516],
        [-4.9534, -5.1927, -5.4806,  4.6110],
        [-3.9870, -3.6541, -4.2833,  4.7793],
        ...,
        [-4.5076, -5.5269, -5.0037,  3.7067],
        [ 4.3638,  5.2859,  2.9272, -3.7984],
        [-4.2132, -6.0050, -5.7416,  3.1828]], requires_grad=True)
Parameter containing:
tensor([[-5.0524, -6.0584, -4.7789,  6.1404],
        [-5.9479, -6.1800, -6.4593,  5.6158],
        [-5.0084, -4.6854, -5.2958,  5.7769],
        ...,
        [-5.5156, -6.5044, -5.9969,  4.7388],
        [ 5.3761,  6.2705,  3.9826, -4.8277],
        [-5.2302, -6.9683, -6.7128,  4.2308]], requires_grad=True)
Parameter containing:
tensor([[-5.8993, -6.8751, -5.6339,  6.9546],
        [-6.7678, -6.9929, -7.2638,  6.4456],
        [-5.8543, -5.5411, -6.1332,  6.5998],
        ...,
        [-6.3486, -7.3077, -6.8154,  5.5951],
        [ 6.2131,  7.0807,  4.8615, -5.6812],
        [-6.0719, -7.7578, -7.5100,  5.1024]], requires_grad=True)
Parameter containing:
tensor([[-6.6025, -7.5490, -6.3451,  7.6261],
        [-7.4447, -7.6631, -7.9259,  7.1322],
        [-6.5568, -6.2529, -6.8272,  7.2799],
        ...,
        [-7.0383, -7.9686, -7.4911,  6.3074],
        [ 6.9068,  7.7484,  5.5957, -6.3908],
        [-6.7700, -8.4054, -8.1650,  5.8296]], requires_grad=True)
Parameter containing:
tensor([[-7.1858, -8.1039, -6.9361,  8.1787],
        [-8.0026, -8.2144, -8.4693,  7.6995],
        [-7.1395, -6.8447, -7.4018,  7.8409],
        ...,
        [-7.6084, -8.5108, -8.0477,  6.8995],
        [ 7.4809,  8.2972,  6.2091, -6.9804],
        [-7.3483, -8.9347, -8.7015,  6.4362]], requires_grad=True)
Parameter containing:
tensor([[-7.6675, -8.5581, -7.4253,  8.6306],
        [-8.4597, -8.6652, -8.9124,  8.1657],
        [-7.6208, -7.3349, -7.8752,  8.3011],
        ...,
        [-8.0775, -8.9528, -8.5035,  7.3898],
        [ 7.9537,  8.7455,  6.7201, -7.4682],
        [-7.8253, -9.3640, -9.1378,  6.9405]], requires_grad=True)

Now WA and WB are optimized! So we generate with the lora just by calling lora() in the .generate and save the output to then de-tokenize it.

[6]:
# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:

    lora()

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out.value))

# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out.value))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
['Hello Paris']
['Hello,']