from __future__ import annotations
import inspect
import os
import warnings
from functools import wraps
from types import BuiltinFunctionType, BuiltinMethodType, FunctionType, MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from torch.nn.modules.module import _addindent
from .. import CONFIG, base_deprecation_message, deprecated, util
from ..util import apply, Patch
from .batching import Batchable
from .inject import convert as inject
from .tracing.base import Tracer, WithBlockNotFoundError
from .tracing.editing import EditingTracer
from .tracing.globals import Object
from .tracing.iterator import IteratorProxy
from .tracing.tracer import InterleavingTracer, ScanningTracer
from .interleaver import Interleaver
def trace_only(fn: Callable):
@wraps(fn)
def wrapper(self: Envoy, *args, **kwargs):
if self._interleaver is None:
raise ValueError(f"Must be within a trace to use `.{fn.__name__}(...)`")
return fn(self, *args, **kwargs)
return wrapper
[docs]
class Envoy(Batchable):
"""
A proxy class that wraps a PyTorch module to enable intervention during execution.
This class provides access to module inputs and outputs during forward passes,
and allows for modification of these values through an interleaving mechanism.
It serves as the primary interface for inspecting and modifying the behavior
of neural network modules during execution.
Attributes:
path (str): The module's location in the model hierarchy.
Example: "model.encoder.layer1" indicates this module is the first layer of the encoder in the model.
_module (torch.nn.Module): The underlying PyTorch module
_source (Optional[EnvoySource]): Source code representation of the module
_interleaver (Optional[Interleaver]): Interleaver for managing execution flow
_default_mediators (List[List[str]]): List of default mediators created with .edit
_children (List[Envoy]): List of child Envoys
_alias (Aliaser): Aliaser object for managing aliases
"""
def __init__(
self,
module: torch.nn.Module,
interleaver: Optional[Interleaver] = None,
path: Optional[str] = "model",
rename: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> None:
"""
Initialize an Envoy for a PyTorch module.
Args:
module (torch.nn.Module): The PyTorch module to wrap
interleaver (Optional[Interleaver]): Optional interleaver for managing execution flow
path (Optional[str]): Optional path string representing the module's location in the model hierarchy
rename (Optional[Dict[str, Union[str, List[str]]]]): Optional dictionary mapping module names to alias names.
Example: {"layer1": "first_layer", "layer2": "second_layer"}
Example: {".model.layers": ".layers"} <-- Mounts .layers to the root model.
Example: {".transformer": ["model", "mdl"]} <-- Allows access of .transformer as .model or .mdl
"""
self.path = path
self._module = module
self._module.__path__ = path
self._source = None
self._interleaver = interleaver if interleaver is not None else Interleaver()
self._interleaver.wrap_module(module)
self._default_mediators: List[List[str]] = []
self._children: List[Envoy] = []
self._fake_inputs = inspect._empty
self._fake_output = inspect._empty
if rename is not None:
self._alias = Aliaser(rename)
else:
self._alias = None
for name, module in list(self._module.named_children()):
setattr(self, name, module)
if rename is not None:
self._alias.build(self)
def __getitem__(self, key: str) -> Envoy:
"""
Access a child Envoy by index for Module Lists.
Args:
key: The index of the child Envoy to retrieve
Returns:
The child Envoy at the specified index
"""
return self._children[key]
@property
def interleaving(self) -> bool:
"""
Check if the Envoy is currently nterleaving.
Returns:
True if the Envoy is interleaving, False otherwise
"""
return self._interleaver is not None and self._interleaver.interleaving
#### Properties ####
@property
def output(self) -> Object:
"""
Get the output of the module's forward pass.
This property allows access to the return values produced by the module
during the forward pass.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> attn = model.transformer.h[0].attn.output[0].save()
>>> print(attn)
Returns:
The module's output values
"""
if self.interleaving:
return self._interleaver.current.request(
self._interleaver.current.iterate(f"{self.path}.output")
)
elif self._fake_output is not inspect._empty:
return self._fake_output
else:
raise ValueError(
"Cannot return output of Envoy that is not interleaving nor has a fake output set."
)
@output.setter
def output(self, value: Any):
"""
Set new values for the module's output.
This allows for intervention by replacing the module's output with
custom values during execution.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> model.transformer.h[0].attn.output[0] *= 2
Args:
value: The new output value to use.
"""
if self.interleaving:
self._interleaver.current.swap(
self._interleaver.current.iterate(f"{self.path}.output"), value
)
else:
raise ValueError("Cannot set output of Envoy that is not interleaving.")
@property
def inputs(self) -> Tuple[Tuple[Object], Dict[str, Object]]:
"""
Get the inputs to the module's forward pass.
This property provides access to all input values passed to the module
during the forward pass.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> args, kwargs = model.transformer.h[0].attn.inputs
Returns:
The module's input values as a tuple of positional and keyword arguments. i.e (args, kwargs)
"""
if self.interleaving:
return self._interleaver.current.request(
self._interleaver.current.iterate(f"{self.path}.input")
)
elif self._fake_inputs is not inspect._empty:
return self._fake_inputs
else:
raise ValueError(
"Cannot return inputs of Envoy that is not interleaving nor has a fake inputs set."
)
@inputs.setter
def inputs(self, value: Any):
"""
Set new values for the module's inputs.
This allows for intervention by replacing the module's inputs with
custom values during execution.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> model.transformer.h[0].attn.inputs = (torch.randn(1, 1024, 1024), {})
Args:
value: The new input value(s) to use, structured as a tuple of (args, kwargs)
"""
if self.interleaving:
self._interleaver.current.swap(
self._interleaver.current.iterate(f"{self.path}.input"), value
)
else:
raise ValueError("Cannot set inputs of Envoy that is not interleaving.")
@property
def input(self) -> Object:
"""
Get the first input to the module's forward pass.
This is a convenience property that returns just the first input value
from all inputs passed to the module. So first positional argument, or first keyword argumetn if there are no positional arguments.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> hidden_states = model.transformer.h[0].attn.input.save()
>>> print(hidden_states)
Returns:
The first input value
"""
inputs = self.inputs
return [*inputs[0], *inputs[1].values()][0]
@input.setter
def input(self, value: Any):
"""
Set a new value for the module's first input.
This is a convenience method that replaces just the first input value
while preserving all other inputs.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> model.transformer.h[0].attn.input = torch.randn(1, 1024, 1024)
Args:
value: The new value for the first input
"""
inputs = self.inputs
value = (value, *inputs[0][1:]), inputs[1]
self.inputs = value
@property
def source(self) -> EnvoySource:
"""
Get the source code representation of the module.
This property provides access to the module's source code with operations
highlighted, allowing for inspection and intervention at specific points.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> # We can print to see the formward method of the module and names associated with the operations within.
>>> print(model.transformer.h[0].attn.source)
60
61 if using_eager and self.reorder_and_upcast_attn:
self__upcast_and_reordered_attn_0 -> 62 attn_output, attn_weights = self._upcast_and_reordered_attn(
63 query_states, key_states, value_states, attention_mask, head_mask
64 )
65 else:
attention_interface_0 -> 66 attn_output, attn_weights = attention_interface(
67 self,
68 query_states,
69 key_states,
70 value_states,
71 attention_mask,
72 head_mask=head_mask,
73 dropout=self.attn_dropout.p if self.training else 0.0,
74 is_causal=is_causal,
75 **kwargs,
76 )
77
attn_output_reshape_0 -> 78 attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
contiguous_0 -> + ...
self_c_proj_0 -> 79 attn_output = self.c_proj(attn_output)
self_resid_dropout_0 -> 80 attn_output = self.resid_dropout(attn_output)
81
82 return attn_output, attn_weights
83
>>> # We can print out one of these to see the only the operation and a few operations before and after.
>>> print(model.transformer.h[0].attn.source.attention_interface_0)
.transformer.h.0.attn.attention_interface_0:
....
if using_eager and self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query_states, key_states, value_states, attention_mask, head_mask
)
else:
--> attn_output, attn_weights = attention_interface( <--
self,
query_states,
key_states,
value_states,
attention_mask,
head_mask=head_mask,
....
>>> with model.trace("Hello World"):
>>> # Now we can access it like we would any other Envoy with .input or .output to grab the intermediate value.
>>> attn = model.transformer.h[0].attn.source.attention_interface_0.output.save()
>>> print(attn)
Returns:
An EnvoySource object containing the module's source code and operations
"""
if self._source is None:
def wrap(fn: Callable, **kwargs):
bound_obj = (
fn.__self__
if inspect.ismethod(fn) and fn.__name__ != "forward"
else None
)
if self.interleaving:
return self._interleaver.wrap_operation(
fn, **kwargs, bound_obj=bound_obj
)
else:
return fn
source, line_numbers, forward = inject(
self._module.forward, wrap, self._module.__path__
)
self._module.forward = MethodType(forward, self._module)
self._source = EnvoySource(self._module.__path__, source, line_numbers, interleaver=self._interleaver)
return self._source
def __call__(self, *args, hook: bool = False, **kwargs):
return (
self._module.forward(*args, **kwargs)
if self.interleaving and not hook
else self._module(*args, **kwargs)
)
#### Public methods ####
[docs]
def trace(self, *args, fn: Optional[Callable] = None, trace: bool = None, tracer_cls: Type[InterleavingTracer] = InterleavingTracer, **kwargs):
"""
Create a tracer for this module.
This method returns a tracer that can be used to capture and modify
the execution of the module.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> model.transformer.h[0].attn.output[0][:] = 0
>>> output = model.output.save()
>>> print(output)
Args:
*args: Arguments to pass to the tracer
**kwargs: Keyword arguments to pass to the tracer
Returns:
An InterleavingTracer for this module
"""
# TODO trace= is Legacy
if trace is not None:
deprecation_message = f"The `trace` argument {base_deprecation_message}\nJust call the method without a with context instead."
warnings.warn(deprecation_message)
if fn is None:
fn = self.__call__
kwargs["hook"] = True
return tracer_cls(fn, self, *args, **kwargs)
[docs]
def scan(self, *args, **kwargs):
"""
Just like .trace() but runs the model in fake tensor mode to validate operations and inspect tensor shapes.
This method returns a tracer that runs the model in fake tensor mode to validate operations
and inspect tensor shapes without performing actual computation. This is useful for:
- Validating that operations will work with given input shapes
- Inspecting the shapes and types of tensors that would flow through the model
- Debugging shape mismatches or other tensor-related issues.
Note this will not dispatch the model if not dispatched.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> # Value error as the fake inputs and outputs have not been scanned in.
>>> print(model.transformer.h[0].mlp.output.shape)
>>> # Scan the model to validate operations and inspect shapes
>>> with model.scan("Hello World"):
>>> # Access fake inputs/outputs to inspect shapes
>>> attn_input = model.transformer.h[0].attn.input.save()
>>> attn_output = model.transformer.h[0].attn.output[0].save()
>>> print(f"Attention input shape: {attn_input.shape}")
>>> print(f"Attention output shape: {attn_output.shape}")
>>> print(model.transformer.h[0].mlp.output.shape)
Args:
*args: Arguments to pass to the tracer
**kwargs: Keyword arguments to pass to the tracer
Returns:
A ScanningTracer for this module
"""
return ScanningTracer(self.__call__, self, *args, hook=True, **kwargs)
[docs]
def edit(self, *, inplace: bool = False):
"""
Create an editing tracer for this module. Allows for setting default interventions.
This means this tracer won't execute the module, but will instead set default interventions that are applied on all future executions.
Edits can be cleared with `Envoy.clear_edits()`.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> # Now the first layer attention output will always be 0.
>>> with model.edit() as edited_model:
>>> edited_model.transformer.h[0].attn.output[:] = 0
>>> with model.trace("Hello World"):
>>> output = model.output.save()
>>> # The orignal model will have the default output.
>>> print(output)
>>> with edited_model.trace("Hello World"):
>>> edited_output = edited_model.output.save()
>>> # The edited model will have the output after our intervention.
>>> print(edited_output)
Args:
inplace (bool, optional): Whether to edit in place. Defaults to False.
Returns:
(EditingTracer): An EditingTracer for this module
"""
return EditingTracer(self.__call__, self, inplace=inplace)
[docs]
def clear_edits(self):
"""
Clear all edits for this Envoy.
"""
self._default_mediators = []
[docs]
def export_edits(
self, name: str, export_dir: Optional[str] = None, variant: str = "__default__"
):
"""TODO
Args:
name (str): _description_
export_dir (Optional[str], optional): _description_. Defaults to None.
variant (str, optional): _description_. Defaults to '__default__'.
Raises:
ValueError: _description_
"""
if len(self._default_mediators) == 0:
raise ValueError("Cannot export an Envoy before calling .edit().")
if export_dir is None:
export_dir = os.path.join(CONFIG.APP.CACHE_DIR, "exports")
export_dir = os.path.expanduser(os.path.join(export_dir, name))
os.makedirs(export_dir, exist_ok=True)
from . import serialization
serialization.save(
self._default_mediators, os.path.join(export_dir, f"{variant}.dill")
)
[docs]
def import_edits(
self, name: str, export_dir: Optional[str] = None, variant: str = "__default__"
):
"""TODO
Args:
name (str): _description_
export_dir (Optional[str], optional): _description_. Defaults to None.
variant (str, optional): _description_. Defaults to '__default__'.
"""
if export_dir is None:
export_dir = os.path.join(CONFIG.APP.CACHE_DIR, "exports")
export_dir = os.path.expanduser(os.path.join(export_dir, name))
from . import serialization
imported_mediators = serialization.load(
os.path.join(export_dir, f"{variant}.dill"), self
)
self._default_mediators.extend(imported_mediators)
# TODO legacy
def session(self, *args, tracer_cls: Type[Tracer] = Tracer, **kwargs):
tracer = tracer_cls(*args, **kwargs)
setattr(tracer, "model", self)
return tracer
@property
@deprecated(message="Use `tracer.iter` instead.")
@trace_only
def iter(self):
return IteratorProxy(self._interleaver)
@deprecated(message="Use `tracer.all()` instead.")
@trace_only
def all(self):
return self.iter[:]
@deprecated(message="Use `tracer.next()` instead.")
@trace_only
def next(self, step: int = 1):
self._interleaver.current.iteration += step
return self
[docs]
@trace_only
def skip(self, replacement: Any):
"""Skips the execution of this module duting execution / interleaving.
Behavior is the module will not be executed and will return a replacement value instead.
Example:
>>> model = LanguageModel("gpt2", device_map='auto', dispatch=True)
>>> with model.trace("Hello World"):
>>> # Skip the first layer and replace it with the input to the layer.
>>> model.transformer.h[0].skip((model.transformer.h[0].input, None))
>>> output = model.output.save()
>>> print(output)
Args:
replacement (Any): The replacement value to replace the module's output with.
"""
requester = self._interleaver.current.iterate(f"{self.path}.input")
self._interleaver.current.skip(requester, replacement)
[docs]
@trace_only
def wait_for_output(self):
"""
Wait for the output to the module to be available.
"""
self.output
[docs]
def to(self, device: torch.device):
"""
Move the module to a specific device.
This method moves the underlying PyTorch module to the specified device.
Args:
device: The device to move the module to
Returns:
Self, for method chaining
"""
self._module.to(device)
return self
[docs]
def cpu(self, *args, **kwargs):
"""
Move the module to the CPU.
"""
self._module.cpu(*args, **kwargs)
return self
[docs]
def cuda(self, *args, **kwargs):
"""
Move the module to the GPU.
"""
self._module.cuda(*args, **kwargs)
return self
@property
def device(self) -> Optional[torch.device]:
"""
Get the device the module is on. Finds the first parameter and return its device.
"""
try:
return next(self._module.parameters()).device
except:
return None
[docs]
def modules(
self,
include_fn: Callable[[Envoy], bool] = None,
names: bool = False,
) -> List[Envoy]:
"""
Get all modules in the Envoy tree.
This method returns all Envoys in the tree, optionally filtered by
an inclusion function.
Args:
include_fn: Optional function to filter modules
names: Whether to include module names in the result
Returns:
A list of Envoys or (name, Envoy) tuples
"""
result = []
for envoy in self._children:
result.extend(envoy.modules(include_fn=include_fn, names=names))
if include_fn is None or include_fn(self):
if names:
result.append((self.path, self))
else:
result.append(self)
return result
[docs]
def named_modules(self, *args, **kwargs) -> List[Tuple[str, Envoy]]:
"""
Returns all Envoys in the Envoy tree along with their name/module_path.
This is a convenience method that calls modules() with names=True.
Args:
include_fn (Callable, optional): Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None.
*args, **kwargs: Additional arguments to pass to modules()
Returns:
List[Tuple[str, Envoy]]: Included Envoys and their names/module_paths.
"""
return self.modules(*args, **kwargs, names=True)
[docs]
def get(self, path: str) -> Object:
"""Gets the Envoy/Proxy via its path.
e.x:
model = nnsight.LanguageModel("openai-community/gpt2")
module = model.get('transformer.h.0.mlp')
with model.trace("Hello"):
value = model.get('transformer.h.0.mlp.output').save()
Args:
path (str): '.' separated path.
Returns:
Union[Envoy, InterventionProxyType]: Fetched Envoy/Proxy
"""
return util.fetch_attr(self, path)
def interleave(self, fn: Union[Callable, str], *args, **kwargs):
device = self.device
(args, kwargs) = apply(
(args, kwargs), lambda tensor: tensor.to(device), torch.Tensor
)
if isinstance(fn, str):
fn = getattr(self, fn)
try:
with self._interleaver:
fn(*args, **kwargs)
self._interleaver.check_cache_full()
self._interleaver.check_dangling_mediators()
finally:
self._interleaver.cancel()
#### Private methods ####
def _add_envoy(self, module: torch.nn.Module, name: str) -> Envoy:
"""
Adds a new Envoy for a given torch module under this Envoy.
This method creates a new Envoy for a child module and adds it to
this Envoy's children.
Args:
module: Module to create Envoy for.
name: Name of envoy/attribute.
"""
module_path = f"{self.path}.{name}"
envoy = Envoy(
module,
path=module_path,
rename=self._alias.rename if self._alias is not None else None,
interleaver=self._interleaver
)
self._children.append(envoy)
setattr(self._module, name, module)
# If the module already has a sub-module named 'input' or 'output',
# mount the proxy access to 'nns_input' or 'nns_output instead.
if hasattr(Envoy, name):
self._handle_overloaded_mount(envoy, name)
else:
super().__setattr__(name, envoy)
return envoy
def _handle_overloaded_mount(self, envoy: Envoy, mount_point: str) -> None:
"""If a given module already has an attribute of the same name as something nnsight wants to add, we need to rename it.
Directly edits the underlying class to accomplish this.
Args:
envoy (Envoy): Envoy to handle.
mount_point (str): Overloaded attribute name.
"""
warnings.warn(
f"Module `{self.path}` of type `{type(self._module)}` has pre-defined a `{mount_point}` attribute. nnsight access for `{mount_point}` will be mounted at `.nns_{mount_point}` instead of `.{mount_point}` for this module only."
)
# If we already shifted a mount point dont create another new class.
if "Preserved" in self.__class__.__name__:
new_cls = self.__class__
else:
new_cls = type(
f"{self.__class__.__name__}.Preserved",
(self.__class__,),
{},
)
self.__class__ = new_cls
# Get the normal proxy mount point
mount = getattr(Envoy, mount_point)
setattr(new_cls, f"nns_{mount_point}", mount)
if isinstance(mount, property):
mount = property(
lambda slf: slf.__dict__[mount_point],
mount.fset,
mount.fdel,
mount.__doc__,
)
setattr(new_cls, mount_point, mount)
# Move it to nns_<mount point>
self.__dict__[mount_point] = envoy
def _update(self, module: torch.nn.Module) -> None:
"""Updates the ._model attribute using a new model of the same architecture.
Used when loading the real weights (dispatching) and need to replace the underlying modules.
"""
i = 0
for i, child in enumerate(module.children()):
self._children[i]._update(child)
# Handle extra modules added after initialization: issues/376
for name, child in list(self._module.named_children())[i + 1 :]:
setattr(module, name, child)
self._module = module
self._module.__path__ = self.path
self._interleaver.wrap_module(module)
if self._source is not None:
def wrap(fn: Callable, **kwargs):
bound_obj = (
fn.__self__
if inspect.ismethod(fn) and fn.__name__ != "forward"
else None
)
if self.interleaving:
return self._interleaver.wrap_operation(
fn, **kwargs, bound_obj=bound_obj
)
else:
return fn
source, line_numbers, forward = inject(
self._module.forward, wrap, self._module.__path__
)
self._module.forward = MethodType(forward, self._module)
def _update_alias(self, alias: Dict[str, str]):
"""
Update the alias for this Envoy and its children.
"""
if alias is not None:
self._alias = Aliaser(alias)
self._alias.build(self)
for envoy in self._children:
envoy._update_alias(alias)
def _shallow_copy(self) -> Envoy:
"""Creates a new instance copy of the same class with the all the attributes of the original instance.
Returns:
Self: NNsightModel
"""
copy = self.__class__.__new__(self.__class__)
for key, value in self.__dict__.items():
copy.__dict__[key] = value
return copy
#### Dunder methods ####
def __len__(self):
"""
Get the length of the Envoy.
"""
return len(self._module)
def __iter__(self):
"""
Iterate over the Envoy.
"""
return iter(self._children)
def __str__(self):
"""
String representation of the Envoy.
Returns:
A string representation of the Envoy showing its path
"""
return self.__repr__()
def __reprlist__(self):
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._module._get_name() + "()"
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
lines = []
main_str = self._module._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
def __repr__(self):
"""
Representation of the Envoy.
Returns:
The string representation of the Envoy
"""
if isinstance(self._module, torch.nn.ModuleList):
return self.__reprlist__()
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self._module.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for envoy in self._children:
key = envoy.path.split(".")[-1]
mod_str = repr(envoy)
mod_str = _addindent(mod_str, 2)
if self._alias is not None and key in self._alias.name_to_aliases:
key = "/".join([*self._alias.name_to_aliases[key], key])
child_lines.append("(" + key + "): " + mod_str)
if self._alias is not None:
for extra in self._alias.extras:
key = "/".join(self._alias.name_to_aliases[extra])
envoy = self.get(extra)
mod_str = repr(envoy)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
main_str = self._module._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
def __getattr__(self, name: str) -> Union[torch.nn.Module, Envoy, Any]:
"""
Get an attribute from the underlying module.
If the attribute is callable, it will be wrapped in a tracer to enable
intervention during execution.
Args:
name: The name of the attribute to get
Returns:
The attribute value, possibly wrapped in a tracer
Raises:
AttributeError: If the attribute doesn't exist
"""
if self._alias is not None and name in self._alias.alias_to_name:
return util.fetch_attr(self, self._alias.alias_to_name[name])
if hasattr(self._module, name):
value = getattr(self._module, name)
# It's a method bound to the module, create an interleaver for it
if isinstance(
value,
(FunctionType, MethodType, BuiltinFunctionType, BuiltinMethodType),
):
# If the Envoy defines a method with __nnsight_{name}__, use it instead to override
value = getattr(self, f"__nnsight_{name}__", value)
def trace(*args, **kwargs):
try:
return self.trace(*args, fn=value, **kwargs)
except WithBlockNotFoundError as e:
return value(*args, **kwargs)
return trace
elif isinstance(value, torch.nn.Module):
# If the _module has a module in its __dict__ but wasn't picked up when creating the Envoy,
# Hopefully it is alrady an Envoy somewhere in the tree.
# https://github.com/ndif-team/nnsight/issues/479
# This happened because some transformers models set this class attr: _checkpoint_conversion_mapping
if hasattr(value, "__path__"):
return util.fetch_attr(self, value.__path__[len(self.path) :])
return self._add_envoy(value, name)
else:
return value
else:
raise AttributeError(f"{self} has no attribute {name}")
def __setattr__(self, key: Any, value: Any) -> None:
"""
Set an attribute on the Envoy.
If the value is a PyTorch module, it will be wrapped in an Envoy to enable
intervention during execution.
Args:
key: The attribute name
value: The attribute value
"""
if key != "_module" and isinstance(value, torch.nn.Module):
self._add_envoy(value, key)
else:
super().__setattr__(key, value)
# TODO extend Envoy
[docs]
class OperationEnvoy:
"""
Represents a specific operation within a module's forward pass.
This class provides access to the inputs and outputs of individual
operations within a module's execution, allowing for fine-grained
inspection and intervention at the operation level.
"""
def __init__(
self,
name: str,
source: str,
line_number: int,
interleaver: Optional[Interleaver] = None,
):
"""
Initialize an OperationEnvoy.
Args:
name: The fully qualified name of the operation
source: The source code of the module containing the operation
line_number: The line number of the operation in the source
interleaver: Optional interleaver for managing execution flow
"""
self.name = name
self.source_code = source
self.line_number = line_number
self._interleaver = interleaver
self._source: EnvoySource = None
self._fn: Callable = None
def __str__(self):
"""
String representation showing the operation in context.
This method returns a formatted string showing the operation's source code
with surrounding context lines and highlighting the operation line.
Returns:
A formatted string showing the operation's source code with context
"""
source_lines = self.source_code.split("\n")
start_idx = max(0, self.line_number - 5)
end_idx = min(len(source_lines) - 1, self.line_number + 8)
highlighted_lines = [self.name + ":\n"]
if start_idx != 0:
highlighted_lines.append(" ....")
for i in range(start_idx, end_idx):
line = source_lines[i]
if i == self.line_number + 1:
highlighted_lines.append(f" --> {line[4:]} <--")
else:
highlighted_lines.append(" " + line)
if end_idx != len(source_lines) - 1:
highlighted_lines.append(" ....")
return "\n".join(highlighted_lines)
@property
def output(self) -> Union[Any, torch.Tensor]:
"""
Get the output of this operation.
This property provides access to the return value(s) produced by the operation
during execution.
Returns:
The operation's output value(s)
"""
return self._interleaver.current.request(f"{self.name}.output")
@output.setter
def output(self, value: Any) -> None:
"""
Set a new value for the operation's output.
This allows for intervention by replacing the operation's output with
a custom value during execution.
Args:
value: The new output value
"""
self._interleaver.current.swap(f"{self.name}.output", value)
@property
def inputs(
self,
) -> Tuple[Tuple[Any, torch.Tensor], Dict[str, Union[torch.Tensor, Any]]]:
"""
Get the inputs to this operation.
This property provides access to all input value(s) passed to the operation
during execution, structured as a tuple of positional and keyword arguments.
Returns:
The operation's input value(s)
"""
return self._interleaver.current.request(f"{self.name}.input")
@inputs.setter
def inputs(self, value: Any) -> None:
"""
Set new values for the operation's inputs.
This allows for intervention by replacing the operation's inputs with
custom values during execution.
Args:
value: The new input value(s)
"""
self._interleaver.current.swap(f"{self.name}.input", value)
@inputs.deleter
def inputs(self):
"""
Clear the cached input value.
This removes any stored input values, forcing them to be recomputed
on the next access.
"""
self._input = None
@property
def input(self) -> Union[Any, torch.Tensor]:
"""
Get the first input to the operation.
This is a convenience property that returns just the first input value
from all inputs passed to the operation.
Returns:
The first input value
"""
inputs = self.inputs
return [*inputs[0], *inputs[1].values()][0]
@input.setter
def input(self, value: Any) -> None:
"""
Set a new value for the operation's first input.
This is a convenience method that replaces just the first positional input
while preserving all other inputs.
Args:
value: The new value for the first input
"""
inputs = self.inputs
value = (value, *inputs[0][1:]), inputs[1]
self.inputs = value
@property
def source(self) -> EnvoySource:
"""
Get the source code of the operation.
This property provides access to the operation's source code with nested
operations highlighted, allowing for inspection and intervention at specific points.
Returns:
An EnvoySource object containing the operation's source code and nested operations
"""
if self._source is None:
fn = self._interleaver.current.request(f"{self.name}.fn")
# TODO maybe do something else here
if isinstance(fn, torch.nn.Module):
msg = f"Don't call .source on a module ({getattr(fn, '__path__', '')}) from within another .source. Call it directly with: {getattr(fn, '__path__', '')}.source"
raise ValueError(msg)
def wrap(fn: Callable, **kwargs):
bound_obj = (
fn.__self__
if getattr(fn, "__name__", None) != "forward"
and inspect.ismethod(fn)
else None
)
return self._interleaver.wrap_operation(
fn, **kwargs, bound_obj=bound_obj
)
source, line_numbers, fn = inject(fn, wrap, self.name)
self._source = EnvoySource(
self.name, source, line_numbers, interleaver=self._interleaver
)
self._fn = fn
if f"{self.name}.fn" not in self._interleaver.current.history:
self._interleaver.current.swap(f"{self.name}.fn", self._fn)
return self._source
# @input.setter
# def input(self, value: Any):
# #TODO would need await...
# inputs = self._input
# self._input = ((value, *inputs[0]), inputs[1])
# self.interleaver.set_swap(self._input, (self.module, self.name), Events.INPUT)
# @input.deleter
# def input(self):
# self._input = None
[docs]
class EnvoySource:
"""
Represents the source code of a module with operations highlighted.
This class provides access to the individual operations within a module's
source code, allowing for inspection and intervention at specific points
in the code. It serves as a bridge between the source code representation
and the runtime execution of operations.
"""
def __init__(
self,
name: str,
source: str,
line_numbers: dict,
interleaver: Optional[Interleaver] = None,
):
"""
Initialize an EnvoySource.
Args:
name: The fully qualified name of the module or operation
source: The source code string
line_numbers: A dictionary mapping operation names to line numbers
interleaver: Optional interleaver for managing execution flow
"""
self.source = source
self.line_numbers = line_numbers
self.operations: List[OperationEnvoy] = []
for _name, line_number in line_numbers.items():
operation = OperationEnvoy(
f"{name}.{_name}", source, line_number, interleaver=interleaver
)
setattr(self, _name, operation)
self.operations.append(operation)
def __str__(self):
"""
String representation showing the source code with operations highlighted.
This method returns a formatted string showing the source code with
operation names and line numbers, making it easy to identify intervention points.
Returns:
A formatted string showing the source code with operation names and line numbers
"""
# Find the longest name for proper alignment
max_name_length = (
max(len(name) for name in self.line_numbers.keys())
if self.line_numbers
else 0
)
source_lines = self.source.split("\n")
formatted_lines = [
" " * (max_name_length + 6) + "* " + source_lines[0]
] # Keep the function definition unchanged
# Group operations by line number
operations_by_line = {}
for name, line_number in self.line_numbers.items():
if line_number not in operations_by_line:
operations_by_line[line_number] = []
operations_by_line[line_number].append(name)
for i, line in enumerate(source_lines[1:]):
line_number = i
# Check if this line has operations
if line_number in operations_by_line:
# Handle multiple operations on the same line
operations = operations_by_line[line_number]
# First operation gets the line number
first_op = operations[0]
line_prefix = f" {first_op:{max_name_length}} ->{line_number:3d} "
formatted_lines.append(f"{line_prefix}{line}")
# For nested operations, unwrap them onto separate lines
if len(operations) > 1:
for op in operations[1:]:
continuation_prefix = f" {op:{max_name_length}} -> + "
# Instead of just showing a vertical line, show the operation on its own line
formatted_lines.append(
f"{continuation_prefix}{' ' * (len(line) - len(line.lstrip()))}..."
)
else:
# Regular line with no operations
line_prefix = " " * (max_name_length + 4) + f"{line_number:3d} "
formatted_lines.append(f"{line_prefix}{line}")
source = "\n".join(formatted_lines)
return source
def __getattribute__(self, name: str) -> Union[OperationEnvoy]:
return object.__getattribute__(self, name)
class Aliaser:
def __init__(self, rename: Dict[str, Union[str, List[str]]]):
"""
Initialize an Aliaser.
Args:
rename (Dict[str, Union[str, List[str]]]): Dictionary mapping module names to alias names.
Example: {"layer1": "first_layer", "layer2": "second_layer"}
Example: {".model.layers": ".layers"} <-- Mounts .layers to the root model.
Example: {".transformer": ["model", "mdl"]} <-- Allows access of .transformer as .model or .mdl
Attributes:
rename (Dict[str, Union[str, List[str]]]): Dictionary mapping module names to alias names.
alias_to_name (Dict[str, str]): Dictionary mapping alias names to module names.
name_to_aliases (Dict[str, List[str]]): Dictionary mapping module names to list of alias names.
extras (Dict[str, List[str]]): Dictionary mapping attribute paths (.transformer.h) to list of alias names.
Used to show dot seperated attributes in the string representation of the Envoy.
"""
self.rename = rename
self.alias_to_name = {}
self.name_to_aliases = {}
self.extras = {}
def build(self, envoy: Envoy):
for name, aliases in self.rename.items():
try:
util.fetch_attr(envoy, name)
except:
continue
if isinstance(aliases, str):
aliases = [aliases]
name = name.removeprefix(".")
if "." in name:
self.extras[name] = aliases
self.name_to_aliases[name] = aliases
for alias in aliases:
self.alias_to_name[alias] = name