LoRA#

We’re going to train a very simple LoRA that, when applied, will make our model always predict “Paris” no matter what.

[ ]:
!pip install nnsight

Let’s define all the variables to use in LoRA training.

[3]:
import torch
import torch.nn as nn
from nnsight import LanguageModel
from nnsight import Envoy #

model = LanguageModel('openai-community/gpt2', device_map='auto')

# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
    def __init__(self, module: Envoy, dim: int, r: int) -> None:
        """Init.

        Args:
            module (Envoy): Which model Module we are adding the LORA to.
            dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
            r (int): Inner dimension of the LORA
        """
        super(LORA, self).__init__()
        self.r = r
        self.module = module
        self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
        self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()

    # The Call method defines how to actually apply the LORA.
    def __call__(self, alpha: float = 1.0):
        """Call.

        Args:
            alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
        """

        # We apply WA to the first positional arg (the hidden states)
        A_x = torch.matmul(self.module.input[0][0], self.WA)
        BA_x = torch.matmul(A_x, self.WB)

        # LORA is additive
        h = BA_x + self.module.output

        # Replace the output with our new one * alpha
        # Could also have been self.module.output[:] = h * alpha, for in-place
        self.module.output = h * alpha

    def parameters(self):
        # Some way to get all the parameters.
        return [self.WA, self.WB]

We can use the .scan() method to get the shape of the module without having to fully run the model.

[4]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp

It’s time to run the LoRA training loop! We will be using the Session and the Iterator contexts to achieve this.

[5]:
with model.scan(" "):
    dim = module.output.shape[-1]

print(dim)
768
[6]:
import nnsight
from torch.utils.data import DataLoader

# The LORA object itself isn't transmitted to the server. Only the forward / call method.
# The parameters are created remotely and never sent only retrieved
with model.session() as session:

    # Create dataset of 100 pairs of a blank prompt and the " Paris " id
    dataset = [["_", answer_token]] * 100

    # Create a dataloader from it.
    dataloader = DataLoader(dataset, batch_size=10)

    # Create our LORA on the last mlp
    lora = LORA(module, dim, lora_dim)

    # Create an optimizer. Use the parameters from LORA
    optimizer = torch.optim.AdamW(lora.parameters(), lr=3)

    # Iterate over dataloader using .iter.
    with session.iter(dataloader) as batch:

        prompt = batch[0]
        correct_token = batch[1]

        # Run .trace with prompt
        with model.trace(prompt) as tracer:

            # Apply LORA to intervention graph just by calling it with .trace
            lora()

            # Get logits
            logits = model.lm_head.output

            # Do cross entropy on last predicted token and correct_token
            loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
            # Call backward
            loss.backward()

        # Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
        # The Graph of Iterator here will be:
        # 1.) Index batch at 0 for prompt
        # 2.) Index batch at 1 for correct_token
        # 3.) Execute the .trace using the prompt
        # 4.) Call .step() on optimizer
        optimizer.step()
        # 5.) Call .zero_grad() in optimizer
        optimizer.zero_grad()
        # 6.) Print out the lora WA weights to show they are indeed changing
        nnsight.log(lora.WA)

Parameter containing:
tensor([[-1.6894, -2.2130, -1.4245,  0.9411],
        [-0.1197, -0.9315, -0.2023, -0.1131],
        [ 0.4321, -1.0991,  1.7861, -0.0823],
        ...,
        [-0.5378,  0.2232, -0.3258,  0.1821],
        [ 1.2927,  0.8778, -0.3216,  0.8665],
        [ 0.8403, -0.8087, -1.2318,  0.8117]], requires_grad=True)
Parameter containing:
tensor([[-1.4478e+00, -2.3376e+00, -1.1908e+00,  1.1039e+00],
        [ 1.8683e-03, -1.0216e+00, -7.8252e-02,  8.2429e-03],
        [ 4.4023e-01, -1.0872e+00,  1.7536e+00, -5.8724e-02],
        ...,
        [-3.4760e-01,  4.2430e-02, -1.4198e-01,  3.5068e-01],
        [ 1.1123e+00,  9.9310e-01, -4.5363e-01,  6.9885e-01],
        [ 1.1631e+00, -1.1325e+00, -8.4679e-01,  1.1354e+00]],
       requires_grad=True)
