Source code for nnsight.tracing.contexts.conditional
from __future__ import annotations
from typing import Any, Dict, Optional
from ...tracing.graph import NodeType, SubGraph
from ..contexts import Context
[docs]
class Condition(Context[SubGraph]):
def __init__(
self, condition: Optional[NodeType], branch: Optional[NodeType] = None, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.args = [condition, branch]
self.index = None
def else_(self, condition: Optional[Any] = None):
return Condition(
condition,
branch=self.graph.nodes[self.index],
parent=self.graph.stack[-1],
)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
self.index = self.graph.nodes[-1].index
@classmethod
def execute(cls, node: NodeType):
graph, condition, branch = node.args
graph: SubGraph
condition: Any
condition, branch = node.prepare_inputs((condition, branch))
# else case has a True condition
if condition is None and not branch:
condition = True
if not branch and condition:
graph.reset()
graph.execute()
node.set_value(True)
else:
graph.clean()
node.set_value(branch)
[docs]
@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""
default_style = super().style()
default_style["node"] = {"color": "#FF8C00", "shape": "polygon", "sides": 6}
default_style["edge"][2] = {"style": "solid", "label": "branch", "color": "#FF8C00", "fontsize": 10}
return default_style