Source code for nnsight.tracing.contexts.iterator
import copy
from typing import Collection, Dict, Any
from ...tracing.graph import SubGraph
from ...tracing.graph import Node
from ...tracing.graph import Proxy
from . import Context
from ..protocols import VariableProtocol, StopProtocol
[docs]
class Iterator(Context[SubGraph]):
def __init__(self, collection: Collection, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.args = [collection]
def __enter__(self) -> Proxy:
super().__enter__()
return VariableProtocol.add(self.graph)
@classmethod
def execute(cls, node: Node):
graph, collection = node.args
graph: SubGraph
collection: Collection
collection = node.prepare_inputs(collection)
variable_node = next(iter(graph))
graph.defer_stack.append(variable_node.index)
for idx, value in enumerate(copy.copy(collection)):
VariableProtocol.set(variable_node, value)
if idx == len(collection) - 1:
graph.defer_stack.pop()
graph.reset()
try:
graph.execute()
except Exception as e:
if idx != len(collection) - 1:
graph.defer_stack.pop()
if not isinstance(e, StopProtocol.StopException):
raise e
else:
break
node.set_value(None)
[docs]
@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""
default_style = super().style()
default_style["node"] = {"color": "blue", "shape": "polygon", "sides": 6}
return default_style