nnsight.tracing#

The `nnsight.tracing `module involves tracing operations in order to form a computation graph.

The Graph class adds and stores operations as `Node`s .

The Node class represents an individual operation in the Graph.

The Proxy class handles interactions from the user in order to create new Node`s. There is a `Proxy for each Node.

class nnsight.tracing.Graph.Graph(proxy_class: ~typing.Type[~nnsight.tracing.Proxy.Proxy] = <class 'nnsight.tracing.Proxy.Proxy'>, validate: bool = False, sequential: bool = True, graph_id: int = None)[source]#

Represents a computation graph composed of Nodes.

nodes#

Mapping of Node name to node. Order is preserved and important when executing the graph sequentially.

Type:

Dict[str, Node]

attachments#

Dictionary object used to add extra functionality to this Graph. Used by Protocols.

Type:

Dict[str, Any]

proxy_class (Type[class

Proxy <nnsight.tracing.Proxy.Proxy>]): Proxy class to use. Defaults to class:Proxy <nnsight.tracing.Proxy.Proxy>.

alive#

If this Graph should be considered alive (still tracing), and therefore added to. Used by `Node`s.

Type:

bool

name_idx#

Mapping of node target_name to number of previous names with the same target_name. Used so names are unique.

Type:

Dict[str, int]

validate#

If to execute nodes as they are added with their proxy values in order to check if the executions are possible and create a new proxy_value. Defaults to True.

When adding Node`s to the `Graph, if the Graph’s validate attribute is set to True, it will execute the Node’s target with its arguments’ .proxy_value attributes (essentially executing the Node, with FakeTensors in FakeTensorMode). This 1.) checks to see of the operation is valid on the tensor shape’s within the .proxy_value`s (this would catch an indexing error) and 2.) populating this new `Node’s .proxy_value attribute with the result.

Type:

bool

sequential#

If to run nodes sequentially when executing this graph.

