How to make your NDIF experiment 130x faster¶
A user had reached out to me recently asking how they could make their nnsight code faster with NDIF to meet a project deadline. After looking at their code, I introduced a number of improvements that leverage nnsight features and remote execution principles. The result was a 130x improvement speedup.
The experience was successful; I drew many useful lessons, and I want to share with you these key principles, so you can too optimally implement your experiments for remote execution.
TLDR;¶
- If you're doing more than one forward pass, wrap them in a
model.session - Downloading large tensors can be costly, only
.save()what you need - Cache all your activations in one go
- Reduce loops with Batching Invokes
.skipwhat you can
Case Study: Attention Heads Patching¶
Our user's experimental setup consists of a standard activation patching to compute causal scores for attention heads over a dataset of contrastive pairs. For this project, the user wanted to scale their experiment to Llama-3.1-70B which they can access on NDIF via remote execution without requiring their own compute.
As the model size grows, patching in isolation every attention head in the model starts to take very long if we don't leverage every speed up we can. To motivate the principles introduced below, let's begin by getting familiar with the vanilla implementation:
# Outer loop over the model layers
for layer in range(n_layers):
# A first forward pass caches the heads' activations
# of a single layer from the clean prompt.
with model.trace(clean_seq, remote=True):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads = split_heads(o_proj_inp, n_heads, h_dim).save()
# Inner loop over the attention heads of the layer
for head in range(n_heads):
# Runs a forward pass for each attention head
# on the corrupted prompt,
# applying the corresponding head patch from the source prompt.
with model.trace(corr_seq, remote=True):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
original_shape = o_proj_inp.shape # [bsz, seq_len, model_dim]
sub = split_heads(o_proj_inp, n_heads, h_dim)
sub[:, -2, head, :] = clean_heads[:, -2, head, :]
sub = sub.view(original_shape)
new_tup = ((sub,), o_proj_inp[1])
model.model.layers[layer].self_attn.o_proj.inputs = new_tup
# Save the model logits from the entire corrupted prompt.
logits = model.output.logits.save()
# Transform the saved results locally
stats_from_logits(logits, entities)
Performance¶
Because the vanilla implementation is considerably slow to run, we will be evaluating it and the rest of the proposed enhancements on Llama-3.1-8B.
We run each implementation multiple times n=5 and present the findings as the average over all the runs.
-
Total Runtime: 915.72s; this is the time it takes to run the entire head patching code over a single contrastive pair.
Note: The request queue times are subtracted from this total time metric to obtain a truthful benchmark across methods, regardless of NDIF's traffic at the time of testing.
- Total Download Size: 3.11GB; this is how much data was saved and returned by the server to the client across all the requests for a single input.
To improve on these results we will work on optimizing the following parameters: * # NDIF requests * # Forward passes * Results size
In the vanilla implementation, every remote request is a single forward pass. The code runs a forward pass for every layer (32) and every head (32), plus an additional one per layer to cache the head inputs: 322 + 32.
| Parameter | Amount |
|---|---|
| NDIF requests | 1,056 |
| Forward passes | 1,056 |
| Result Size | 3.11 GB |
How to meet the deadline¶
Sessions¶
Naturally, each remote request introduces latency, and if you coincide with high utilization on a certain model, every one of your requests will get delayed in the queue. What you ideally want is to run multiple forward passes in a single NDIF request.
NNsight provides an API to do just that. with model.session(remote=True): allows you to wrap multiple tracing contexts and execute them all as part of the same NDIF request, so you can save the unnecessary back and forth latency.
This is powerful because you can also execute code before and after your forward passes directly on the remote environment.
Let's update the Attention Head Patching code to use the Session context.
Note: we only need to set remote=True on the outer Session context.
with model.session(remote=True):
results = list().save()
for layer in range(n_layers):
with model.trace(clean_seq):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads = split_heads(o_proj_inp, n_heads, h_dim) # no .save()
for head in range(n_heads):
with model.trace(corr_seq):
...
logits = model.output.logits.to('cpu') # free memory
results.append(logits)
for logits in results:
stats_from_logits(logits, entities)
The Session context acts as one NNsight execution environment. Therefore, variables can be referenced across traces without having to save them first. For Activation Patching, activations can be patched from one forward pass to another without storing intermediately.
Performance¶
Introducing the Session context brings our number of NDIF requests down to only 1 and consequently also reduces the total data download.
| Parameter | Amount |
|---|---|
| NDIF requests | 1 |
| Forward passes | 1056 |
| Result Size | 2.50 GB |
We made the experiment code 6x faster!
Save only what you need¶
This may be the most overlooked principle discussed in this blog, but it's one that makes a significant improvement to your experiment for no added effort.
In the vanilla implementation, the logits are saved for each patched run to which stats_from_logits is then applied locally, but really all it does is getting the probability and logits of the top prediction and the target entity's tokens.
Instead of doing this final operation locally, let's call it on the logits within the tracing context before adding it to the final results.
with model.session(remote=True):
results = list().save()
for layer in range(n_layers):
with model.trace(clean_seq):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads = split_heads(o_proj_inp, n_heads, h_dim) # no .save()
for head in range(n_heads):
with model.trace(corr_seq):
...
logits = model.output.logits
# transform the data here
results.append(stats_from_logits(logits, entities))
Performance¶
By transforming the data remotely and only saving the results required for our use case, we bring the results' size down by a factor of 27,000.
| Parameter | Amount |
|---|---|
| NDIF requests | 1 |
| Forward passes | 1056 |
| Result Size | 112.81 KB |
This improvement makes another significant impact on the latency as we reduce it by almost 3x more and 21x overall!
Caching in one go¶
Now, we turn our attention to the number of forward passes ran on the model. Currently, the code performs one forward pass per layer activation to cache; let's rewrite that instead to cache all the layer outputs in one forward pass.
with model.session(remote=True):
clean_heads = list()
with model.trace(clean_seq):
# loop inside the trace
for layer in range(n_layers):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads.append(split_heads(o_proj_inp, n_heads, h_dim))
results = list().save()
for layer in range(n_layers):
for head in range(n_heads):
with model.trace(corr_seq):
...
logits = model.output.logits
results.append(stats_from_logits(logits, entities))
Performance¶
For Llama-3.1-8B, this refactoring saves us 31 extra forward passes. Everything else remains the same.
| Parameter | Amount |
|---|---|
| NDIF requests | 1 |
| Forward passes | 1025 |
| Result Size | 112.81 KB |
The speedup here is not significant. However, I anticipate it will be impactful as your model size or sequence lengths, i.e. each individual forward pass becomes more expensive. Therefore, it is still worth adopting as a better practice, as long as you can afford the memory cost.
Batching¶
One of the most powerful NNsight features available to us is the .invoke API for batching interventions during runtime while retaining the ability to intervene on each batch in isolation. The implication is that you can apply different interventions to different sequences in the forward pass and without having to manage slicing and indices.
The syntax looks the following:
with model.trace() as tracer:
with tracer.invoke(batch_1):
# intervene here
...
with tracer.invoke(batch_2):
# intervene here
...
The Invoker context only applies the interventions to the batch slice it was provided, and the indices are treated as if it was the only input to the forward pass.
For our Attention Head Activation Patching, we will leverage batching to parallelize the patched runs over multiple heads. For Llama-3.1-8B, we can batch all the heads in a single layer at once. In practice, this means to move the loop over the heads inside the tracing context, such that each iteration creates an invoker context per patched head.
with model.session(remote=True):
clean_heads = list()
with model.trace(clean_seq):
for layer in range(n_layers):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads.append(split_heads(o_proj_inp, n_heads, h_dim))
results = list().save()
for layer in range(n_layers):
with model.trace() as tracer: # don't pass the input here
for head in range(n_heads): # loop over the heads
with tracer.invoke(corr_seq): # input passed here
...
logits = model.output.logits
results.append(stats_from_logits(logits, entities))
Performance¶
Batching allows us to reduce the number of forward passes from (num_layers * num_heads) to (num_layers * num_heads / batch_size), and in our case that would just be 32 runs + 1 (clean prompt run for caching).
| Parameter | Amount |
|---|---|
| NDIF requests | 1 |
| Forward passes | 33 |
| Result Size | 112.81 KB |
Doing so, we reach another significant speedup of +7x. In totality, this is more than 130x speed improvement.
Module skipping¶
While we go over skipping redundant computation, I hope you don't decide to skip this section.
This far, we were successful in reducing the number of forward passes to roughly one per layer. However, since all the forward passes take in the same input data, each pass is redundantly recomputing the same layer results up until the patching point.
Hence, our next optimization will be to pre-cache the unpatched layer outputs and use them to completely skip all the layers leading up to the point of patching in the model. To achieve that, NNsight implements a .skip(<value>) API that you can call on any module to skip its forward pass entirely. The method call takes in the value that will act as the output of the skipped module to propagate to the next module in line.
For our Activation Patching implementation, this means doing an extra forward pass on the clean prompt to cache layer outputs and adding logic at the top of the patch invoker to skip all prior layers.
with model.session(remote=True):
clean_heads = list()
with model.trace(clean_seq):
for layer in range(n_layers):
o_proj_inp = model.model.layers[layer].self_attn.o_proj.input
clean_heads.append(split_heads(o_proj_inp, n_heads, h_dim))
# cache unpatched activations
corr_acts = list()
with model.trace(corr_seq):
for layer in range(n_layers):
corr_acts.append(model.model.layers[layer].output)
results = list().save()
for layer in range(n_layers):
with model.trace() as tracer: # don't pass the input here
for head in range(n_heads): # loop over the heads
with tracer.invoke(corr_seq): # input passed here
for layer_to_skip in range(layer):
# skip all the layers leading to the patch
model.model.layers[layer_to_skip].skip(corr_acts[layer_to_skip])
...
logits = model.output.logits
results.append(stats_from_logits(logits, entities))
Performance¶
| Parameter | Amount |
|---|---|
| NDIF requests | 1 |
| Forward passes | 34 |
| Result Size | 112.81 KB |
The final improvement we introduce does not reveal to be speeding things up for our setup and is even a tiny bit slower. But, I wonder if we can see a different affect on a larger model?
Performance on Llama-3.1-70B¶
Scaling the benchmark to a larger model reveals speedup gains using the layer skipping logic. Compared to our last optimized iteration with batching, we obtain a 40% speed increase when we pair it with skipping earlier layers. I anticipate that speedups can also be felt on "smaller" models if working with larger input sequences.
Since the implementation from the previous section is only slightly faster than the one we just covered, I recommend you implement module skipping in your experiment when applicable, so it can be reusable across model sizes and benefit from that extra performance boost.
Benchmarking Local Execution¶
If you came here to figure out how to run your experiments faster and more efficiently on NDIF, I hope you were able to learn a few useful tricks.
To conclude, I want the readers researching on their models in a local tracing setup to also get some important takeaways that they can leverage to speedup their experiments.
Let's re-run some of the implementations we came up with earlier but this time on a local GPU node and see how well the speed improvements transfer.
The Activation Patching implementation using invoker batching and module skipping yield a 7x speedup compared to the base implementation. Leveraging NNsight features can optimize your experiment a long way!
If you are constrained by memory limits and can't rely on batching, the benchmark shows that you can get a decent 40% speed increase by adding layer skipping even when not paired with batching.
¶
If all of this is new to you and you've never accessed models on NDIF before, visit login.ndif.us and get started with your API key!
System Config
| nnsight | 0.6.3 |
| transformers | 5.5.4 |
| torch | 2.11.0+cu130 |
| GPU | RTX 6000 Ada Generation |