nnsight.tracing#

The nnsight.tracing module involves tracking operations on a torch.nn.Module in order to form a computation graph.

The Graph class adds and stores each operation or node. It has a ‘module’ node which acts as the root object on which the computation graph is performing. It has ‘argument’ nodes which act as entry points for data to flow into the graph.

The Node class represents nodes in the graph. The class represents the operations (and the resulting output of said operations) they are tracing AND nodes that actually execute the operations when running the graph on a model.

  • Nodes have a .proxy_value attribute that are a result of the tracing operation, and are ‘meta’ tensors allowing you to view the shape and datatypes of the actual resulting value that will be populated when the node’ operation is executed.

  • Nodes carry out their operation in .execute() where their arguments are pre-processed and their value is set in .set_value().

  • Arguments passed to the node are other nodes, where a bi-directional dependency graph is formed. During execution pre-processing, the arguments that are nodes and converted to their value.

  • Nodes are responsible for updating their listeners that one of their dependencies are completed, and if all are completed that they should execute. Similarly, nodes must inform their dependencies when one of their listeners has ceased “listening.” If the node has no listeners, it’s value is destroyed by calling .destroy() in order to free memory. When re-executing the same graph and therefore the same nodes, the remaining listeners and dependencies are reset on each node.

Proxy class objects are the actual objects that interact with operations in order to update the graph to create new nodes. Each Node has it’s own proxy object. The operations that are traceable on base Proxy objects are many python built-in and magic methods, as well as implementing __torch_function__ to trace torch operations. When an operation is traced, arguments are converted into their ‘meta’ tensor values and ran through the operation in order to find out the shames and data types of the result.

class nnsight.tracing.Graph.Graph(module: ~torch.nn.modules.module.Module, proxy_class: ~typing.Type[~nnsight.tracing.Proxy.Proxy] = <class 'nnsight.tracing.Proxy.Proxy'>, validate: bool = True)[source]#

Represents a computation graph involving a torch.nn.module.

Reserved target names:

  • ‘argument’ : There can be multiple argument nodes. Their first argument needs to be the argument name which acts as a key in graph.argument_node_names which maps to a list of names for nodes that depend on it’s value. These nodes values need to be set outside of the computation graph as entry points to kick of the execution of the graph.

  • ‘swap’ : swp nodes indicate populating the graph’s swap attribute. When executed, its value is not set. Logic involving the swap value should set its value after using it.

  • ‘null’ : Null nodes never get executed and therefore their listeners never get destroyed.

  • ‘grad’ : grad nodes indicates adding a .register_hook() to a tensor proxy

validate#

If to execute nodes as they are added with their proxy values in order to check if the executions are possible (i.e shape errors etc). Defaults to True.

Type:

bool

proxy_class#

Proxy class to use. Defaults to Proxy.

Type:

Type[Proxy]

tracing#

If currently tracing operations

Type:

bool

nodes#

Mapping of node name to node.

Type:

Dict[str, Node]

name_idx#

Mapping of node target_name to number of previous names with the same target_name. Used so names are unique.

Type:

Dict[str, int]

module_proxy#

Proxy for given root meta module.

Type:

Proxy

argument_node_names#

Map of name of argument to name of nodes that depend on it.

Type:

Dict[str, List[str]]

generation_idx#

Current generation index.

Type:

int

swap#

Attribute to store swap values from ‘swap’ nodes.

Type:

Node

add(target: Callable | str, value: Any, args: List[Any] | None = None, kwargs: Dict[str, Any] | None = None, name: str | None = None) Proxy[source]#

Adds a node to the graph and returns it’s proxy.

Parameters:
  • value (Any) – ‘meta’ proxy value used for tracing the shapes and values.

  • target (Union[Callable, str]) – Either the function to call for this node, or a string of a reserved target name.

  • args (List[Any], optional) – Positional arguments of node. Defaults to None.

  • kwargs (Dict[str, Any], optional) – Keyword arguments of node. Defaults to None.

  • name (str, optional) – Unique name of node. Otherwise pull name from target Defaults to None.

