"""This module contains logic to interleave a computation graph (an intervention graph) with the computation graph of a model.
The :class:`InterventionProxy <nnsight.intervention.InterventionProxy>` class extends the functionality of a base nnsight.tracing.Proxy.Proxy object and makes it easier for users to interact with.
:func:`intervene() <nnsight.intervention.InterventionProtocol.intervene>` is the entry hook into the models computation graph in order to interleave an intervention graph.
The :class:`HookModel <nnsight.intervention.HookModel>` provides a context manager for adding input and output hooks to modules and removing them upon context exit.
"""
from __future__ import annotations
import inspect
from collections import defaultdict
from contextlib import AbstractContextManager
from typing import Any, Callable, Collection, Dict, List, Tuple, Union
import torch
from torch.utils.hooks import RemovableHandle
from typing_extensions import Self
from . import util
from .contexts.Conditional import Conditional
from .tracing import protocols
from .tracing.Graph import Graph
from .tracing.Node import Node
from .tracing.protocols import Protocol
from .tracing.Proxy import Proxy
[docs]
class InterventionProxy(Proxy):
"""Sub-class for Proxy that adds additional user functionality to proxies.
Examples:
Saving a proxy so it is not deleted at the completion of it's listeners is enabled with ``.save()``:
.. code-block:: python
with model.trace('The Eiffel Tower is in the city of'):
hidden_states = model.lm_head.input.save()
logits = model.lm_head.output.save()
print(hidden_states)
print(logits)
"""
def __init__(self, node: Node) -> None:
super().__init__(node)
self.__dict__["_grad"] = None
self._grad: InterventionProxy
[docs]
def save(self) -> InterventionProxy:
"""Method when called, indicates to the intervention graph to not delete the tensor values of the result.
Returns:
InterventionProxy: Proxy.
"""
# Add a 'lock' node with the save proxy as an argument to ensure the values are never deleted.
# This is because 'lock' nodes never actually get set and therefore there will always be a
# dependency for the save proxy.
protocols.LockProtocol.add(self.node)
return self
[docs]
def stop(self) -> InterventionProxy:
"""Method when called, indicates to the intervention graph to stop the execution of the model after this Proxy/Node is completed..
Returns:
InterventionProxy: Proxy.
"""
protocols.EarlyStopProtocol.add(self.node.graph, self.node)
return self
[docs]
def update(self, value: Union[Node, Any]) -> InterventionProxy:
"""Updates the value of the Proxy via the creation of the UpdateProtocol node.
Args:
- value (Union[Node, Any]): New proxy value.
Returns:
InterventionProxy: Proxy.
.. codeb-block:: python
with model.trace(input) as tracer:
num = tracer.apply(int, 0)
num.update(5)
"""
return protocols.UpdateProtocol.add(self.node, value)
@property
def grad(self) -> InterventionProxy:
"""
Calling denotes the user wishes to get the grad of proxy tensor and therefore we create a Proxy of that request.
Only generates a proxy the first time it is references otherwise return the already set one.
Returns:
Proxy: Grad proxy.
"""
if self._grad is None:
self.__dict__["_grad"] = protocols.GradProtocol.add(self.node)
return self._grad
@grad.setter
def grad(self, value: Union[InterventionProxy, Any]) -> None:
"""
Calling denotes the user wishes to set the grad of this proxy tensor and therefore we create a Proxy of that request via a SwapProtocol.
Args:
value (Union[InterventionProxy, Any]): Value to set output to.
"""
protocols.SwapProtocol.add(self.grad.node, value)
self.__dict__["_grad"] = None
def __call__(self, *args, **kwargs) -> Self:
# We don't want to call backward on fake tensors.
# We also want to track the number of times .backward() has been called so .grad on a Proxy refers to the right backward pass.
if (
self.node.target is util.fetch_attr
and isinstance(self.node.args[1], str)
and self.node.args[1] == "backward"
):
# Clear all .grad proxies so allow users to get the ,.grad of the next backward pass.
for node in self.node.graph.nodes.values():
try:
if node.proxy._grad is not None:
node.proxy.__dict__["_grad"] = None
except ReferenceError:
pass
# Use GradProtocol to increment the tracking of the number of times .backward() has been called.
protocols.GradProtocol.increment(self.node.graph)
return self.node.create(
proxy_value=None,
target=Proxy.proxy_call,
args=[self.node] + list(args),
kwargs=kwargs,
)
return super().__call__(*args, **kwargs)
def __setattr__(
self, key: Union[InterventionProxy, Any], value: Union[Self, Any]
) -> None:
# We catch setting .grad as that is a special Protocol vs. setting attributes generally.
if key == "grad":
return getattr(self.__class__, key).fset(self, value)
return super().__setattr__(key, value)
@property
def shape(self) -> Collection[torch.Size]:
"""Property to retrieve the shape of the traced proxy value or real value.
Returns:
Union[torch.Size,Collection[torch.Size]]: Proxy value shape or collection of shapes.
"""
if not self.node.attached():
return util.apply(self.value, lambda x: x.shape, torch.Tensor)
# If we haven't scanned in a proxy_value, just return a proxy to get the attribute.
if self.node.proxy_value is inspect._empty:
return super().__getattr__("shape")
return util.apply(
self.node.proxy_value, lambda x: x.shape, torch.Tensor
)
@property
def device(self) -> Collection[torch.device]:
"""Property to retrieve the device of the traced proxy value or real value.
Returns:
Union[torch.Size,Collection[torch.device]]: Proxy value device or collection of devices.
"""
if not self.node.attached():
return util.apply(self.value, lambda x: x.device, torch.Tensor)
# If we haven't scanned in a proxy_value, just return a proxy to get the attribute.
if self.node.proxy_value is inspect._empty:
return super().__getattr__("device")
return util.apply(
self.node.proxy_value, lambda x: x.device, torch.Tensor
)
@property
def dtype(self) -> Collection[torch.device]:
"""Property to retrieve the dtype of the traced proxy value or real value.
Returns:
Union[torch.Size,Collection[torch.dtype]]: Proxy value dtype or collection of dtypes.
"""
if not self.node.attached():
return util.apply(self.value, lambda x: x.dtype, torch.Tensor)
# If we haven't scanned in a proxy_value, just return a proxy to get the attribute.
if self.node.proxy_value is inspect._empty:
return super().__getattr__("dtype")
return util.apply(
self.node.proxy_value, lambda x: x.dtype, torch.Tensor
)
[docs]
class InterventionProtocol(Protocol):
"""Primary Protocol that handles tracking and injecting inputs and outputs from a torch model into the overall intervention Graph.
Uses an attachment on the Graph to store the names of nodes that need to be injected with data from inputs or outputs of modules.
"""
attachment_name = "nnsight_module_nodes"
condition: bool = False
[docs]
@classmethod
def add(
cls,
graph: "Graph",
proxy_value: Any,
args: List[Any] = None,
kwargs: Dict[str, Any] = None,
) -> Proxy:
"""Adds an InterventionProtocol Node to a Graph.
Args:
graph (Graph): Graph to add to.
module_path (str): Module path of data this Node depends on (ex. model.module1.module2.output)
proxy_value (Any): Proxy value.
args (List[Any], optional): Args. Defaults to None.
kwargs (Dict[str, Any], optional): Kwargs. Defaults to None.
Returns:
Proxy: _description_
"""
# Creates the InterventionProtocol Node.
proxy = graph.create(
proxy_value=proxy_value, target=cls, args=args, kwargs=kwargs
)
cls.compile(proxy.node)
return proxy
@classmethod
def compile(cls, node: Node) -> None:
graph = node.graph
module_path, *_ = node.args
# Add attachment if it does not exist.
if cls.attachment_name not in graph.attachments:
graph.attachments[cls.attachment_name] = dict()
# More than one Node can depend on a given input or output, therefore we store a list of node names.
arguments = graph.attachments[cls.attachment_name]
if module_path not in arguments:
arguments[module_path] = []
# Append the newly created nodes name.
arguments[module_path].append(node.name)
[docs]
@classmethod
def get_interventions(cls, graph: "Graph") -> Dict:
"""Returns mapping from module_paths to InterventionNode names added to the given Graph.
Args:
graph (Graph): Graph.
Returns:
Dict: Interventions.
"""
return graph.attachments.get(cls.attachment_name, dict())
@classmethod
def concat(
cls,
activations: Any,
value: Any,
batch_start: int,
batch_size: int,
total_batch_size: int,
):
def _concat(values):
data_type = type(values[0])
if data_type == torch.Tensor:
orig_size = values[-1]
new_size = sum([value.shape[0] for value in values[:-1]])
if new_size == orig_size:
return torch.concatenate(values[:-1])
return values[0]
elif data_type == list:
return [
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif data_type == tuple:
return tuple(
[
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
)
elif data_type == dict:
return {
key: _concat([value[key] for value in values])
for key in values[0].keys()
}
return values[0]
def narrow1(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(0, 0, batch_start)
return acts
pre = util.apply(activations, narrow1, torch.Tensor)
post_batch_start = batch_start + batch_size
def narrow2(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(
0, post_batch_start, acts.shape[0] - post_batch_start
)
return acts
post = util.apply(
activations,
narrow2,
torch.Tensor,
)
orig_sizes = util.apply(activations, lambda x: x.shape[0], torch.Tensor)
return _concat([pre, value, post, orig_sizes])
[docs]
@classmethod
def intervene(
cls,
activations: Any,
module_path: str,
key: str,
intervention_handler: InterventionHandler,
):
"""Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.
Forms the current module_path key in the form of <module path>.<output/input>
Checks the graphs InterventionProtocol attachment attribute for this key.
If exists, value is a list of node names to iterate through.
Node args for intervention type nodes should be ``[module_path, batch_size, batch_start, call_iter]``.
Checks and updates the counter for the given intervention node. If counter is not ready yet continue.
Using batch_size and batch_start, apply torch.narrow to tensors in activations to select
only batch indexed tensors relevant to this intervention node. Sets the value of a node
using the indexed values. Using torch.narrow returns a view of the tensors as opposed to a copy allowing
subsequent downstream nodes to make edits to the values only in the relevant tensors, and have it update the original
tensors. This both prevents interventions from effecting bathes outside their preview and allows edits
to the output from downstream intervention nodes in the graph.
Args:
activations (Any): Either the inputs or outputs of a torch module.
module_path (str): Module path of the current relevant module relative to the root model.
key (str): Key denoting either "input" or "output" of module.
intervention_handler (InterventionHandler): Handler object that stores the intervention graph and keeps track of module call count.
Returns:
Any: The activations, potentially modified by the intervention graph.
"""
# Key to module activation intervention nodes has format: <module path>.<output/input>
module_path = f"{module_path}.{key}"
interventions = cls.get_interventions(intervention_handler.graph)
if module_path in interventions:
intervention_node_names = interventions[module_path]
# Multiple intervention nodes can have same module_path if there are multiple invocations.
for intervention_node_name in intervention_node_names:
node = intervention_handler.graph.nodes[intervention_node_name]
# Args for intervention nodes are (module_path, batch_group_idx, call_iter).
_, batch_group_idx, call_iter = node.args
batch_start, batch_size = intervention_handler.batch_groups[
batch_group_idx
]
# Updates the count of intervention node calls.
# If count matches call_iter, time to inject value into node.
if call_iter != intervention_handler.count(
intervention_node_name
):
continue
value = activations
narrowed = False
if len(intervention_handler.batch_groups) > 1:
def narrow(acts: torch.Tensor):
if acts.shape[0] == intervention_handler.batch_size:
nonlocal narrowed
narrowed = True
return acts.narrow(0, batch_start, batch_size)
return acts
value = util.apply(
activations,
narrow,
torch.Tensor,
)
# Value injection.
node.set_value(value)
# Check if through the previous value injection, there was a 'swap' intervention.
# This would mean we want to replace activations for this batch with some other ones.
value = protocols.SwapProtocol.get_swap(
intervention_handler.graph, value
)
# If we narrowed any data, we need to concat it with data before and after it.
if narrowed:
activations = cls.concat(
activations,
value,
batch_start,
batch_size,
intervention_handler.batch_size,
)
# Otherwise just return the whole value as the activations.
else:
activations = value
return activations
[docs]
@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""
return {
"node": {"color": "green4", "shape": "box"}, # Node display
"label": cls.__name__,
"arg": defaultdict(
lambda: {"color": "gray", "shape": "box"}
), # Non-node argument display
"arg_kname": defaultdict(
lambda: None, {0: "key", 1: "batch_size", 2: "batch_start"}
), # Argument label key word
"edge": defaultdict(lambda: "solid"),
} # Argument Edge display
[docs]
class HookHandler(AbstractContextManager):
"""Context manager that applies input and/or output hooks to modules in a model.
Registers provided hooks on __enter__ and removes them on __exit__.
Attributes:
model (torch.nn.Module): Root model to access modules and apply hooks to.
modules (List[Tuple[torch.nn.Module, str]]): List of modules to apply hooks to along with their module_path.
input_hook (Callable): Function to apply to inputs of designated modules.
Should have signature of [inputs(Any), module_path(str)] -> inputs(Any)
output_hook (Callable): Function to apply to outputs of designated modules.
Should have signature of [outputs(Any), module_path(str)] -> outputs(Any)
handles (List[RemovableHandle]): Handles returned from registering hooks as to be used when removing hooks on __exit__.
"""
def __init__(
self,
model: torch.nn.Module,
module_keys: List[str],
input_hook: Callable = None,
output_hook: Callable = None,
) -> None:
self.model = model
self.module_keys = module_keys
self.input_hook = input_hook
self.output_hook = output_hook
self.handles: List[RemovableHandle] = []
def __enter__(self) -> HookHandler:
"""Registers input and output hooks to modules if they are defined.
Returns:
HookModel: HookModel object.
"""
for module_key in self.module_keys:
module_atoms = module_key.split(".")
if len(module_atoms) == 1:
continue
*module_atoms, hook_type = module_atoms
module_path = ".".join(module_atoms)
module: torch.nn.Module = util.fetch_attr(self.model, module_path)
if hook_type == "input":
def input_hook(module, input, kwargs, module_path=module_path):
return self.input_hook((input, kwargs), module_path)
self.handles.append(
module.register_forward_pre_hook(
input_hook, with_kwargs=True, prepend=True
)
)
elif hook_type == "output":
def output_hook(module, input, output, module_path=module_path):
return self.output_hook(output, module_path)
self.handles.append(
module.register_forward_hook(output_hook, prepend=True)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Removes all handles added during __enter__."""
for handle in self.handles:
handle.remove()
if isinstance(exc_val, Exception):
raise exc_val
[docs]
class InterventionHandler:
"""Object passed to InterventionProtocol.intervene to store information about the current interleaving execution run.
Like the Intervention Graph, the total batch size that is being executed, and a counter for how many times an Intervention node has been attempted to be executed.
"""
def __init__(
self, graph: Graph, batch_groups: List[Tuple[int, int]], batch_size: int
) -> None:
self.graph = graph
self.batch_groups = batch_groups
self.batch_size = batch_size
self.call_counter: Dict[str, int] = {}
[docs]
def count(self, name: str) -> int:
"""Increments the count of times a given Intervention Node has tried to be executed and returns the count.
Args:
name (str): Name of intervention node to return count for.
Returns:
int: Count.
"""
if name not in self.call_counter:
self.call_counter[name] = 0
else:
self.call_counter[name] += 1
return self.call_counter[name]