Source code for nnsight.tracing.protocols

import inspect
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import torch
from torch._subclasses.fake_tensor import FakeCopyMode, FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv

from nnsight.tracing.Node import Node

from .. import util
from ..contexts.Conditional import ConditionalManager
from .util import validate

if TYPE_CHECKING:
    from ..contexts.backends.LocalBackend import LocalMixin
    from ..contexts.Conditional import Conditional
    from ..intervention import InterventionProxy
    from .Bridge import Bridge
    from .Graph import Graph
    from .Node import Node


[docs] class Protocol: """A `Protocol` represents some complex action a user might want to create a `Node` and `Proxy` for as well as add to a `Graph`. Unlike normal `Node` target execution, these have access to the `Node` itself and therefore the `Graph`, enabling more powerful functionality than with just functions and methods. """ redirect: bool = True condition: bool = True
[docs] @classmethod def add(cls, *args, **kwargs) -> "InterventionProxy": """Class method to be implemented in order to add a Node of this Protocol to a Graph.""" raise NotImplementedError()
[docs] @classmethod def execute(cls, node: "Node"): """Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing Args: node (Node): Node to execute using this Protocols execution logic. """ pass
@classmethod def compile(cls, node: "Node") -> None: pass
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "black", "shape": "ellipse"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class ApplyModuleProtocol(Protocol): """Protocol that references some root model, and calls its .forward() method given some input. Using .forward() vs .__call__() means it wont trigger hooks. Uses an attachment to the Graph to store the model. """ attachment_name = "nnsight_root_module"
[docs] @classmethod def add( cls, graph: "Graph", module_path: str, *args, hook=False, **kwargs ) -> "InterventionProxy": """Creates and adds an ApplyModuleProtocol to the Graph. Assumes the attachment has already been added via ApplyModuleProtocol.set_module(). Args: graph (Graph): Graph to add the Protocol to. module_path (str): Module path (model.module1.module2 etc), of module to apply from the root module. Returns: InterventionProxy: ApplyModule Proxy. """ value = inspect._empty # If the Graph is validating, we need to compute the proxy_value for this node. if graph.validate: from .Node import Node # If the module has parameters, get its device to move input tensors to. module: torch.nn.Module = util.fetch_attr( cls.get_module(graph), module_path ) try: device = next(module.parameters()).device except: device = None # Enter FakeMode for proxy_value computing. value = validate(module.forward, *args, **kwargs) kwargs["hook"] = hook # Create and attach Node. return graph.create( target=cls, proxy_value=value, args=[module_path] + list(args), kwargs=kwargs, )
[docs] @classmethod def execute(cls, node: "Node") -> None: """Executes the ApplyModuleProtocol on Node. Args: node (Node): ApplyModule Node. """ module: torch.nn.Module = util.fetch_attr( cls.get_module(node.graph), node.args[0] ) try: device = next(module.parameters()).device except: device = None args, kwargs = node.prepare_inputs( (node.args, node.kwargs), device=device ) module_path, *args = args hook = kwargs.pop("hook") if hook: output = module(*args, **kwargs) else: output = module.forward(*args, **kwargs) node.set_value(output)
[docs] @classmethod def set_module(cls, graph: "Graph", module: torch.nn.Module) -> None: """Sets the nnsight root module as an attachment on a Graph. Args: graph (Graph): Graph. module (torch.nn.Module): Root module. """ graph.attachments[cls.attachment_name] = module
[docs] @classmethod def get_module(cls, graph: "Graph") -> torch.nn.Module: """Returns the nnsight root module from an attachment on a Graph. Args: graph (Graph): Graph Returns: torch.nn.Module: Root Module. """ return graph.attachments[cls.attachment_name]
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": { "color": "blue", "shape": "polygon", "sides": 6, }, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class LockProtocol(Protocol): """Simple Protocol who's .execute() method does nothing. This means not calling .set_value() on the Node, therefore the Node won't be destroyed.""" redirect: bool = False
[docs] @classmethod def add(cls, node: "Node") -> "InterventionProxy": return node.create( proxy_value=None, target=cls, args=[node], )
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "brown", "shape": "ellipse"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument lable key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class GradProtocol(Protocol): """Protocol which adds a backwards hook via .register_hook() to a Tensor. The hook injects the gradients into the node's value on hook execution. Nodes created via this protocol are relative to the next time .backward() was called during tracing allowing separate .grads to reference separate backwards passes: .. code-block:: python with model.trace(...): grad1 = model.module.output.grad.save() model.output.sum().backward(retain_graph=True) grad2 = model.module.output.grad.save() model.output.sum().backward() Uses an attachment to store number of times .backward() has been called during tracing so a given .grad hook is only value injected at the appropriate backwards pass. """ attachment_name = "nnsight_backward_idx"
[docs] @classmethod def add(cls, node: "Node") -> "InterventionProxy": # Get number of times .backward() was called during tracing from an attachment. Use as Node argument. backward_idx = node.graph.attachments.get(cls.attachment_name, 0) return node.create( proxy_value=node.proxy_value, target=cls, args=[node, backward_idx], )
[docs] @classmethod def execute(cls, node: "Node") -> None: args, kwargs = node.prepare_inputs((node.args, node.kwargs)) # First arg is the Tensor to add hook to. tensor: torch.Tensor = args[0] # Second is which backward pass this Node refers to. backward_idx: int = args[1] # Hook to remove when hook is executed at the appropriate backward pass. hook = None def grad(value): nonlocal backward_idx # If backward_idx == 0, this is the correct backward pass and we should actually execute. if backward_idx == 0: # Set the value of the Node. node.set_value(value) if node.attached(): # There may be a swap Protocol executed during the resolution of this part of the graph. # If so get it and replace value with it. value = SwapProtocol.get_swap(node.graph, value) # Don't execute this hook again. backward_idx = -1 # Remove hook (if this is not done memory issues occur) hook.remove() return value # Otherwise decrement backward_idx else: backward_idx -= 1 return None # Register hook. hook = tensor.register_hook(grad)
[docs] @classmethod def increment(cls, graph: "Graph"): """Increments the backward_idx attachment to track the number of times .backward() is called in tracing for this Graph. Args: graph (Graph): Graph. """ backward_idx = graph.attachments.get(cls.attachment_name, 0) graph.attachments[cls.attachment_name] = backward_idx + 1
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "green4", "shape": "box"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class SwapProtocol(Protocol): """Protocol which adds an attachment to the Graph which can store some value. Used to replace ('swap') a value with another value.""" attachment_name = "nnsight_swap"
[docs] @classmethod def add(cls, node: "Node", value: Any) -> "InterventionProxy": return node.create(target=cls, args=[node, value], proxy_value=True)
[docs] @classmethod def execute(cls, node: "Node") -> None: # In case there is already a swap, get it from attachments. swap: "Node" = node.graph.attachments.get(cls.attachment_name, None) # And set it to False to destroy it. if swap is not None: swap.set_value(False) # Set the swap to this Node. node.graph.attachments[cls.attachment_name] = node
[docs] @classmethod def get_swap(cls, graph: "Graph", value: Any) -> Any: """Checks if a swap exists on a Graph. If so get and return it, otherwise return the given value. Args: graph (Graph): Graph value (Any): Default value. Returns: Any: Default value or swap value. """ # Tries to get the swap. swap: "Node" = graph.attachments.get(cls.attachment_name, None) # If there was one: if swap is not None: device = None def _device(value: torch.Tensor): nonlocal device device = value.device # Get device of default value. util.apply(value, _device, torch.Tensor) # Get swap Node's value. value = util.apply(swap.args[1], lambda x: x.value, type(swap)) if device is not None: def _to(value: torch.Tensor): return value.to(device) # Move swap values to default value's device. value = util.apply(value, _to, torch.Tensor) # Set value of 'swap' node so it destroys itself and listeners. swap.set_value(True) # Un-set swap. graph.attachments[cls.attachment_name] = None return value
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "green4", "shape": "ellipse"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge key word
[docs] class BridgeProtocol(Protocol): """Protocol to connect two Graphs by grabbing a value from one and injecting it into another. Uses an attachment to store a Bridge object which references all relevant Graphs and their ordering. """ attachment_name = "nnsight_bridge" condition: bool = False
[docs] class BridgeException(Exception): def __init__(self): super.__init__( "Must define a Session context to make use of the Bridge" )
[docs] @classmethod def add(cls, node: "Node") -> "InterventionProxy": bridge = cls.get_bridge(node.graph) curr_graph = bridge.peek_graph() bridge_proxy = bridge.get_bridge_proxy( node, curr_graph.id ) # a bridged node has a unique bridge node proxy per graph reference # if the bridge node does not exist, create one if bridge_proxy is None: # Adds a Lock Node. One, so the value from_node isn't destroyed until the to_nodes are done with it, # and two acts as an easy reference to the from_node to get its value from the lock Node args. lock_node = LockProtocol.add(node).node # Args for a Bridge Node are the id of the Graph and node name of the Lock Node. bridge_proxy = node.create( target=cls, proxy_value=node.proxy_value, args=[node.graph.id, lock_node.name], ) bridge.add_bridge_proxy(node, bridge_proxy) return bridge_proxy
[docs] @classmethod def execute(cls, node: "Node") -> None: # Gets Bridge object from the Node's Graph. bridge = cls.get_bridge(node.graph) # Args are Graph's id and name of the Lock Node on it. from_graph_id, lock_node_name = node.args # Gets the from_node's Graph via its id with the Bridge and get the Lock Node. lock_node = bridge.get_graph(from_graph_id).nodes[lock_node_name] # Value node is Lock Node's only arg value_node: "Node" = lock_node.args[0] if value_node.done(): # Set value to that of the value Node. node.set_value(value_node.value) # Bridge.release tells this Protocol when to release all Lock Nodes as we no longer need the data (useful when running a Graph in a loop, only release on last iteration) if bridge.release: lock_node.set_value(None)
[docs] @classmethod def set_bridge(cls, graph: "Graph", bridge: "Bridge") -> None: """Sets Bridge object as an attachment on a Graph. Args: graph (Graph): Graph. bridge (Bridge): Bridge. """ graph.attachments[cls.attachment_name] = weakref.proxy(bridge)
[docs] @classmethod def get_bridge(cls, graph: "Graph") -> "Bridge": """Gets Bridge object from a Graph. Assumes Bridge has been set as an attachment on this Graph via BridgeProtocol.set_bridge(). Args: graph (Graph): Graph. Returns: Bridge: Bridge. """ if not cls.has_bridge(graph): raise cls.BridgeException() return graph.attachments[cls.attachment_name]
[docs] @classmethod def has_bridge(cls, graph: "Graph") -> bool: """Checks to see if a Bridge was added as an attachment on this Graph via BridgeProtocol.set_bridge(). Args: graph (Graph): Graph Returns: bool: If Graph has Bridge attachment. """ return cls.attachment_name in graph.attachments
[docs] @classmethod def peek_graph(cls, graph: "Graph") -> "Graph": """Returns current Intervention Graph. Args: - graph (Graph): Graph. Returns: Graph: Graph. """ if not BridgeProtocol.has_bridge(graph): return graph else: bridge = BridgeProtocol.get_bridge(graph) return bridge.peek_graph()
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "brown", "shape": "box"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: { "color": "gray", "shape": "box", }, # Non-node argument display {0: {"color": "gray", "shape": "box", "style": "dashed"}}, ), "arg_kname": defaultdict( lambda: None, {0: "graph_id"} ), # Arugment label key word "edge": defaultdict(lambda: "solid", {0: "dashed"}), } # Argument edge display
[docs] class EarlyStopProtocol(Protocol): """Protocol to stop the execution of a model early."""
[docs] class EarlyStopException(Exception): pass
[docs] @classmethod def add( cls, graph: "Graph", stop_point_node: Optional["Node"] = None ) -> "InterventionProxy": return graph.create( target=cls, proxy_value=None, args=([stop_point_node] if stop_point_node is not None else []), )
[docs] @classmethod def execute(cls, node: "Node") -> None: node.set_value(True) raise cls.EarlyStopException()
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": { "color": "red", "shape": "polygon", "sides": 6, }, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class LocalBackendExecuteProtocol(Protocol):
[docs] @classmethod def add(cls, object: "LocalMixin", graph: "Graph") -> "InterventionProxy": return graph.create(target=cls, proxy_value=None, args=[object])
[docs] @classmethod def execute(cls, node: Node) -> None: object: "LocalMixin" = node.args[0] object.local_backend_execute() node.set_value(None)
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": { "color": "purple", "shape": "polygon", "sides": 6, }, # Node display "label": "ExecuteProtocol", "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument display "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class ValueProtocol(Protocol):
[docs] @classmethod def add(cls, graph: "Graph", default: Any = None) -> "InterventionProxy": return graph.create(target=cls, proxy_value=default, args=[default])
[docs] @classmethod def execute(cls, node: Node) -> None: node.set_value(node.args[0])
@classmethod def set(cls, node: Node, value: Any) -> None: node.args[0] = value
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "blue", "shape": "box"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class ConditionalProtocol(Protocol): """Protocol operating as a conditional statement. Uses the ConditionalManager attachment to handle all visited Conditional contexts within a single Intervention Graph. Evaluates the condition value of the Conditional as a boolean. Example: Setup: .. code-block:: python import torch from collections import OrderedDict input_size = 5 hidden_dims = 10 output_size = 2 model = nn.Sequential(OrderedDict([ ('layer1', torch.nn.Linear(input_size, hidden_dims)), ('layer2', torch.nn.Linear(hidden_dims, output_size)), ])) input = torch.rand((1, input_size)) å Ex 1: The .save() on the model output will only be executed if the condition (x > 0) is evaluated to True. .. code-block:: python with model.trace(input) as tracer: num = 5 with tracer.cond(x > 0): out = model.output.save() Ex 2: The condition is a tensor boolean operation on the Envoy's output InterventionProxy. .. code-block:: python with model.trace(input) as tracer: l1_out = model.layer1.output with tracer.cond(l1_out[:, 0] > 0): out = model.output.save() """ attachment_name = "nnsight_conditional_manager"
[docs] @classmethod def add( cls, graph: "Graph", condition: Union["Node", Any] ) -> "InterventionProxy": return graph.create(target=cls, proxy_value=True, args=[condition])
[docs] @classmethod def execute(cls, node: "Node") -> None: """Evaluate the node condition to a boolean. Args: node (Node): ConditionalProtocol node. """ cond_value = Node.prepare_inputs(node.args[0]) if cond_value: # cond_value is True node.set_value(True) return def update_conditioned_nodes(conditioned_node: "Node") -> None: """Recursively decrement the remaining listeners count of all the dependencies of conditioned nodes. Args: - conditioned_node (Node): Conditioned Node """ for listener in conditioned_node.listeners: for listener_arg in listener.arg_dependencies: listener_arg.remaining_listeners -= 1 if listener_arg.done() and listener_arg.redundant(): listener_arg.destroy() update_conditioned_nodes(listener) # If the condition value is ignore or evaluated to False, update conditioned nodes update_conditioned_nodes(node)
[docs] @classmethod def has_conditional(cls, graph: "Graph") -> bool: """Checks if the Intervention Graph has a ConditionalManager attached to it. Args: graph (Graph): Intervention Graph. Returns: bool: If graph has a ConditionalManager attachement. """ return cls.attachment_name in graph.attachments.keys()
[docs] @classmethod def get_conditional( cls, graph: "Graph", cond_node_name: str ) -> "Conditional": """Gets the ConditionalProtocol node by its name. Args: graph (Graph): Intervention Graph. cond_node_name (str): ConditionalProtocol name. Returns: Node: ConditionalProtocol Node. """ return graph.attachments[cls.attachment_name].get(cond_node_name)
[docs] @classmethod def push_conditional(cls, node: "Node") -> None: """Attaches a Conditional context to its graph. Args: node (Node): ConditionalProtocol of the current Conditional context. """ # All ConditionalProtocols associated with a graph are stored and managed by the ConditionalManager. # Create a ConditionalManager attachement to the graph if this the first Conditional context to be entered. if cls.attachment_name not in node.graph.attachments.keys(): node.graph.attachments[cls.attachment_name] = ConditionalManager() # Push the ConditionalProtocol node to the ConditionalManager node.graph.attachments[cls.attachment_name].push(node)
[docs] @classmethod def pop_conditional(cls, graph: "Graph") -> None: """Pops latest ConditionalProtocol from the ConditionalManager attached to the graph. Args: graph (Graph): Intervention Graph. """ graph.attachments[cls.attachment_name].pop()
[docs] @classmethod def peek_conditional(cls, graph: "Graph") -> "Node": """Gets the ConditionalProtocol node of the current Conditional context. Args: - graph (Graph): Graph. Returns: Node: ConditionalProtocol of the current Conditional context. """ return graph.attachments[cls.attachment_name].peek()
[docs] @classmethod def add_conditioned_node(cls, node: "Node") -> None: """Adds a conditioned Node the ConditionalManager attached to its graph. Args: - node (Node): Conditioned Node. """ node.graph.attachments[cls.attachment_name].add_conditioned_node(node)
[docs] @classmethod def is_node_conditioned(cls, node: "Node") -> bool: """Checks if the Node is conditoned by the current Conditional context. Args: - node (Node): Conditioned Node. Returns: bool: Whether the Node is conditioned. """ return node.graph.attachments[cls.attachment_name].is_node_conditioned( node )
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": { "color": "#FF8C00", "shape": "polygon", "sides": 6, }, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display
[docs] class UpdateProtocol(Protocol): """Protocol to update the value of an InterventionProxy node. .. codeb-block:: python with model.trace(input) as tracer: num = tracer.apply(int, 0) num.update(5) """
[docs] @classmethod def add( cls, node: "Node", new_value: Union[Node, Any] ) -> "InterventionProxy": """Creates an UpdateProtocol node. Args: node (Node): Original node. new_value (Union[Node, Any]): The update value. Returns: InterventionProxy: proxy. """ return node.create( target=cls, proxy_value=node.proxy_value, args=[ node, new_value, ], )
[docs] @classmethod def execute(cls, node: "Node") -> None: """Sets the value of the original node to the new value. If the original is defined outside the context, it uses the bridge to get the node. Args: node (Node): UpdateProtocol node. """ value_node, new_value = node.args new_value = Node.prepare_inputs(new_value) if value_node.target == BridgeProtocol: value_node._value = new_value bridge = BridgeProtocol.get_bridge(value_node.graph) lock_node = bridge.id_to_graph[value_node.args[0]].nodes[ value_node.args[1] ] value_node = lock_node.args[0] value_node._value = new_value node.set_value(new_value)
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ return { "node": {"color": "blue", "shape": "ellipse"}, # Node display "label": cls.__name__, "arg": defaultdict( lambda: {"color": "gray", "shape": "box"} ), # Non-node argument "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display