nnsight.tracing#
The `nnsight.tracing `module involves tracing operations in order to form a computation graph.
The Graph
class adds and stores operations as `Node`s .
The Node
class represents an individual operation in the Graph
.
The Proxy
class handles interactions from the user in order to create new Node`s. There is a `Proxy for each Node.
- class nnsight.tracing.Graph.Graph(proxy_class: ~typing.Type[~nnsight.tracing.Proxy.Proxy] = <class 'nnsight.tracing.Proxy.Proxy'>, validate: bool = False, sequential: bool = True, graph_id: int = None)[source]#
Represents a computation graph composed of
Nodes
.- nodes#
Mapping of Node name to node. Order is preserved and important when executing the graph sequentially.
- Type:
Dict[str,
Node
]
- attachments#
Dictionary object used to add extra functionality to this Graph. Used by Protocols.
- Type:
Dict[str, Any]
- proxy_class (Type[class
Proxy <nnsight.tracing.Proxy.Proxy>]): Proxy class to use. Defaults to class:Proxy <nnsight.tracing.Proxy.Proxy>.
- alive#
If this Graph should be considered alive (still tracing), and therefore added to. Used by `Node`s.
- Type:
bool
- name_idx#
Mapping of node target_name to number of previous names with the same target_name. Used so names are unique.
- Type:
Dict[str, int]
- validate#
If to execute nodes as they are added with their proxy values in order to check if the executions are possible and create a new proxy_value. Defaults to True.
When adding Node`s to the `Graph, if the Graph’s validate attribute is set to True, it will execute the Node’s target with its arguments’ .proxy_value attributes (essentially executing the Node, with FakeTensors in FakeTensorMode). This 1.) checks to see of the operation is valid on the tensor shape’s within the .proxy_value`s (this would catch an indexing error) and 2.) populating this new `Node’s .proxy_value attribute with the result.
- Type:
bool
- sequential#
If to run nodes sequentially when executing this graph.
When this is set to True, Node`s attempt to be executed in the order they were added to the `Graph when calling `.execute(). Otherwise, all nodes are checked to be fulfilled (they have no dependencies). These are root nodes and they are then executed in the order they were added.
- Type:
bool
- add(node: Node) None [source]#
Adds a Node to this Graph. Called by Nodes on __init__.
When adding Node`s to the `Graph, if the Graph’s validate attribute is set to True, it will execute the Node’s target with its arguments’ .proxy_value attributes (essentially executing the Node, with FakeTensors in FakeTensorMode). This 1.) checks to see of the operation is valid on the tensor shape’s within the .proxy_value`s (this would catch an indexing error) and 2.) populating this new `Node’s .proxy_value attribute with the result.
- Parameters:
node (Node) – Node to add.
- create(*args, **kwargs) Proxy [source]#
Creates a Node directly on this Graph and returns its Proxy.
- Returns:
Proxy for newly created Node.
- Return type:
- execute() None [source]#
Executes operations of Graph.
Executes all Node`s sequentially if `Graph.sequential. Otherwise execute only root `Node`s sequentially.
- reset() None [source]#
Resets the Graph to prepare for a new execution of the Graph. Calls .reset() on all Nodes.
- vis(title: str = 'graph', path: str = '.', display: bool = True, save: bool = False, recursive: bool = False)[source]#
Generates and saves a graphical visualization of the Intervention Graph using the pygraphviz library. :param title: Name of the Intervention Graph. Defaults to “graph”. :type title: str :param path: Directory path to save the graphic in. If None saves content to the current directory. :type path: str :param display: If True, shows the graph image. :type display: bool :param save: If True, saves the graph to the specified path. :type save: bool :param recursive: If True, recursively visualize sub-graphs. :type recursive: bool
- class nnsight.tracing.Node.Node(target: Union[Callable, str], graph: Graph = None, proxy_value: Any, args: List[Any] = None, kwargs: Dict[str, Any] = None, name: str = None)[source]#
A Node represents some action that should be carried out during execution of a Graph.
The class represents the operations (and the resulting output of said operations) they are tracing AND nodes that actually execute the operations when executing the Graph. The Nodes you are Tracing are the same object as the ones that are executed.
Nodes have a
.proxy_value
attribute that are a result of the tracing operation, and are FakeTensors allowing you to view the shape and datatypes of the actual resulting value that will be populated when the node’ operation is executed.Nodes carry out their operation in
.execute()
where their arguments are pre-processed and their value is set in.set_value()
.Arguments passed to the node are other nodes, where a bi-directional dependency graph is formed. During execution pre-processing, the arguments that are nodes and converted to their value.
Nodes are responsible for updating their listeners that one of their dependencies are completed, and if all are completed that they should execute. Similarly, nodes must inform their dependencies when one of their listeners has ceased “listening.” If the node has no listeners, it’s value is destroyed by calling
.destroy()
in order to free memory. When re-executing the same graph and therefore the same nodes, the remaining listeners and dependencies are reset on each node.
- name#
Unique name of node.
- Type:
str
- proxy_value#
Fake Tensor version of value. Used when graph has validate = True to check of Node actions are possible.
- Type:
Any
- target#
Function to execute or name of Protocol.
- Type:
Union[Callable, str]
- args#
Positional arguments. Defaults to None.
- Type:
List[Any], optional
- kwargs#
Keyword arguments. Defaults to None.
- Type:
Dict[str, Any], optional
- cond_dependency#
ConditionalProtocol node if this node was defined within a Conditional context.
- Type:
Optional[Node]
- value#
Actual value to be populated during execution.
- Type:
Any
- attached() bool [source]#
Checks to see if the weakref to the Graph is alive or dead.
- Returns:
Is Node attached.
- Return type:
bool
- create(target: Callable | str, proxy_value: Any, args: List[Any] = None, kwargs: Dict[str, Any] = None, name: str = None) Proxy | Any [source]#
We use Node.add vs Graph.add in case graph is dead. If the graph is dead, we assume this node is ready to execute and therefore we try and execute it and then return its value.
- Returns:
Proxy or value
- Return type:
Union[Proxy, Any]
- done() bool [source]#
Returns true if the value of this node has been set.
- Returns:
If done.
- Return type:
bool
- execute() None [source]#
Actually executes this node. Lets protocol execute if target is str. Else prepares args and kwargs and passes them to target. Gets output of target and sets the Node’s value to it.
- executed() bool [source]#
Returns true if remaining_dependencies is less than 0.
- Returns:
If executed.
- Return type:
bool
- fulfilled() bool [source]#
Returns true if remaining_dependencies is 0.
- Returns:
If fulfilled.
- Return type:
bool
- classmethod prepare_inputs(inputs: Any, device: device = None, proxy: bool = False) Any [source]#
Prepare arguments for executing this node’s target. Converts Nodes in args and kwargs to their value and moves tensors to correct device.
- Returns:
Prepared inputs.
- Return type:
Any
- redundant() bool [source]#
Returns true if remaining_listeners is 0.
- Returns:
If redundant.
- Return type:
bool
- set_value(value: Any) None [source]#
Sets the value of this Node and logs the event. Updates remaining_dependencies of listeners. If they are now fulfilled, execute them. Updates remaining_listeners of dependencies. If they are now redundant, destroy them.
- Parameters:
value (Any) – Value.
- property value: Any#
Property to return the value of this node.
- Returns:
The stored value of the node, populated during execution of the model.
- Return type:
Any
- Raises:
ValueError – If the underlying ._value is inspect._empty (therefore never set or destroyed).
- visualize(viz_graph: AGraph, recursive: bool, backend_name: str = '') str [source]#
Adds this node to the visualization graph and recursively visualizes its arguments and adds edges between them.
- Parameters:
viz_graph (-) – Visualization graph.
recursive (-) – If True, recursively visualizes all sub-graphs.
backend_name (-) – Inherent parent graph name for unique differentiation in recursive visualization.
- Returns:
name of this node.
- Return type:
str
- class nnsight.tracing.Proxy.Proxy(node: Node)[source]#
Proxy objects are the actual objects that interact with operations in order to update the graph to create new Nodes.
The operations that are traceable on base Proxy objects are many python built-in and magic methods, as well as implementing __torch_function__ to trace torch operations.
- property value: Any#
Property to return the value of this proxy’s node.
- Returns:
The stored value of the proxy, populated during execution of the model.
- Return type:
Any
- nnsight.tracing.Proxy.proxy_wrapper(fn) Callable [source]#
Wraps problematic functions (torch functions sometimes). Checks if any of its args are proxies. If so we return a proxy of the function. Otherwise just run the function.
- Parameters:
fn (function) – Function to wrap.
- Returns:
Wrapped function.
- Return type:
function
- class nnsight.tracing.protocols.ApplyModuleProtocol[source]#
Protocol that references some root model, and calls its .forward() method given some input. Using .forward() vs .__call__() means it wont trigger hooks. Uses an attachment to the Graph to store the model.
- classmethod add(graph: Graph, module_path: str, *args, hook=False, **kwargs) InterventionProxy [source]#
Creates and adds an ApplyModuleProtocol to the Graph. Assumes the attachment has already been added via ApplyModuleProtocol.set_module().
- Parameters:
graph (Graph) – Graph to add the Protocol to.
module_path (str) – Module path (model.module1.module2 etc), of module to apply from the root module.
- Returns:
ApplyModule Proxy.
- Return type:
- classmethod execute(node: Node) None [source]#
Executes the ApplyModuleProtocol on Node.
- Parameters:
node (Node) – ApplyModule Node.
- classmethod get_module(graph: Graph) Module [source]#
Returns the nnsight root module from an attachment on a Graph.
- Parameters:
graph (Graph) – Graph
- Returns:
Root Module.
- Return type:
torch.nn.Module
- class nnsight.tracing.protocols.BridgeProtocol[source]#
Protocol to connect two Graphs by grabbing a value from one and injecting it into another. Uses an attachment to store a Bridge object which references all relevant Graphs and their ordering.
- classmethod add(node: Node) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- classmethod execute(node: Node) None [source]#
Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing
- Parameters:
node (Node) – Node to execute using this Protocols execution logic.
- classmethod get_bridge(graph: Graph) Bridge [source]#
Gets Bridge object from a Graph. Assumes Bridge has been set as an attachment on this Graph via BridgeProtocol.set_bridge().
- classmethod has_bridge(graph: Graph) bool [source]#
Checks to see if a Bridge was added as an attachment on this Graph via BridgeProtocol.set_bridge().
- Parameters:
graph (Graph) – Graph
- Returns:
If Graph has Bridge attachment.
- Return type:
bool
- classmethod peek_graph(graph: Graph) Graph [source]#
Returns current Intervention Graph.
- Parameters:
graph (-) – Graph.
- Returns:
Graph.
- Return type:
- class nnsight.tracing.protocols.ConditionalProtocol[source]#
Protocol operating as a conditional statement. Uses the ConditionalManager attachment to handle all visited Conditional contexts within a single Intervention Graph. Evaluates the condition value of the Conditional as a boolean.
Example
- Setup:
Ex 1: The .save() on the model output will only be executed if the condition (x > 0) is evaluated to True.
Ex 2: The condition is a tensor boolean operation on the Envoy’s output InterventionProxy.
- classmethod add(graph: Graph, condition: Node | Any) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- classmethod add_conditioned_node(node: Node) None [source]#
Adds a conditioned Node the ConditionalManager attached to its graph.
- Parameters:
node (-) – Conditioned Node.
- classmethod execute(node: Node) None [source]#
Evaluate the node condition to a boolean.
- Parameters:
node (Node) – ConditionalProtocol node.
- classmethod get_conditional(graph: Graph, cond_node_name: str) Conditional [source]#
Gets the ConditionalProtocol node by its name.
- classmethod has_conditional(graph: Graph) bool [source]#
Checks if the Intervention Graph has a ConditionalManager attached to it.
- Parameters:
graph (Graph) – Intervention Graph.
- Returns:
If graph has a ConditionalManager attachement.
- Return type:
bool
- classmethod is_node_conditioned(node: Node) bool [source]#
Checks if the Node is conditoned by the current Conditional context.
- Parameters:
node (-) – Conditioned Node.
- Returns:
Whether the Node is conditioned.
- Return type:
bool
- classmethod peek_conditional(graph: Graph) Node [source]#
Gets the ConditionalProtocol node of the current Conditional context.
- Parameters:
graph (-) – Graph.
- Returns:
ConditionalProtocol of the current Conditional context.
- Return type:
- classmethod pop_conditional(graph: Graph) None [source]#
Pops latest ConditionalProtocol from the ConditionalManager attached to the graph.
- Parameters:
graph (Graph) – Intervention Graph.
- class nnsight.tracing.protocols.EarlyStopProtocol[source]#
Protocol to stop the execution of a model early.
- classmethod add(graph: Graph, stop_point_node: Node | None = None) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- class nnsight.tracing.protocols.GradProtocol[source]#
Protocol which adds a backwards hook via .register_hook() to a Tensor. The hook injects the gradients into the node’s value on hook execution. Nodes created via this protocol are relative to the next time .backward() was called during tracing allowing separate .grads to reference separate backwards passes:
Uses an attachment to store number of times .backward() has been called during tracing so a given .grad hook is only value injected at the appropriate backwards pass.
- classmethod add(node: Node) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- classmethod execute(node: Node) None [source]#
Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing
- Parameters:
node (Node) – Node to execute using this Protocols execution logic.
- class nnsight.tracing.protocols.LocalBackendExecuteProtocol[source]#
- classmethod add(object: LocalMixin, graph: Graph) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- class nnsight.tracing.protocols.LockProtocol[source]#
Simple Protocol who’s .execute() method does nothing. This means not calling .set_value() on the Node, therefore the Node won’t be destroyed.
- classmethod add(node: Node) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- class nnsight.tracing.protocols.Protocol[source]#
A Protocol represents some complex action a user might want to create a Node and Proxy for as well as add to a Graph. Unlike normal Node target execution, these have access to the Node itself and therefore the Graph, enabling more powerful functionality than with just functions and methods.
- classmethod add(*args, **kwargs) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- class nnsight.tracing.protocols.SwapProtocol[source]#
Protocol which adds an attachment to the Graph which can store some value. Used to replace (‘swap’) a value with another value.
- classmethod add(node: Node, value: Any) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- classmethod execute(node: Node) None [source]#
Class method to be implemented which contains the actual execution logic of the Protocol. By default, does nothing
- Parameters:
node (Node) – Node to execute using this Protocols execution logic.
- class nnsight.tracing.protocols.UpdateProtocol[source]#
Protocol to update the value of an InterventionProxy node.
- classmethod add(node: Node, new_value: Node | Any) InterventionProxy [source]#
Creates an UpdateProtocol node.
- Parameters:
- Returns:
proxy.
- Return type:
- class nnsight.tracing.protocols.ValueProtocol[source]#
- classmethod add(graph: Graph, default: Any = None) InterventionProxy [source]#
Class method to be implemented in order to add a Node of this Protocol to a Graph.
- class nnsight.tracing.Bridge.Bridge[source]#
A Bridge object collects and tracks multiple Graphs in order to facilitate interaction between them. The order in which Graphs added matters as Graphs can only get values from previous Graphs/
- bridged_nodes#
Mapping of bridged Nodes to the BridgeProtocol nodes representing them on different graphs.
- Type:
defaultdict[Node, defaultdict[int, Optional[InterventionProxy]]]
- locks#
Count of how many entities are depending on ties between graphs not to be released.
- Type:
int
- add_bridge_proxy(node: Node, bridge_proxy: Node) None [source]#
Adds a BridgeProtocol Proxy to the bridged nodes attribute.
- Parameters:
node (-) – Bridged Node.
bridge_proxy (-) – BridgeProtocol node proxy corresponding to the bridged node.
- get_bridge_proxy(node: Node, graph_id: int) InterventionProxy | None [source]#
Check if the argument Node is bridged within the specified graph and returns its corresponding BridgeProtocol node proxy.
- Parameters:
node (-) – Node.
graph_id (-) – Graph id.
- Returns:
BridgeProtocol node proxy if it exists.
- Return type:
Optional[InterventionProxy]
- get_graph(id: int) Graph [source]#
Returns graph from Bridge given the Graph’s id.
- Parameters:
id (int) – Id of Graph to get.
- Returns:
Graph.
- Return type: