Source code for nnsight.intervention.contexts.local

from typing import Callable, List, Optional

from nnsight.tracing.graph.node import Node

from ...tracing.contexts import Tracer
from ...tracing.graph import GraphType, NodeType
from ..protocols import EntryPoint, NoopProtocol


[docs] class LocalContext(Tracer): send: Optional[Callable] = None @classmethod def set(cls, fn: Callable): cls.send = fn @classmethod def execute(cls, node: NodeType): super().execute(node) uploads = node.kwargs.get("uploads", []) if uploads: values = {index: node.graph.nodes[index].value for index in uploads} cls.send(values) for index in uploads: node = node.graph.nodes[index] node.remaining_listeners -= 1 if node.redundant: node.destroy()
[docs] class RemoteContext(Tracer): send: Optional[Callable] = None receive: Optional[Callable] = None @classmethod def set(cls, send: Callable, receive: Callable): cls.send = send cls.receive = receive @classmethod def from_local(cls, local_node: NodeType): local_node.target = RemoteContext graph: GraphType = local_node.args[0] start = graph[0].index end = graph[-1].index uploads = [] # TODO check for swap and error for node in graph.nodes[start : end + 1]: for dependency in node.dependencies: if ( isinstance(dependency.target, type) and issubclass(dependency.target, EntryPoint) ) or dependency.index < start: local_node.args.append(dependency) if isinstance(node.target, type) and issubclass(node.target, EntryPoint): continue node.args.clear() node.kwargs.clear() node.target = NoopProtocol for listener in node.listeners: if listener.index > end: uploads.append(node.index) if len(uploads) > 0: local_node.kwargs["upload"] = True return uploads @classmethod def execute(cls, node: NodeType): graph, *dependencies = node.args dependencies = { dependency.index: dependency.value for dependency in dependencies } cls.send((node.index, dependencies)) super().execute(node) if node.kwargs.get("upload", False): values = cls.receive() for index, value in values.items(): graph.nodes[index]._value = value