Source code for nnsight.contexts.session.Session

from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any, Dict, Iterable

from ...tracing.Bridge import Bridge
from ...tracing.Graph import Graph
from ...tracing.protocols import EarlyStopProtocol
from ..backends import Backend, BridgeBackend, RemoteMixin
from ..GraphBasedContext import GraphBasedContext
from .Iterator import Iterator

if TYPE_CHECKING:
    from ...models.mixins import RemoteableMixin
    from ...models.NNsightModel import NNsight


[docs] class Session(GraphBasedContext, RemoteMixin): """A Session is a root Collection that handles adding new Graphs and new Collections while in the session. Attributes: bridge (Bridge): Bridge object which stores all Graphs added during the session and handles interaction between them graph (Graph): Root Graph where operations and values meant for access by all subsequent Graphs should be stored and referenced. model (NNsight): NNsight model. backend (Backend): Backend for this context object. """ def __init__( self, backend: Backend, model: "NNsight", *args, bridge: Bridge = None, **kwargs, ) -> None: self.bridge = Bridge() if bridge is None else bridge self.model = model GraphBasedContext.__init__( self, backend, bridge=self.bridge, proxy_class=self.model.proxy_class, *args, **kwargs, ) def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.model._session = None super().__exit__(exc_type, exc_val, exc_tb)
[docs] def iter(self, iterable: Iterable, **kwargs) -> Iterator: """Creates an Iterator context to iteratively execute an intervention graph, with an update item at each iteration. Args: - iterable (Iterable): Data to iterate over. - return_context (bool): If True, returns the Iterator context. Default: False. Returns: Iterator: Iterator context. Example: Setup: .. code-block:: python import torch from collections import OrderedDict input_size = 5 hidden_dims = 10 output_size = 2 model = nn.Sequential(OrderedDict([ ('layer1', torch.nn.Linear(input_size, hidden_dims)), ('layer2', torch.nn.Linear(hidden_dims, output_size)), ])) input = torch.rand((1, input_size)) Ex: .. code-block:: python with model.session() as session: l = session.apply(list).save() with session.iter([0, 1, 2]) as item: l.append(item) """ bridge = weakref.proxy(self.bridge) backend = BridgeBackend(bridge) return Iterator( iterable, backend, bridge=bridge, proxy_class=self.model.proxy_class, **kwargs, )
### BACKENDS ########
[docs] def local_backend_execute(self) -> Dict[int, Graph]: try: super().local_backend_execute() except EarlyStopProtocol.EarlyStopException: pass local_result = self.bridge.id_to_graph self.bridge = weakref.proxy(self.bridge) return local_result
[docs] def remote_backend_get_model_key(self) -> str: self.model: "RemoteableMixin" return self.model.to_model_key()
[docs] def remote_backend_postprocess_result(self, local_result: Dict[int, Graph]): from ...schema.Response import ResultModel return { id: ResultModel.from_graph(graph) for id, graph in local_result.items() }
[docs] def remote_backend_handle_result_value( self, value: Dict[int, Dict[str, Any]] ): for graph_id, saves in value.items(): graph = self.bridge.id_to_graph[graph_id] for node_name, node_value in saves.items(): graph.nodes[node_name]._value = node_value graph.alive = False
def remote_backend_cleanup(self): self.bridge = weakref.proxy(self.bridge) graph = self.graph graph.alive = False if not isinstance(graph, weakref.ProxyType): self.graph = weakref.proxy(graph) def __repr__(self) -> str: return f"<{self.__class__.__name__} at {hex(id(self))}>"