Source code for nnsight.contexts.Conditional
from __future__ import annotations
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Any, Union
from ..tracing import protocols
if TYPE_CHECKING:
from ..tracing.Node import Node
from ..tracing.Graph import Graph
from ..intervention import InterventionProxy
[docs]
class ConditionalManager():
""" A Graph attachement that manages the Conditional contexts defined within an Intervention Graph.
Attributes:
_conditional_dict (Dict[str, Node]): Mapping of ConditionalProtocol node name to Conditional context.
_conditioned_nodes_dict (Dict[str, Set[Node]]): Mapping of ConditionalProtocol node name to all the Nodes conditiones by it.
_conditional_stack (Dict): Stack of visited Conditional contexts' ConditonalProtocol nodes.
"""
def __init__(self):
self._conditional_nodes_dict: Dict[str, Node] = dict()
self._conditioned_nodes_dict: Dict[str, Set[Node]] = dict()
self._conditional_nodes_stack: List[Node] = list()
[docs]
def push(self, conditional_node: "Node") -> None:
""" Adds the Conditional to the stack of Conditional contexts.
Args:
conditional_node (Node): ConditionalProtocol node.
"""
self._conditional_nodes_dict[conditional_node.name] = conditional_node
self._conditioned_nodes_dict[conditional_node.name] = set()
self._conditional_nodes_stack.append(conditional_node)
[docs]
def get(self, key: str) -> Conditional:
""" Returns a ConditionalProtocol node.
Args:
key (str): ConditionalProtocol node name.
Returns:
Node: ConditionalProtocol node.
"""
return self._conditional_nodes_dict[key]
[docs]
def pop(self) -> None:
""" Pops the ConditionalProtocol node of the current Conditional context from the ConditionalManager stack. """
self._conditional_nodes_stack.pop()
[docs]
def peek(self) -> Optional["Node"]:
""" Gets the current Conditional context's ConditionalProtocol node.
Returns:
Optional[Node]: Lastest ConditonalProtocol node if the ConditionalManager stack is non-empty.
"""
if len(self._conditional_nodes_stack) > 0:
return self._conditional_nodes_stack[-1]
[docs]
def add_conditioned_node(self, node: "Node") -> None:
""" Adding a Node to the set of conditioned nodes by the current Conditonal context.
Args:
- node (Node): A node conditioned by the latest Conditional context.
"""
self._conditioned_nodes_dict[self.peek().name].add(node)
[docs]
def is_node_conditioned(self, node: "Node") -> bool:
""" Returns True if the Node argument is conditioned by the current Conditional context.
Args:
- node (Node): Node.
Returns:
bool: Whether the Node is conditioned.
"""
curr_conditioned_nodes_set = self._conditioned_nodes_dict[self.peek().name]
return (node in curr_conditioned_nodes_set)
[docs]
class Conditional(AbstractContextManager):
""" A context defined by a boolean condition, upon which the execution of all nodes defined from within is contingent.
Attributes:
_graph (Graph): Conditional Context graph.
_condition (Union[InterventionProxy, Any]): Condition.
"""
def __init__(self, graph: "Graph", condition: Union["InterventionProxy", Any]):
self._graph = graph
self._condition: Union["InterventionProxy", Any] = condition
def __enter__(self) -> Conditional:
conditional_node = protocols.ConditionalProtocol.add(self._graph, self._condition).node
protocols.ConditionalProtocol.push_conditional(conditional_node)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
protocols.ConditionalProtocol.pop_conditional(self._graph)