Source code for nnsight.contexts.session.Session

from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any, Dict, Iterable

from ...tracing.Bridge import Bridge
from ...tracing.Graph import Graph
from ...tracing.protocols import EarlyStopProtocol
from ..backends import Backend, BridgeBackend, RemoteMixin
from ..GraphBasedContext import GraphBasedContext
from .Iterator import Iterator

    from ...models.mixins import RemoteableMixin
    from ...models.NNsightModel import NNsight

[docs] class Session(GraphBasedContext, RemoteMixin): """A Session is a root Collection that handles adding new Graphs and new Collections while in the session. Attributes: bridge (Bridge): Bridge object which stores all Graphs added during the session and handles interaction between them graph (Graph): Root Graph where operations and values meant for access by all subsequent Graphs should be stored and referenced. model (NNsight): NNsight model. backend (Backend): Backend for this context object. """ def __init__( self, backend: Backend, model: "NNsight", *args, bridge: Bridge = None, **kwargs, ) -> None: self.bridge = Bridge() if bridge is None else bridge self.model = model GraphBasedContext.__init__( self, backend, bridge=self.bridge, proxy_class=self.model.proxy_class, *args, **kwargs, ) def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.model._session = None super().__exit__(exc_type, exc_val, exc_tb)
[docs] def iter(self, iterable: Iterable, **kwargs) -> Iterator: """Creates an Iterator context to iteratively execute an intervention graph, with an update item at each iteration. Args: - iterable (Iterable): Data to iterate over. - return_context (bool): If True, returns the Iterator context. Default: False. Returns: Iterator: Iterator context. Example: Setup: .. code-block:: python import torch from collections import OrderedDict input_size = 5 hidden_dims = 10 output_size = 2 model = nn.Sequential(OrderedDict([ ('layer1', torch.nn.Linear(input_size, hidden_dims)), ('layer2', torch.nn.Linear(hidden_dims, output_size)), ])) input = torch.rand((1, input_size)) Ex: .. code-block:: python with model.session() as session: l = session.apply(list).save() with session.iter([0, 1, 2]) as item: l.append(item) """ bridge = weakref.proxy(self.bridge) backend = BridgeBackend(bridge) return Iterator( iterable, backend, bridge=bridge, proxy_class=self.model.proxy_class, **kwargs, )
### BACKENDS ########
[docs] def local_backend_execute(self) -> Dict[int, Graph]: try: super().local_backend_execute() except EarlyStopProtocol.EarlyStopException: pass local_result = self.bridge.id_to_graph self.bridge = weakref.proxy(self.bridge) return local_result
[docs] def remote_backend_get_model_key(self) -> str: self.model: "RemoteableMixin" return self.model.to_model_key()
[docs] def remote_backend_postprocess_result(self, local_result: Dict[int, Graph]): from ...schema.Response import ResultModel return { id: ResultModel.from_graph(graph) for id, graph in local_result.items() }
[docs] def remote_backend_handle_result_value( self, value: Dict[int, Dict[str, Any]] ): for graph_id, saves in value.items(): graph = self.bridge.id_to_graph[graph_id] for node_name, node_value in saves.items(): graph.nodes[node_name]._value = node_value graph.alive = False
def remote_backend_cleanup(self): self.bridge = weakref.proxy(self.bridge) graph = self.graph graph.alive = False if not isinstance(graph, weakref.ProxyType): self.graph = weakref.proxy(graph) def __repr__(self) -> str: return f"<{self.__class__.__name__} at {hex(id(self))}>"