Info
Last Execution: 2026-02-17
| Package | Version |
|---|---|
| nnsight | 0.5.15 |
| Python | 3.12.3 |
| torch | 2.10.0+cu128 |
| transformers | 5.2.0 |
Skipping Modules¶
Use .skip() to bypass a module's computation entirely, replacing its output with a value you provide. This lets you test the effect of removing specific components — like individual layers or MLPs — from the forward pass.
Setup¶
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
Skipping a Module¶
Call .skip(replacement) on any module to bypass its forward pass. The replacement tensor is used as the module's output instead.
# Normal prediction
with model.trace("The Eiffel Tower is in the city of"):
normal_logits = model.lm_head.output.save()
# Skip layer 5's MLP — pass the previous MLP's output through instead
with model.trace("The Eiffel Tower is in the city of"):
model.transformer.h[5].mlp.skip(model.transformer.h[4].mlp.output)
skipped_logits = model.lm_head.output.save()
print(f"Normal: {model.tokenizer.decode(normal_logits[0, -1].argmax(dim=-1))}")
print(f"Skipped: {model.tokenizer.decode(skipped_logits[0, -1].argmax(dim=-1))}")
Normal: Paris Skipped: Paris
Skipping Multiple Layers¶
You can skip a range of layers in a loop. Pass the output of the last non-skipped layer as the replacement for each skipped layer.
with model.trace("The Eiffel Tower is in the city of"):
# Skip layers 3 through 8
replacement = model.transformer.h[2].output
for i in range(3, 9):
model.transformer.h[i].skip(replacement)
logits = model.lm_head.output.save()
print(f"Skipped layers 3-8: {model.tokenizer.decode(logits[0, -1].argmax(dim=-1))}")
Skipped layers 3-8: the
Replacement tensor format
The replacement must match the output format of the skipped module. For GPT-2 transformer blocks, the output is a tuple (hidden_states, ...), so passing a prior layer's .output (which is already a tuple) works directly. For sub-modules like .mlp, the output is a plain tensor.
Measuring Layer Importance¶
Skip each layer one at a time and check how the prediction changes — a simple way to measure which layers matter most for a given prompt.
prompt = "The Eiffel Tower is in the city of"
with model.trace(prompt):
baseline = model.lm_head.output[0, -1].argmax(dim=-1).save()
baseline_token = model.tokenizer.decode(baseline)
print(f"Baseline: {baseline_token}\n")
for layer_idx in range(1, 12):
with model.trace(prompt):
model.transformer.h[layer_idx].skip(model.transformer.h[layer_idx - 1].output)
pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
token = model.tokenizer.decode(pred)
changed = " ← changed!" if token != baseline_token else ""
print(f"Skip layer {layer_idx:2d}: {token}{changed}")
Baseline: Paris Skip layer 1: London ← changed! Skip layer 2: London ← changed! Skip layer 3: the ← changed! Skip layer 4: Paris Skip layer 5: Paris Skip layer 6: London ← changed! Skip layer 7: Paris Skip layer 8: London ← changed! Skip layer 9: London ← changed! Skip layer 10: London ← changed! Skip layer 11: Paris
When to use skip
- Ablation studies — measure the causal effect of removing a layer or sub-module
- Layer importance — identify which layers are critical for specific predictions
- Model splicing — replace a module's computation with an alternative output