from __future__ import annotations
from typing import (Callable, Dict, Generic, Iterator, List, Optional,
Tuple, Type, TypeVar, Union, overload)
from typing_extensions import Self
from ... import util
from ...util import NNsightError
from .. import protocols
from . import Node, NodeType, Proxy, ProxyType
[docs]
class Graph(Generic[NodeType, ProxyType]):
"""The `Graph` class represents a computation graph composed of individual `Node`s (operations).
It contains logic to both trace/build the computation graph, as well as how to execute it.
Sections of the graph can be divided into `SubGraphs`, but there will always be one root `Graph`.
The final `Node` of the graph (graph[-1]) should be the root `Node` which when executed, downstream executes the entire `Graph`.
Attributes:
node_class (Type[NodeType]): Class used to create `Node`s. Can be changed to add additional functionality to `Node's. Defaults to `Node`.
proxy_class (Type[ProxyType]): Class used to create `Proxy`s for 'Node's. Can be changed to add additional functionality to `Proxy's. Defaults to `Proxy`.
nodes (List[Node]): Ordered list of all `Node`s. Used to access `Nodes` via their index.
stack (List[Graph]): List of `Graph`s as a stack. Used to move `Node`s onto the most recent graph, as opposed to the `Graph` used to create the `Node`.
Managed outside the `Graph` class by the `Context` objects.
defer_stack (List[int]): List of `Node` indexes as a stack. Used to prevent destruction/memory cleanup of `Node`s whose index is less than the most recent index on the stack.
This happens when you have `Node`s that will be executed more than once. In a loop for example, you only want to destroy a `Node`s dependencies on the final iteration.
Also managed outside the `Graph` object.
alive (bool): If the `Graph` is "alive". Alive meaning its still open for tracing (adding new `Node`s). Set to False before executing the `Graph`.
"""
def __init__(
self,
node_class: Type[NodeType] = Node,
proxy_class: Type[ProxyType] = Proxy,
debug: bool = False,
) -> None:
self.node_class = node_class
self.proxy_class = proxy_class
self.debug = debug
self._alive = [True]
self.nodes: List[Node] = []
self.stack: List[Graph] = []
self.defer_stack: List[int] = []
@property
def alive(self) -> bool:
return self._alive[0]
@alive.setter
def alive(self, value: bool):
self._alive[0] = value
[docs]
def reset(self) -> None:
"""Resets the `Graph` to prepare for execution.
Simply resets all `Node`s in the `Graph`.
"""
for node in self:
node.reset()
[docs]
def execute(self) -> None:
"""Executes all `Node`s (operations) in this `Graph`.
Raises:
exception: If there is an exception during executing a `Node`. If so, we need to clean up the dependencies of `Node`s yet to be executed.
"""
err: Tuple[int, NNsightError] = None
for node in self:
try:
node.execute()
except NNsightError as e:
err = (node.index, e)
break
if err is not None:
defer_stack = self.defer_stack.copy()
self.defer_stack.clear()
self.clean(err[0])
self.defer_stack.extend(defer_stack)
raise err[1]
[docs]
def clean(self, start: Optional[int] = None):
"""Cleans up dependencies of `Node`s so their values are appropriately memory managed.
Cleans all `Node`s from start to end regardless if they are on this `Graph`.
Args:
start (Optional[int], optional): `Node` index to start cleaning up from. Defaults to None.
"""
if len(self) == 0:
return
if start is None:
start = self[0].index
end = self[-1].index + 1
# Loop over ALL nodes within the span of this graph.
for index in range(start, end):
node = self.nodes[index]
node.update_dependencies()
[docs]
def create(
self,
target: Union[Callable, protocols.Protocol],
*args,
redirect: bool = True,
**kwargs,
) -> ProxyType:
"""Creates a new `Node` using this `Graph`'s node_class and returns a `Proxy` for it with this `Graph`'s proxy_class.
Args:
target (Union[Callable, protocols.Protocol]): Target for the new `Node`.
redirect (bool, optional): If to move the newly created `Node` to the most recent `Graph` on the Graph.stack. Defaults to True.
Returns:
ProxyType: `Proxy` for newly created `Node`.
"""
# Redirection.
graph = self.stack[-1] if redirect and self.stack else self
return self.proxy_class(self.node_class(target, *args, graph=graph, **kwargs))
[docs]
def add(self, node: NodeType) -> None:
"""Adds a `Node` to this `Graph`.
Sets the `Node`'s .index attribute so it knows its own index within the entire computation graph.
Args:
node (NodeType): `Node` to add.
"""
# Tag the Node with its own index.
node.index = len(self.nodes)
# Add Node.
self.nodes.append(node)
[docs]
def copy(self, new_graph: Optional[Graph[NodeType, ProxyType]] = None) -> Graph:
"""Creates a shallow copy of the root `Graph` object.
Args:
new_graph (Optional[Graph[NodeType, ProxyType]], optional): `Graph` to copy into. Defaults to None and creates a new `Graph`.
Returns:
Graph: New `Graph`.
"""
if new_graph is None:
new_graph = Graph(node_class=self.node_class, proxy_class=self.proxy_class, debug=self.debug)
node = self[-1]
def process(arg: Union[Node, SubGraph]):
if isinstance(arg, SubGraph):
return arg.copy(parent=new_graph)
if arg.done:
return arg.value
new_graph.create(
node.target,
*util.apply(node.args, process, (Node, SubGraph)),
**util.apply(node.kwargs, process, (Node, SubGraph)),
)
return new_graph
### Magic Methods ######################################
def __str__(self) -> str:
result = f"{self.__class__.__name__}:\n"
for node in self:
result += f" {str(node)}\n"
return result
@overload
def __getitem__(self, key: int) -> Node: ...
@overload
def __getitem__(self, key: Union[slice, List[int]]) -> List[Node]: ...
def __getitem__(self, key: Union[int, Union[slice, List[int]]]) -> Union[Node, List[Node]]:
return self.nodes[key]
def __iter__(self) -> Iterator[Node]:
return iter(self.nodes)
def __len__(self) -> int:
return len(self.nodes)
[docs]
class SubGraph(Graph[NodeType, ProxyType]):
"""Represents a slice of the greater computation graph. It has a reference to the same underlying list of nodes and simply maintains a subset of node indexes.
Attributes:
subset (List[int]): Node indexes for `Node`s contained within this subgraph.
"""
def __init__(
self,
parent: GraphType,
subset: Optional[List[int]] = None,
):
"""Init
Args:
parent (GraphType): Graph to inherit attributes from.
subset (Optional[List[int]], optional): Subset to start from when loading a pre-defined `SubGraph`
"""
self.__dict__.update(parent.__dict__)
self.subset: List[int] = [] if subset is None else subset
def __getstate__(self):
return {
"nodes":self.nodes,
"subset":self.subset,
"defer_stack": self.defer_stack,
}
def __setstate__(self, state: Dict) -> None:
self.__dict__.update(state)
[docs]
def add(self, node: NodeType) -> None:
super().add(node)
# Also add the index to this SubGraph's subset upon adding.
self.subset.append(self.nodes[-1].index)
@overload
def __getitem__(self, key: int) -> Node: ...
@overload
def __getitem__(self, key: Union[slice, List[int]]) -> List[Node]: ...
def __getitem__(self, key: Union[int, Union[slice, List[int]]]) -> Union[Node, List[Node]]:
index = self.subset[key]
# We iterate over indexes and get their Nodes.
node = (
[self.nodes[idx] for idx in index]
if isinstance(index, list)
else self.nodes[index]
)
return node
def __iter__(self) -> Iterator[Node]:
return self.Iterator(self)
def __len__(self) -> int:
return len(self.subset)
[docs]
class Iterator(Iterator):
def __init__(self, subgraph: SubGraph[GraphType]) -> None:
self.subgraph = subgraph
self.start = 0
self.end = len(self.subgraph)
def __next__(self) -> NodeType:
if self.start < self.end:
value = self.subgraph[self.start]
self.start += 1
return value
raise StopIteration
[docs]
def copy(
self,
new_graph: Optional[SubGraph[NodeType, ProxyType]] = None,
parent: Optional[Graph[NodeType, ProxyType]] = None,
memo: Optional[Dict[int, NodeType]] = None,
) -> Self:
"""Creates a shallow copy of this SubGraph.
Args:
new_graph (Optional[SubGraph[NodeType, ProxyType]], optional): SubGraph to copy into. Defaults to None and creates a new SubGraph of the same type.
parent (Optional[Graph[NodeType, ProxyType]], optional): Parent graph. Defaults to None and will create a root `Graph` as the parent.
Returns:
Self: New graph.
"""
if parent is None:
parent = Graph(node_class=self.node_class, proxy_class=self.proxy_class)
if new_graph is None:
new_graph = type(self)(parent)
if memo is None:
memo = {}
def process(arg: Union[Node, SubGraph]):
if isinstance(arg, SubGraph):
return arg.copy(parent=new_graph, memo=memo)
if arg.done:
return arg.value
return new_graph.nodes[memo[arg.index]]
for node in self:
new_node = new_graph.create(
node.target,
*util.apply(node.args, process, (Node, SubGraph)),
**util.apply(node.kwargs, process, (Node, SubGraph)),
).node
memo[node.index] = new_node.index
return new_graph
# class MultiGraph(Graph):
# def __init__(self, *args, **kwargs) -> None:
# super().__init__(proxy_class, validate)
GraphType = TypeVar("GraphType", bound=SubGraph)