Source code for nnsight.modeling.mixins.remoteable

import ctypes
import inspect
from types import FrameType
from typing import Any, Callable, Dict, Union

from typing_extensions import Self

from ...intervention.backends import Backend
from ...intervention.backends.remote import RemoteBackend
from ...intervention.serialization import load, save
from ...intervention.tracing.tracer import InterleavingTracer, Tracer
from ...util import from_import_path, to_import_path
from .meta import MetaMixin


[docs] class RemoteableMixin(MetaMixin):
[docs] def trace( self, *inputs: Any, backend: Union[Backend, str, None] = None, remote: bool = False, blocking: bool = True, **kwargs: Dict[str, Any], ): if backend is not None: pass elif remote: backend = RemoteBackend(self.to_model_key(), blocking=blocking) # If backend is a string, assume RemoteBackend url. elif isinstance(backend, str): backend = RemoteBackend( self.to_model_key(), host=backend, blocking=blocking ) return super().trace( *inputs, backend=backend, tracer_cls=RemoteInterleavingTracer, **kwargs, )
def session( self, *inputs: Any, backend: Union[Backend, str, None] = None, remote: bool = False, blocking: bool = True, **kwargs: Dict[str, Any], ): if backend is not None: pass elif remote: backend = RemoteBackend(self.to_model_key(), blocking=blocking) # If backend is a string, assume RemoteBackend url. elif isinstance(backend, str): backend = RemoteBackend( self.to_model_key(), host=backend, blocking=blocking ) return super().session( *inputs, backend=backend, tracer_cls=RemoteTracer, **kwargs, ) def _remoteable_model_key(self) -> str: raise NotImplementedError() @classmethod def _remoteable_from_model_key(cls, model_key: str) -> Self: raise NotImplementedError() def to_model_key(self) -> str: return f"{to_import_path(type(self))}:{self._remoteable_model_key()}" @classmethod def from_model_key(cls, model_key: str, **kwargs) -> Self: import_path, model_key = model_key.split(":", 1) type: RemoteableMixin = from_import_path(import_path) return type._remoteable_from_model_key(model_key, **kwargs)
[docs] class StreamTracer(Tracer): _send: Callable = None _recv: Callable = None def __init__(self, frame: FrameType, *args, **kwargs): super().__init__(*args, **kwargs) self.frame = frame @classmethod def register(cls, send_fn: Callable, recv_fn: Callable): cls._send = send_fn cls._recv = recv_fn @classmethod def deregister(cls): cls._send = None cls._recv = None
[docs] def execute(self, fn: Callable): data = save(fn) if self._send is None: raise ValueError("No send function provided") if self._recv is None: raise ValueError("No recv function provided") self._send(data) data: Dict[str, Any] = load(self._recv(), None) if self.frame.f_code.co_filename.startswith("<nnsight"): # For dynamically generated code, update both globals and locals self.frame.f_globals.update(data) self.frame.f_locals.update(data) # Ensure locals are properly synchronized with the frame ctypes.pythonapi.PyFrame_LocalsToFast( ctypes.py_object(self.frame), ctypes.c_int(0) ) else: # For regular files, just update locals for key, value in data.items(): self.frame.f_locals[key] = value ctypes.pythonapi.PyFrame_LocalsToFast( ctypes.py_object(self.frame), ctypes.c_int(0) )
[docs] class RemoteTracer(Tracer): def local(self): frame = inspect.currentframe().f_back return StreamTracer(frame)
[docs] class RemoteInterleavingTracer(InterleavingTracer, RemoteTracer): pass