Source code for nnsight.tracing.Bridge

from collections import OrderedDict, defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional

from . import protocols

if TYPE_CHECKING:
    from ..intervention import InterventionProxy
    from .Graph import Graph
    from .Node import Node


[docs] class Bridge: """A Bridge object collects and tracks multiple Graphs in order to facilitate interaction between them. The order in which Graphs added matters as Graphs can only get values from previous Graphs/ Attributes: id_to_graph (Dict[int, Graph]): Mapping of graph id to Graph. graph_stack (List[Graph]): Stack of visited Intervention Graphs. bridged_nodes (defaultdict[Node, defaultdict[int, Optional[InterventionProxy]]]): Mapping of bridged Nodes to the BridgeProtocol nodes representing them on different graphs. locks (int): Count of how many entities are depending on ties between graphs not to be released. """ def __init__(self) -> None: # Mapping fro Graph if to Graph. self.id_to_graph: Dict[int, "Graph"] = OrderedDict() # Stack to keep track of most inner current graph self.graph_stack: List["Graph"] = list() self.bridged_nodes: defaultdict["Node", defaultdict[int, "InterventionProxy"]] = defaultdict(lambda: defaultdict(lambda: None)) self.locks = 0 @property def release(self) -> bool: return not self.locks
[docs] def add(self, graph: "Graph") -> None: """Adds Graph to Bridge. Args: graph (Graph): Graph to add. """ protocols.BridgeProtocol.set_bridge(graph, self) self.id_to_graph[graph.id] = graph self.graph_stack.append(graph)
[docs] def peek_graph(self) -> "Graph": """Gets the current hierarchical Graph in the Bridge. Returns: Graph: Graph of current context. """ return self.graph_stack[-1]
[docs] def pop_graph(self) -> None: """Pops the last Graph in the graph stack.""" self.graph_stack.pop()
[docs] def get_graph(self, id: int) -> "Graph": """Returns graph from Bridge given the Graph's id. Args: id (int): Id of Graph to get. Returns: Graph: Graph. """ return self.id_to_graph[id]
[docs] def add_bridge_proxy(self, node: "Node", bridge_proxy: "Node") -> None: """ Adds a BridgeProtocol Proxy to the bridged nodes attribute. Args: - node (Node): Bridged Node. - bridge_proxy (Node): BridgeProtocol node proxy corresponding to the bridged node. """ self.bridged_nodes[node][bridge_proxy.node.graph.id] = bridge_proxy
[docs] def get_bridge_proxy(self, node: "Node", graph_id: int) -> Optional["InterventionProxy"]: """ Check if the argument Node is bridged within the specified graph and returns its corresponding BridgeProtocol node proxy. Args: - node (Node): Node. - graph_id (int): Graph id. Returns: Optional[InterventionProxy]: BridgeProtocol node proxy if it exists. """ return self.bridged_nodes[node][graph_id]