We’re going to train a very simple LORA that, when applied, will make our model always predict “Paris” no matter what.

import torch
import torch.nn as nn
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')

from nnsight.envoy import Envoy #

# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
    def __init__(self, module: Envoy, dim: int, r: int) -> None:

            module (Envoy): Which model Module we are adding the LORA to.
            dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
            r (int): Inner dimension of the LORA
        super(LORA, self).__init__()
        self.r = r
        self.module = module
        self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
        self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()

    # The Call method defines how to actually apply the LORA.
    def __call__(self, alpha: float = 1.0):

            alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.

        # We apply WA to the first positional arg (the hidden states)
        A_x = torch.matmul(self.module.input[0][0], self.WA)
        BA_x = torch.matmul(A_x, self.WB)

        # LORA is additive
        h = BA_x + self.module.output

        # Replace the output with our new one * alpha
        # Could also have been self.module.output[:] = h * alpha, for in-place
        self.module.output = h * alpha

    def parameters(self):
        # Some way to get all the parameters.
        return [self.WA, self.WB]
Let’s define all the variables to use in LORA training.

# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp

We can use the .scan() method to get the shape of the module without having to fully run the model.

with model.scan(" "):
    dim = module.output.shape[-1]


It’s time to run the LORA training loop! We using the Session and the Iterator contexts to achieve this.

from torch.utils.data import DataLoader

# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with model.session() as session:

    # Create dataset of 100 pairs of a blank prompt and the " Paris " id
    dataset = [["_", answer_token]] * 100

    # Create a dataloader from it.
    dataloader = DataLoader(dataset, batch_size=10)

    # Create our LORA on the last mlp
    lora = LORA(module, dim, lora_dim)

    # Create an optimizer. Use the parameters from LORA
    optimizer = torch.optim.AdamW(lora.parameters(), lr=3)

    # Iterate over dataloader using .iter.
    with session.iter(dataloader, return_context=True) as (batch, iterator):

        prompt = batch[0]
        correct_token = batch[1]

        # Run .trace with prompt
        with model.trace(prompt) as tracer:

            # Apply LORA to intervention graph just by calling it with .trace

            # Get logits
            logits = model.lm_head.output

            # Do cross entropy on last predicted token and correct_token
            loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
            # Call backward

        # Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
        # The Graph of Iterator here will be:
        # 1.) Index batch at 0 for prompt
        # 2.) Index batch at 1 for correct_token
        # 3.) Execute the .trace using the prompt
        # 4.) Call .step() on optimizer
        # 5.) Call .zero_grad() in optimizer
        # 6.) Print out the lora WA weights to show they are indeed changing

Now WA and WB are optimized! So we generate with the lora just by calling lora() in the .generate and save the output to then de-tokenize it.

# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:


    out = model.generator.output.save()


# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:

    out = model.generator.output.save()


['Hello Paris']