Returns:

Proxy for the added node.

Return type:

Proxy

Raises:

ValueError – If more than one reserved “module” nodes are added to the graph.

compile(module: Module) None[source]#

Re-compile graph to prepare for a new execution of the graph.

Compiles all nodes.

Finally, sets the “nnsight_root_module” node’s value to the module that is being interleaved.

Parameters:

module (torch.nn.Module) – Module to be considered the root module of the graph.

class nnsight.tracing.Node.Node(name: str, graph: Graph, value: Any, target: Callable | str, args: List[Any] = None, kwargs: Dict[str, Any] = None)[source]#

A Node represents some action that should be carried out during execution of a Graph.

name#

Unique name of node.

Type:

str

graph#

Reference to parent Graph object.

Type:

Graph

proxy_value#

Meta version of value. Used when graph has validate = True.

Type:

Any

target#

Function to execute or reserved string name.

Type:

Union[Callable, str]

args#

Positional arguments. Defaults to None.

Type:

List[Any], optional

kwargs#

Keyword arguments. Defaults to None.

Type:

Dict[str, Any], optional

listeners#

Nodes that depend on this node.

Type:

List[Node]

dependencies#

Nodes that this node depends on.

Type:

List[Node]

value#

Actual value to be populated during execution.

Type:

Any

add(target: Callable | str, value: Any, args: List[Any] | None = None, kwargs: Dict[str, Any] | None = None, name: str | None = None) Proxy[source]#

We use Node.add vs Graph.add in case the weakref to Graph is gone.

Returns:

Proxy

Return type:

Proxy

compile() None[source]#

Resets this Nodes remaining_listeners and remaining_dependencies and sets its value to None.

destroy() None[source]#

Removes the reference to the node’s value and logs it’s destruction.

done() bool[source]#

Returns true if the value of this node has been set.

Returns:

If done.

Return type:

bool

execute() None[source]#

Actually executes this node. If target is ‘null’ do nothing. Prepares args and kwargs and passed them to target.

fulfilled() bool[source]#

Returns true if remaining_dependencies is 0.

Returns:

If fulfilled.

Return type:

bool

is_tracing() bool[source]#

Checks to see if the weakref to the Graph is alive and tracing or dead.

Returns:

Is Graph tracing.

Return type:

bool

classmethod prepare_inputs(inputs: Any, device: device | None = None, proxy: bool = False) Any[source]#

Prepare arguments for executing this node’s target. Converts Nodes in args and kwargs to their value and moves tensors to correct device. :returns: Prepared inputs. :rtype: Any

redundant() bool[source]#

Returns true if remaining_listeners is 0.

Returns:

If redundant.

Return type:

bool

set_value(value: Any)[source]#

Sets the value of this Node and logs the event. Updates remaining_dependencies of listeners. If they are now fulfilled, execute them. Updates remaining_listeners of dependencies. If they are now redundant, destroy them.

Parameters:

value (Any) – Value.

class nnsight.tracing.Proxy.Proxy(node: Node)[source]#

Proxy objects are the actual objects that interact with operations in order to update the graph to create new nodes.

The operations that are traceable on base Proxy objects are many python built-in and magic methods, as well as implementing __torch_function__ to trace torch operations.

node#

This proxy’s node.

Type:

Node

property value: Any#

Property to return the value of this proxy’s node.

Returns:

The stored value of the proxy, populated during execution of the model.

Return type:

Any

nnsight.tracing.Proxy.proxy_wrapper(fn) Callable[source]#

Wraps problematic functions (torch functions sometimes). Checks if any of its args are proxies. If so we return a proxy of the function. Otherwise just run the function.

Parameters:

fn (function) – Function to wrap.

Returns:

Wrapped function.

Return type:

function