nnsight.models#

This module contains the main NNsight model classes which enable the tracing and interleaving functionality of nnsight.

Models allow users to load and wrap torch modules. Here we load gpt2 from HuggingFace using its repo id:

from nnsight import LanguageModel
model = LanguageModel('gpt2', device_map='cuda:0')

In this case, declaring a LanguageModel entails the underlying model is a transformers.AutoModelForCausalLM, and unused arguments by LanguageModel are passed downstream to AutoModelForCausalLM. device_map='cuda:0' leverages the accelerate package to use the first GPU when loading the local model.

The wrapping of the underlying model encompasses both the display of its module structure when the model is printed and the ability to directly access the attributes of the underlying meta_model through the wrapped model itself.

Printing out the wrapped module returns its structure:

print(model)
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2AttentionAltered(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
          (query): WrapperModule()
          (key): WrapperModule()
          (value): WrapperModule()
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

The primary method of interacting and running the model is .trace(...). This returns a context manager object which, when entered, track operations performed on the inputs and outputs of modules.

The trace context is has the most explicit control of all levels of nnsight tracing and interleaving, creating a parent context where sub, input specific, contexts are spawned from.

with model.trace("The Eiffel Tower is in the city of") as runner:
    logits = model.lm_head.output.save()

print(logits.value)

See nnsight.contexts for more.

class nnsight.models.NNsightModel.NNsight(model_key: str | Module, *args, dispatch: bool = False, **kwargs)[source]#

Main class to be implemented as a wrapper for PyTorch models wishing to gain this package’s functionality. Can be used “as is” for basic models.

Class Attributes:

proxy_class (Type[InterventionProxy]): InterventionProxy like type to use as a Proxy for this Model’s inputs and outputs. Can have Model specific functionality added to a new sub-class.

model_key#

String representing what kind of model this is. Usually hugging face repo id of model to load, path to checkpoint, or class name of custom model.

Type:

str

args#

Positional arguments used to initialize model.

Type:

List[Any]

kwargs#

Keyword arguments used to initialize model.

Type:

Dict[str,Any]

dispatched#

If the _model has been loaded yet with real parameters yet.

Type:

bool

custom_model#

If the value passed to repoid_path_model was a custom model.

Type:

bool

_model#

Underlying torch module.

Type:

torch.nn.Module

dispatch_model(*args, **kwargs) None[source]#

Dispatch ._model to have real parameters using ._load(…).

interleave(fn: Callable, intervention_graph: Graph, *inputs: List[Any], **kwargs) Any[source]#

Runs some function with some inputs and some graph with the appropriate contexts for this model.

Loads and dispatched ._model if not already done so.

Re-compiles Graph with ._model to prepare for a new execution of graph.

Runs ._prepare_inputs(…) one last time to get total_batch_size.

Handles adding and removing hooks on modules via HookHandler and tracking number of times a module has been called via InterventionHandler.

After execution, garbage collects and clears cuda memory.

Parameters:
  • fn (Callable) – Function or method to run.

  • intervention_graph (Graph) – Intervention graph to interleave with model’s computation graph.

  • inputs (List[Any]) – Inputs to give to function.

Returns:

Output of model.

Return type:

Any

proxy_class#

alias of InterventionProxy

to(*args, **kwargs) Self[source]#

Override torch.nn.Module.to so this returns the NNSight model, not the underlying module when doing: model = model.to(…)

Returns:

Envoy.

Return type:

Envoy

trace(*inputs: Any, trace: bool = True, invoker_args: Dict[str, Any] | None = None, scan: bool = True, **kwargs: Dict[str, Any]) Runner | Any[source]#

Entrypoint into the tracing and interleaving functionality nnsight provides.

In short, allows access to the future inputs and outputs of modules in order to trace what operations you would like to perform on them. This can be as simple as accessing and saving activations for inspection, or as complicated as transforming the activations and gradients in a forward pass over multiple inputs.

Parameters:
  • inputs (tuple[Any])

  • trace (bool, optional) – If to open a tracing context. Otherwise immediately run the model and return the raw output. Defaults to True.

  • invoker_args (Dict[str, Any], optional) – Keyword arguments to pass to Invoker initialization, and then downstream to the model’s .prepare_inputs(…) method. Used when giving input directly to .trace(…). Defaults to None.

  • kwargs (Dict[str, Any]) – Keyword arguments passed to Runner/Tracer initialization, and then downstream to the model’s ._execute(…) method.

Raises:

ValueError – If trace is False and no inputs were provided (nothing to run with)

Returns:

Either the Runner used for tracing, or the raw output if trace is False.

Return type:

Union[Runner, Any]

Examples

There are a few ways you can use .trace(...) depending in your use case.

Lets use this extremely basic model for our examples:

import torch
from collections import OrderedDict

input_size = 5
hidden_dims = 10
output_size = 2

model = nn.Sequential(OrderedDict([
    ('layer1', torch.nn.Linear(input_size, hidden_dims)),
    ('sigma1', torch.nn.Sigmoid()),
    ('layer2', torch.nn.Linear(hidden_dims, output_size)),
    ('sigma2', torch.nn.Sigmoid()),
]))

example_input = torch.rand((1, input_size))

The first example has us running the model with a single example input, and saving the input and output of ‘layer2’ as well as the final output using the tracing context.

from nnsight import NNsight

with NNsight(model).trace(example_input) as model:

    l2_input = model.layer2.input.save()
    l2_output = model.layer2.output.save()

    output = model.output.save()

print(l2_input)
print(l2_output)
print(output)

The second example allows us to divide up multiple inputs into one batch, and scope an inner invoker context to each one. We indicate this simply by not passing and positional inputs into .trace(…). The Tracer object then expects you to enter each input via Tracer.invoke(…)

example_input2 = torch.rand((1, input_size))

with NNsight(model).trace() as model:

    with model.invoke(example_input):

        output1 = model.output.save()

    with model.invoke(example_input2):

        output2 = model.output.save()

print(output1)
print(output2)

For a proxy tensor with 3 tokens.

class nnsight.models.LanguageModel.LanguageModel(model_key: str | ~torch.nn.modules.module.Module, *args, tokenizer: ~transformers.tokenization_utils.PreTrainedTokenizer | None = None, automodel: ~typing.Type[~transformers.models.auto.modeling_auto.AutoModel] = <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, **kwargs)[source]#

LanguageModels are NNsight wrappers around transformers language models.

Inputs can be in the form of:

Prompt: (str) Prompts: (List[str]) Batched prompts: (List[List[str]]) Tokenized prompt: (Union[List[int], torch.Tensor]) Tokenized prompts: (Union[List[List[int]], torch.Tensor]) Direct input: (Dict[str,Any])

If using a custom model, you also need to provide the tokenizer like LanguageModel(custom_model, tokenizer=tokenizer)

Calls to generate pass arguments downstream to GenerationMixin.generate()

config#

Huggingface config file loaded from repository or checkpoint.

Type:

PretrainedConfig

tokenizer#

Tokenizer for LMs.

Type:

PreTrainedTokenizer

automodel#

AutoModel type from transformer auto models.

Type:

Type

model#

Meta version of underlying auto model.

Type:

PreTrainedModel

proxy_class#

alias of LanguageModelProxy

class nnsight.models.LanguageModel.LanguageModelProxy(node: Node)[source]#

Indexing by token of hidden states can easily done using .token[<idx>] or .t[<idx>]

with runner.invoke('The Eiffel Tower is in the city of') as invoker:
    logits = model.lm_head.output.t[0].save()

print(logits.value)

This would save only the first token of the output for this module. This should be used when using multiple invokes as the batching and padding of multiple inputs could mean the indices for tokens shifts around and this take care of that.

Parameters:

InterventionProxy (_type_) – _description_

Returns:

_description_

Return type:

_type_

property t: TokenIndexer#

Property as alias for InterventionProxy.token

property token: TokenIndexer#

Property used to do token based indexing on a proxy. Directly indexes the second dimension of tensors. Makes positive indices negative as tokens are padded on the left.

Example

model.transformer.h[0].mlp.output.token[0]

Is equivalent to:

model.transformer.h[0].mlp.output.token[:,-3]

For a proxy tensor with 3 tokens.

Returns:

Object to do token based indexing.

Return type:

TokenIndexer

class nnsight.models.LanguageModel.TokenIndexer(proxy: InterventionProxy)[source]#

Helper class to directly access token indices of hidden states. Directly indexes the second dimension of tensors. Makes positive indices negative as tokens are padded on the left.

Parameters:

proxy (InterventionProxy) – Proxy to aid in token indexing.