Boundless DAS#

This tutorial is adapated from pyvene, you can find their code here.

Read more about Boundless DAS from the original paper by Zhengxuan Wu et al. here.

Setup (Ignore)#

[2]:
import torch
from tqdm import tqdm, trange
from nnsight import LanguageModel
from pyvene import BoundlessRotatedSpaceIntervention
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from datasets import Dataset as hf_Dataset
from transformers import get_linear_schedule_with_warmup
import gc

from tutorial_price_tagging_utils import factual_sampler, bound_alignment_sampler, lower_bound_alignment_example_sampler
[3]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LanguageModel('sharpbai/alpaca-7b-merged', device_map="cuda:0", torch_dtype=torch.bfloat16, dispatch=True)
remote = False
[4]:
def free_unused_cuda_memory():
    """Free unused cuda memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    else:
        raise RuntimeError("not using cuda")
    gc.collect()
[5]:
def calculate_loss(logits, labels, subspace_proj, mask_weight=1.5, vocab_size=32001):
    shift_logits = logits[..., :, :].contiguous()
    shift_labels = labels[..., :].contiguous()
    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)


    boundary_loss = mask_weight * subspace_proj.intervention_boundaries.sum()
    loss += boundary_loss

    return loss
[6]:
def compute_metrics(eval_preds, eval_labels, generate_output=False):
    total_count = 0
    correct_count = 0

    if generate_output:
        outputs = []
        gts = []

    for eval_pred, eval_label in zip(eval_preds, eval_labels):

        for i in range(eval_label.shape[0]):
            label_idxs = eval_label[i].ne(-100).nonzero().squeeze(-1)

            actual_test_labels = eval_label[i][label_idxs].tolist()
            pred_test_labels = [eval_pred[i][idx].argmax(dim=-1).item() for idx in label_idxs]

            correct = actual_test_labels==pred_test_labels # uncomment it to evaluate all tokens

            if generate_output:
                outputs.append(pred_test_labels)
                gts.append(actual_test_labels)

            total_count += 1
            if correct:
                correct_count += 1

    return_dict = {"accuracy": round(correct_count/total_count, 2)}
    if generate_output:
        return_dict["outputs"] = outputs
        return_dict["labels"] = gts

    return return_dict

Price Tagging game#

The instruction prompt of the Price Tagging game follows the publicly released template of the Alpaca (7B) model. The core instruction contains an English sentence: “Please say yes only if it costs between [X.XX] and [X.XX] dollars, otherwise no.” followed by an input dollar amount [X.XX], where [X.XX] are random continuous real numbers drawn with a uniform distribution from [0.00, 9.99]. The output is a single token ‘Yes’ or ‘No’.

One hypothesis for how the model solves this task is the left boundary causal model which has one high-level boolean variable representing whether the input amount is higher than the lower bound, and an output node incorporating whether the input amount is also lower than the high bound. In this tutorial we focus on finding alignment for this causal model.

Prealign Task#

To create our datasets, we are using code copied from pyvene.

[7]:
raw_prealign = factual_sampler(model.tokenizer, 5000, game="pricing_tag")

prealign_dataset = hf_Dataset.from_dict(
    {"input_ids": raw_prealign[0], "labels": raw_prealign[1]})
prealign_dataset.set_format('torch', columns=['input_ids','labels'])
prealign_dataloader = DataLoader(
    prealign_dataset, batch_size=8
)

Each instance in the dataset appear in this format:

[8]:
model.tokenizer.decode(prealign_dataset['input_ids'][0])
[8]:
'<s> Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nPlease say yes only if it costs between 2.52 and 7.83 dollars, otherwise no.\n\n### Input:\n9.76 dollars\n\n### Response:\n'
[9]:
with torch.no_grad():

    eval_labels = []
    eval_preds = []

    for step, inputs in enumerate(tqdm(prealign_dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)

        outputs = model.forward(
            input_ids=inputs['input_ids'],
            labels=inputs['labels']
        )
        eval_labels += [inputs['labels'].detach().cpu()]

        eval_preds += [outputs.logits.detach().cpu()]

    eval_metrics = compute_metrics(eval_preds, eval_labels)

eval_dict = eval_metrics
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {eval_dict['accuracy']}")
100%|██████████| 625/625 [01:42<00:00,  6.08it/s]
[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: 0.92

Boundless DAS#

The goal of Boundless DAS is to learn an alignment between potential distributed neural representations and high level causal variables.

To train Boundless DAS, we sample two training examples and then swap the intermediate boolean values between them to produce a counterfactual output using our causal model. In parallel, we swap the aligned dimensions of the neural representations in rotated space. Lastly, we update our rotation matrix such that our neural network has a more similar counterfactual behavior to the causal model.

We start by creating the training dataset for our trainable intervention.

[10]:
raw_data = bound_alignment_sampler(
    model.tokenizer, 10000, [lower_bound_alignment_example_sampler]
)

raw_train, raw_temp = train_test_split(
    list(zip(*raw_data)), test_size=0.2, random_state=42
)

raw_eval, raw_test = train_test_split(
    raw_temp, test_size=0.5, random_state=42
)

def unpack(data):
    return tuple(map(list, zip(*data)))

raw_train = unpack(raw_train)
raw_eval = unpack(raw_eval)
raw_test = unpack(raw_test)

def create_dataset(data):
    dataset = hf_Dataset.from_dict({
        "input_ids": data[0],
        "source_input_ids": data[1],
        "labels": data[2],
        "intervention_ids": data[3]  # we will not use this field
    }).with_format("torch")
    return DataLoader(dataset, batch_size=8)

train_dataloader = create_dataset(raw_train)
eval_dataloader = create_dataset(raw_eval)
test_dataloader = create_dataset(raw_test)
[11]:
subspace_proj = BoundlessRotatedSpaceIntervention(embed_dim=model.config.hidden_size).to('cuda')

gradient_accumulation_steps = 4
epochs = 3
temperature_start = 50.0
temperature_end = 0.1
intervention_layer = 12

t_total = int(len(train_dataloader) * epochs)
warm_up_steps = 0.1 * t_total

# Define params to be learned
optimizer_params = []
optimizer_params += [{'params': subspace_proj.rotate_layer.parameters()}]
optimizer_params += [{'params': subspace_proj.intervention_boundaries, 'lr': 1e-2}]

optimizer = torch.optim.Adam(
    optimizer_params,
    lr=1e-3,
)

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

target_total_step = len(train_dataloader) * epochs

temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to(device)

total_step = 0
subspace_proj.set_temperature(temperature_schedule[total_step])
subspace_proj.train()
[11]:
BoundlessRotatedSpaceIntervention(
  (rotate_layer): ParametrizedRotateLayer(
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _Orthogonal()
      )
    )
  )
)
[12]:
def batch_subspace_swap(inputs, intervention_layer, model:LanguageModel, subspace_proj): #, batch_size=16
    """
    Batched subspace_swap intervention at a single layer using nnsight
    """
    batch_size = len(inputs['input_ids'])
    all_inds = torch.arange(batch_size)

    base_prompt, source_prompt = inputs['input_ids'][:batch_size], inputs['source_input_ids'][:batch_size]

    with model.trace(validate=False, remote=remote) as tracer:
        with tracer.invoke(base_prompt, scan=False):
            base = model.model.layers[intervention_layer].output[0].save()

        with tracer.invoke(source_prompt, scan=False):
            source = model.model.layers[intervention_layer].output[0].save()

    with model.trace(validate=False, remote=remote) as tracer:
        # intervention
        with tracer.invoke(base_prompt, scan=False):
            B = base[all_inds,80,:]
            S = source[all_inds,80,:]

            mixed_out = subspace_proj(B, S, batch_size)
            model.model.layers[intervention_layer].output[0][all_inds,80,:] = mixed_out
        save_out = model.output.save()
    del base, source, B,S
    free_unused_cuda_memory()

    output_logits = save_out.value.logits
    del save_out
    return output_logits
[13]:
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)

for epoch in train_iterator:
    log_dicts = []

    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )

    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)

        counterfactual_outputs = batch_subspace_swap(inputs, intervention_layer, model, subspace_proj)

        eval_metrics = compute_metrics(
            [counterfactual_outputs], [inputs['labels']]
        )

        loss = calculate_loss(counterfactual_outputs, inputs["labels"], subspace_proj)
        loss_str = round(loss.item(), 2)

        log_dict = {'loss': loss_str, 'acc': eval_metrics["accuracy"]}
        epoch_iterator.set_postfix(log_dict)

        log_dicts.append(log_dict)

        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                subspace_proj.set_temperature(temperature_schedule[total_step])

        total_step += 1
Epoch: 0:   0%|          | 0/1000 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Epoch: 0: 100%|██████████| 1000/1000 [43:49<00:00,  2.63s/it, loss=0.49, acc=0.88]
Epoch: 1: 100%|██████████| 1000/1000 [49:47<00:00,  2.99s/it, loss=0.4, acc=0.88]
Epoch: 2: 100%|██████████| 1000/1000 [51:38<00:00,  3.10s/it, loss=0.38, acc=0.88]
Epoch: 100%|██████████| 3/3 [2:25:15<00:00, 2905.15s/it]

Evaluation on test set:

[14]:
with torch.no_grad():

    eval_labels = []
    eval_preds = []

    for step, inputs in enumerate(tqdm(test_dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)

        outputs = batch_subspace_swap(inputs, intervention_layer, model, subspace_proj)#, batch_size=dataloader.batch_size)

        eval_labels += [inputs['labels'].detach().cpu()]
        eval_preds += [outputs.detach().cpu()]

    eval_metrics = compute_metrics(eval_preds, eval_labels)

print(f"Boundless DAS accuracy: {eval_metrics['accuracy']}")
100%|██████████| 125/125 [01:54<00:00,  1.09it/s]
Boundless DAS accuracy: 0.93