Source code for nnsight.contexts.session.Iterator

from __future__ import annotations

import weakref
from collections.abc import Iterable
from typing import TYPE_CHECKING, Iterable, Tuple, Union

from ... import util
from ...tracing import protocols
from ...tracing.Node import Node
from .. import check_for_dependencies, resolve_dependencies
from ..GraphBasedContext import GraphBasedContext

if TYPE_CHECKING:
    from ...intervention import InterventionProxy
    from ...tracing.Bridge import Bridge


[docs] class Iterator(GraphBasedContext): """Intervention loop context for iterative execution of an intervention graph. Attributes: - data (Iterable): Data to iterate over. - return_context (bool): If True, returns the Iterator object upon entering the Iterator context. """ def __init__( self, data: Iterable, *args, return_context: bool = False, **kwargs ) -> None: self.data: Iterable = data self._return_context: bool = return_context super().__init__(*args, **kwargs) def __enter__( self, ) -> Union["InterventionProxy", Tuple["InterventionProxy", Iterator]]: super().__enter__() self.data, has_dependencies = check_for_dependencies(self.data) proxy_value = None if self.graph.validate: proxy_value = util.apply( self.data, lambda node: node.args[0].proxy_value, Node ) if len(proxy_value) != 0: proxy_value = ( next(proxy_value) if hasattr(proxy_value, "__next__") else proxy_value[0] ) iter_item_proxy: "InterventionProxy" = protocols.ValueProtocol.add( self.graph, proxy_value ) if self._return_context: return iter_item_proxy, self else: return iter_item_proxy ### BACKENDS ######## def local_backend_execute(self) -> None: self.graph.reset() bridge: "Bridge" = protocols.BridgeProtocol.get_bridge(self.graph) bridge.locks += 1 data = resolve_dependencies(self.data) last_idx: int = len(data) - 1 for idx, item in enumerate(data): if idx != 0: self.graph.reset() last_iter = idx == last_idx if last_iter: bridge.locks -= 1 protocols.ValueProtocol.set( self.graph.nodes[f"{protocols.ValueProtocol.__name__}_0"], item ) try: self.graph.execute() except protocols.EarlyStopProtocol.EarlyStopException as e: break finally: graph = self.graph graph.alive = False if not isinstance(graph, weakref.ProxyType): self.graph = weakref.proxy(graph) def __repr__(self) -> str: return f"<{self.__class__.__name__} at {hex(id(self))}>"