Parameter containing:
tensor([[-1.2827e+00, -2.3891e+00, -1.0335e+00,  1.1924e+00],
        [ 7.6505e-02, -1.0656e+00, -1.2115e-03,  8.2689e-02],
        [ 4.4028e-01, -1.0678e+00,  1.7143e+00, -4.3705e-02],
        ...,
        [-2.2642e-01, -6.9592e-02, -2.6977e-02,  4.5091e-01],
        [ 9.8906e-01,  1.0532e+00, -5.2989e-01,  5.8801e-01],
        [ 1.3529e+00, -1.3232e+00, -5.9669e-01,  1.3260e+00]],
       requires_grad=True)
Parameter containing:
tensor([[-1.1573, -2.4044, -0.9155,  1.2436],
        [ 0.1274, -1.0868,  0.0520,  0.1334],
        [ 0.4365, -1.0452,  1.6722, -0.0330],
        ...,
        [-0.1406, -0.1466,  0.0529,  0.5165],
        [ 0.8953,  1.0856, -0.5781,  0.5063],
        [ 1.4742, -1.4454, -0.4169,  1.4482]], requires_grad=True)
Parameter containing:
tensor([[-1.0565, -2.3983, -0.8220,  1.2723],
        [ 0.1639, -1.0945,  0.0908,  0.1697],
        [ 0.4305, -1.0209,  1.6292, -0.0249],
        ...,
        [-0.0763, -0.2022,  0.1114,  0.5610],
        [ 0.8199,  1.1017, -0.6093,  0.4425],
        [ 1.5537, -1.5257, -0.2807,  1.5284]], requires_grad=True)
Parameter containing:
tensor([[-0.9726, -2.3785, -0.7452,  1.2863],
        [ 0.1907, -1.0935,  0.1198,  0.1964],
        [ 0.4231, -0.9959,  1.5859, -0.0186],
        ...,
        [-0.0266, -0.2436,  0.1554,  0.5916],
        [ 0.7569,  1.1070, -0.6294,  0.3909],
        [ 1.6053, -1.5781, -0.1741,  1.5807]], requires_grad=True)
Parameter containing:
tensor([[-0.9012, -2.3495, -0.6805,  1.2900],
        [ 0.2107, -1.0864,  0.1419,  0.2162],
        [ 0.4149, -0.9705,  1.5428, -0.0135],
        ...,
        [ 0.0126, -0.2747,  0.1892,  0.6123],
        [ 0.7032,  1.1048, -0.6415,  0.3481],
        [ 1.6369, -1.6106, -0.0891,  1.6131]], requires_grad=True)
Parameter containing:
tensor([[-0.8392, -2.3140, -0.6251,  1.2863],
        [ 0.2256, -1.0750,  0.1589,  0.2310],
        [ 0.4062, -0.9451,  1.5002, -0.0094],
        ...,
        [ 0.0440, -0.2982,  0.2153,  0.6256],
        [ 0.6565,  1.0973, -0.6479,  0.3121],
        [ 1.6540, -1.6284, -0.0202,  1.6309]], requires_grad=True)
Parameter containing:
tensor([[-0.7847, -2.2738, -0.5771,  1.2770],
        [ 0.2367, -1.0605,  0.1719,  0.2418],
        [ 0.3971, -0.9199,  1.4583, -0.0060],
        ...,
        [ 0.0693, -0.3159,  0.2354,  0.6335],
        [ 0.6153,  1.0858, -0.6500,  0.2812],
        [ 1.6600, -1.6352,  0.0360,  1.6376]], requires_grad=True)
Parameter containing:
tensor([[-0.7363, -2.2305, -0.5349,  1.2635],
        [ 0.2446, -1.0438,  0.1818,  0.2496],
        [ 0.3879, -0.8949,  1.4172, -0.0032],
        ...,
        [ 0.0898, -0.3289,  0.2509,  0.6370],
        [ 0.5786,  1.0715, -0.6487,  0.2546],
        [ 1.6574, -1.6334,  0.0821,  1.6357]], requires_grad=True)

Now WA and WB are optimized! So we generate with the lora just by calling lora() in the .generate and save the output to then de-tokenize it.

[9]:
# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:

    lora()

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out))

# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
['Hello Paris']
['Hello,']