Skipping Modules#

Summary#

NNsight lets you skip module execution using .skip()! You just need to provide the tensor that will replace the module’s output. Replacement outputs could be the output of the prior module or a custom tensor, as long as they are formatted to match the module’s output.

with model.trace(input):

    # skip module and pass its replacement output to the next module instead
    model.module.skip(replacement_output)

When to Use#

Use skip() if you would like to test the effects of skipping a module on model execution.

As you can include a replacement output, skip() may enable more sophisticated operations like model splicing.

How to Use#

You can skip execution of a module in NNsight by applying .skip(). .skip() replaces the model execution with a replacement output tensor that will serve as the input to the next module. It is up to the user to select their replacement tensor appropriately. For example, to skip the first layer’s MLP and pass it’s input to the next module, you would need to pass the prior MLP’s output to the .skip().

Formatting replacement tensors

Be cautious with replacement tensor selection. Some modules, like layers, may format their inputs and outputs differently, so it is recommended to pass the prior layer’s output when skipping layers. When skipping the first layer, as there’s no layer output before the first layer, you would need to construct a mock prior layer output yourself.

Let’s try it out:

[4]:
with model.trace(prompt):

  # skip layer 5 MLP and pass the output of the previous layer to the next module instead
  model.transformer.h[5].mlp.skip(model.transformer.h[4].mlp.output)

  skipped_output = model.lm_head.output.save()

Let’s say we now want to skip the execution of layers 2 through 6. We can do this iteratively.

[ ]:
with model.trace(prompt):
    replacement_output = model.transformer.h[1].output
    for ii in range(2, 6):
        model.transformer.h[ii].skip(replacement_output)

You can also pass in custom tensors to .skip() as long as they are formatted correctly (see above note on replacement tensor formatting).

If multiple invokers are defined, skips need to be specified for every invoker.

[13]:
with model.trace() as tracer:
    with tracer.invoke(prompt):
        # skipping layer 1
        model.transformer.h[1].skip(model.transformer.h[0].output)

    with tracer.invoke(prompt_2):
        # skipping layer 1
        model.transformer.h[1].skip(model.transformer.h[0].output)