When this is set to True, Node`s attempt to be executed in the order they were added to the `Graph when calling `.execute(). Otherwise, all nodes are checked to be fulfilled (they have no dependencies). These are root nodes and they are then executed in the order they were added.

Type:

bool

add(node: Node) None[source]#

Adds a Node to this Graph. Called by Nodes on __init__.

When adding Node`s to the `Graph, if the Graph’s validate attribute is set to True, it will execute the Node’s target with its arguments’ .proxy_value attributes (essentially executing the Node, with FakeTensors in FakeTensorMode). This 1.) checks to see of the operation is valid on the tensor shape’s within the .proxy_value`s (this would catch an indexing error) and 2.) populating this new `Node’s .proxy_value attribute with the result.

Parameters:

node (Node) – Node to add.

copy()[source]#

Copy constructs a new Graph and then recursively creates new Nodes on the graph.

create(*args, **kwargs) Proxy[source]#

Creates a Node directly on this Graph and returns its Proxy.

Returns:

Proxy for newly created Node.

Return type:

Proxy

execute() None[source]#

Executes operations of Graph.

Executes all Node`s sequentially if `Graph.sequential. Otherwise execute only root `Node`s sequentially.

reset() None[source]#

Resets the Graph to prepare for a new execution of the Graph. Calls .reset() on all Nodes.

vis(title: str = 'graph', path: str = '.', display: bool = True, save: bool = False, recursive: bool = False)[source]#

Generates and saves a graphical visualization of the Intervention Graph using the pygraphviz library. :param title: Name of the Intervention Graph. Defaults to “graph”. :type title: str :param path: Directory path to save the graphic in. If None saves content to the current directory. :type path: str :param display: If True, shows the graph image. :type display: bool :param save: If True, saves the graph to the specified path. :type save: bool :param recursive: If True, recursively visualize sub-graphs. :type recursive: bool

class nnsight.tracing.Node.Node(target: Union[Callable, str], graph: Graph = None, proxy_value: Any, args: List[Any] = None, kwargs: Dict[str, Any] = None, name: str = None)[source]#

A Node represents some action that should be carried out during execution of a Graph.

The class represents the operations (and the resulting output of said operations) they are tracing AND nodes that actually execute the operations when executing the Graph. The Nodes you are Tracing are the same object as the ones that are executed.

  • Nodes have a .proxy_value attribute that are a result of the tracing operation, and are FakeTensors allowing you to view the shape and datatypes of the actual resulting value that will be populated when the node’ operation is executed.

  • Nodes carry out their operation in .execute() where their arguments are pre-processed and their value is set in .set_value().

  • Arguments passed to the node are other nodes, where a bi-directional dependency graph is formed. During execution pre-processing, the arguments that are nodes and converted to their value.

  • Nodes are responsible for updating their listeners that one of their dependencies are completed, and if all are completed that they should execute. Similarly, nodes must inform their dependencies when one of their listeners has ceased “listening.” If the node has no listeners, it’s value is destroyed by calling .destroy() in order to free memory. When re-executing the same graph and therefore the same nodes, the remaining listeners and dependencies are reset on each node.

name#

Unique name of node.

Type:

str

graph#

Weak reference to parent Graph object.

Type:

Graph

proxy#

Weak reference to Proxy created from this Node.

Type:

Proxy

proxy_value#

Fake Tensor version of value. Used when graph has validate = True to check of Node actions are possible.

Type:

Any

target#

Function to execute or name of Protocol.

Type:

Union[Callable, str]

args#

Positional arguments. Defaults to None.

Type:

List[Any], optional

kwargs#

Keyword arguments. Defaults to None.

Type:

Dict[str, Any], optional

listeners#

Nodes that depend on this node.

Type:

List[Node]

arg_dependencies#

Nodes that this node depends on.

Type:

List[Node]

cond_dependency#

ConditionalProtocol node if this node was defined within a Conditional context.

Type:

Optional[Node]

value#

Actual value to be populated during execution.

Type:

Any

attached() bool[source]#

Checks to see if the weakref to the Graph is alive or dead.

Returns:

Is Node attached.

Return type:

bool

clean() None[source]#

Clean up dependencies during early execution stop

create(target: Callable | str, proxy_value: Any, args: List[Any] = None, kwargs: Dict[str, Any] = None, name: str = None) Proxy | Any[source]#

We use Node.add vs Graph.add in case graph is dead. If the graph is dead, we assume this node is ready to execute and therefore we try and execute it and then return its value.

Returns:

Proxy or value

Return type:

Union[Proxy, Any]

destroy() None[source]#

Removes the reference to the node’s value and logs it’s destruction.

done() bool[source]#

Returns true if the value of this node has been set.

Returns:

If done.

Return type:

bool

execute() None[source]#

Actually executes this node. Lets protocol execute if target is str. Else prepares args and kwargs and passes them to target. Gets output of target and sets the Node’s value to it.

executed() bool[source]#

Returns true if remaining_dependencies is less than 0.

Returns:

If executed.

Return type:

bool

fulfilled() bool[source]#

Returns true if remaining_dependencies is 0.

Returns:

If fulfilled.

Return type:

bool

classmethod prepare_inputs(inputs: Any, device: device = None, proxy: bool = False) Any[source]#

Prepare arguments for executing this node’s target. Converts Nodes in args and kwargs to their value and moves tensors to correct device.

Returns:

Prepared inputs.

Return type:

Any

preprocess() None[source]#

Preprocess Node.args and Node.kwargs.

redundant() bool[source]#

Returns true if remaining_listeners is 0.

Returns:

If redundant.

Return type:

bool

reset() None[source]#

Resets this Nodes remaining_listeners and remaining_dependencies.

set_value(value: Any) None[source]#

Sets the value of this Node and logs the event. Updates remaining_dependencies of listeners. If they are now fulfilled, execute them. Updates remaining_listeners of dependencies. If they are now redundant, destroy them.

Parameters:

value (Any) – Value.

property value: Any#

Property to return the value of this node.

Returns:

The stored value of the node, populated during execution of the model.

Return type:

Any

Raises:

ValueError – If the underlying ._value is inspect._empty (therefore never set or destroyed).

visualize(viz_graph: AGraph, recursive: bool, backend_name: str = '') str[source]#

Adds this node to the visualization graph and recursively visualizes its arguments and adds edges between them.

Parameters:
  • viz_graph (-) – Visualization graph.

  • recursive (-) – If True, recursively visualizes all sub-graphs.

  • backend_name (-) – Inherent parent graph name for unique differentiation in recursive visualization.

Returns:

name of this node.

Return type:

  • str

class nnsight.tracing.Proxy.Proxy(node: Node)[source]#

Proxy objects are the actual objects that interact with operations in order to update the graph to create new Nodes.

The operations that are traceable on base Proxy objects are many python built-in and magic methods, as well as implementing __torch_function__ to trace torch operations.

node#

This proxy’s node.

Type:

Node

property value: Any#

Property to return the value of this proxy’s node.

Returns:

The stored value of the proxy, populated during execution of the model.

Return type:

Any

nnsight.tracing.Proxy.proxy_wrapper(fn) Callable[source]#

Wraps problematic functions (torch functions sometimes). Checks if any of its args are proxies. If so we return a proxy of the function. Otherwise just run the function.

Parameters:

fn (function) – Function to wrap.

Returns:

Wrapped function.

Return type:

function

class nnsight.tracing.protocols.ApplyModuleProtocol[source]#

Protocol that references some root model, and calls its .forward() method given some input. Using .forward() vs .__call__() means it wont trigger hooks. Uses an attachment to the Graph to store the model.

classmethod add(graph: Graph, module_path: str, *args, hook=False, **kwargs) InterventionProxy[source]#

Creates and adds an ApplyModuleProtocol to the Graph. Assumes the attachment has already been added via ApplyModuleProtocol.set_module().

Parameters:
  • graph (Graph) – Graph to add the Protocol to.

  • module_path (str) – Module path (model.module1.module2 etc), of module to apply from the root module.

Returns:

ApplyModule Proxy.

Return type:

InterventionProxy

classmethod execute(node: Node) None[source]#

Executes the ApplyModuleProtocol on Node.

Parameters:

node (Node) – ApplyModule Node.

classmethod get_module(graph: Graph) Module[source]#

Returns the nnsight root module from an attachment on a Graph.

Parameters:

graph (Graph) – Graph

Returns:

Root Module.

Return type:

torch.nn.Module

classmethod set_module(graph: Graph, module: Module) None[source]#

Sets the nnsight root module as an attachment on a Graph.

Parameters:
  • graph (Graph) – Graph.

  • module (torch.nn.Module) – Root module.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.BridgeProtocol[source]#

Protocol to connect two Graphs by grabbing a value from one and injecting it into another. Uses an attachment to store a Bridge object which references all relevant Graphs and their ordering.

exception BridgeException[source]#
classmethod add(node: Node) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod get_bridge(graph: Graph) Bridge[source]#

Gets Bridge object from a Graph. Assumes Bridge has been set as an attachment on this Graph via BridgeProtocol.set_bridge().

Parameters:

graph (Graph) – Graph.

Returns:

Bridge.

Return type:

Bridge

classmethod has_bridge(graph: Graph) bool[source]#

Checks to see if a Bridge was added as an attachment on this Graph via BridgeProtocol.set_bridge().

Parameters:

graph (Graph) – Graph

Returns:

If Graph has Bridge attachment.

Return type:

bool

classmethod peek_graph(graph: Graph) Graph[source]#

Returns current Intervention Graph.

Parameters:

graph (-) – Graph.

Returns:

Graph.

Return type:

Graph

classmethod set_bridge(graph: Graph, bridge: Bridge) None[source]#

Sets Bridge object as an attachment on a Graph.

Parameters:
classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.ConditionalProtocol[source]#

Protocol operating as a conditional statement. Uses the ConditionalManager attachment to handle all visited Conditional contexts within a single Intervention Graph. Evaluates the condition value of the Conditional as a boolean.

Example

Setup:

Ex 1: The .save() on the model output will only be executed if the condition (x > 0) is evaluated to True.

Ex 2: The condition is a tensor boolean operation on the Envoy’s output InterventionProxy.

classmethod add(graph: Graph, condition: Node | Any) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod add_conditioned_node(node: Node) None[source]#

Adds a conditioned Node the ConditionalManager attached to its graph.

Parameters:

node (-) – Conditioned Node.

classmethod execute(node: Node) None[source]#

Evaluate the node condition to a boolean.

Parameters:

node (Node) – ConditionalProtocol node.

classmethod get_conditional(graph: Graph, cond_node_name: str) Conditional[source]#

Gets the ConditionalProtocol node by its name.

Parameters:
  • graph (Graph) – Intervention Graph.

  • cond_node_name (str) – ConditionalProtocol name.

Returns:

ConditionalProtocol Node.

Return type:

Node

classmethod has_conditional(graph: Graph) bool[source]#

Checks if the Intervention Graph has a ConditionalManager attached to it.

Parameters:

graph (Graph) – Intervention Graph.

Returns:

If graph has a ConditionalManager attachement.

Return type:

bool

classmethod is_node_conditioned(node: Node) bool[source]#

Checks if the Node is conditoned by the current Conditional context.

Parameters:

node (-) – Conditioned Node.

Returns:

Whether the Node is conditioned.

Return type:

bool

classmethod peek_conditional(graph: Graph) Node[source]#

Gets the ConditionalProtocol node of the current Conditional context.

Parameters:

graph (-) – Graph.

Returns:

ConditionalProtocol of the current Conditional context.

Return type:

Node

classmethod pop_conditional(graph: Graph) None[source]#

Pops latest ConditionalProtocol from the ConditionalManager attached to the graph.

Parameters:

graph (Graph) – Intervention Graph.

classmethod push_conditional(node: Node) None[source]#

Attaches a Conditional context to its graph.

Parameters:

node (Node) – ConditionalProtocol of the current Conditional context.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.EarlyStopProtocol[source]#

Protocol to stop the execution of a model early.

exception EarlyStopException[source]#
classmethod add(graph: Graph, stop_point_node: Node | None = None) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.GradProtocol[source]#

Protocol which adds a backwards hook via .register_hook() to a Tensor. The hook injects the gradients into the node’s value on hook execution. Nodes created via this protocol are relative to the next time .backward() was called during tracing allowing separate .grads to reference separate backwards passes:

Uses an attachment to store number of times .backward() has been called during tracing so a given .grad hook is only value injected at the appropriate backwards pass.

classmethod add(node: Node) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod increment(graph: Graph)[source]#

Increments the backward_idx attachment to track the number of times .backward() is called in tracing for this Graph.

Parameters:

graph (Graph) – Graph.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.LocalBackendExecuteProtocol[source]#
classmethod add(object: LocalMixin, graph: Graph) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.LockProtocol[source]#

Simple Protocol who’s .execute() method does nothing. This means not calling .set_value() on the Node, therefore the Node won’t be destroyed.

classmethod add(node: Node) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.Protocol[source]#

A Protocol represents some complex action a user might want to create a Node and Proxy for as well as add to a Graph. Unlike normal Node target execution, these have access to the Node itself and therefore the Graph, enabling more powerful functionality than with just functions and methods.

classmethod add(*args, **kwargs) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node)[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.SwapProtocol[source]#

Protocol which adds an attachment to the Graph which can store some value. Used to replace (‘swap’) a value with another value.

classmethod add(node: Node, value: Any) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod get_swap(graph: Graph, value: Any) Any[source]#

Checks if a swap exists on a Graph. If so get and return it, otherwise return the given value.

Parameters:
  • graph (Graph) – Graph

  • value (Any) – Default value.

Returns:

Default value or swap value.

Return type:

Any

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.UpdateProtocol[source]#

Protocol to update the value of an InterventionProxy node.

classmethod add(node: Node, new_value: Node | Any) InterventionProxy[source]#

Creates an UpdateProtocol node.

Parameters:
  • node (Node) – Original node.

  • new_value (Union[Node, Any]) – The update value.

Returns:

proxy.

Return type:

InterventionProxy

classmethod execute(node: Node) None[source]#
Sets the value of the original node to the new value.

If the original is defined outside the context, it uses the bridge to get the node.

Parameters:

node (Node) – UpdateProtocol node.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.protocols.ValueProtocol[source]#
classmethod add(graph: Graph, default: Any = None) InterventionProxy[source]#

Class method to be implemented in order to add a Node of this Protocol to a Graph.

classmethod execute(node: Node) None[source]#

Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing

Parameters:

node (Node) – Node to execute using this Protocols execution logic.

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.tracing.Bridge.Bridge[source]#

A Bridge object collects and tracks multiple Graphs in order to facilitate interaction between them. The order in which Graphs added matters as Graphs can only get values from previous Graphs/

id_to_graph#

Mapping of graph id to Graph.

Type:

Dict[int, Graph]

graph_stack#

Stack of visited Intervention Graphs.

Type:

List[Graph]

bridged_nodes#

Mapping of bridged Nodes to the BridgeProtocol nodes representing them on different graphs.

Type:

defaultdict[Node, defaultdict[int, Optional[InterventionProxy]]]

locks#

Count of how many entities are depending on ties between graphs not to be released.

Type:

int

add(graph: Graph) None[source]#

Adds Graph to Bridge.

Parameters:

graph (Graph) – Graph to add.

add_bridge_proxy(node: Node, bridge_proxy: Node) None[source]#

Adds a BridgeProtocol Proxy to the bridged nodes attribute.

Parameters:
  • node (-) – Bridged Node.

  • bridge_proxy (-) – BridgeProtocol node proxy corresponding to the bridged node.

get_bridge_proxy(node: Node, graph_id: int) InterventionProxy | None[source]#

Check if the argument Node is bridged within the specified graph and returns its corresponding BridgeProtocol node proxy.

Parameters:
  • node (-) – Node.

  • graph_id (-) – Graph id.

Returns:

BridgeProtocol node proxy if it exists.

Return type:

Optional[InterventionProxy]

get_graph(id: int) Graph[source]#

Returns graph from Bridge given the Graph’s id.

Parameters:

id (int) – Id of Graph to get.

Returns:

Graph.

Return type:

Graph

peek_graph() Graph[source]#

Gets the current hierarchical Graph in the Bridge.

Returns:

Graph of current context.

Return type:

Graph

pop_graph() None[source]#

Pops the last Graph in the graph stack.