Source code for nnsight.intervention.backends.remote

from __future__ import annotations

import inspect
import io
import os
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import requests
import socketio
import torch
from tqdm.auto import tqdm

from ... import __IPYTHON__, CONFIG, __version__
from ..._c.py_mount import mount, unmount
from ...intervention.serialization import load, save
from ...log import remote_logger
from ...schema.request import RequestModel
from ...schema.response import RESULT, ResponseModel
from ..tracing.tracer import Tracer
from .base import Backend


[docs] class RemoteException(Exception): pass
[docs] class RemoteBackend(Backend): """Backend to execute a context object via a remote service. Context object must inherit from RemoteMixin and implement its methods. Attributes: url (str): Remote host url. Defaults to that set in CONFIG.API.HOST. """ def __init__( self, model_key: str, host: str = None, blocking: bool = True, job_id: str = None, ssl: bool = None, api_key: str = "", callback: str = "", ) -> None: self.model_key = model_key self.host = host or os.environ.get("NDIF_HOST", None) or CONFIG.API.HOST self.api_key = api_key or os.environ.get("NDIF_API_KEY", None) or CONFIG.API.APIKEY self.job_id = job_id self.ssl = CONFIG.API.SSL if ssl is None else ssl self.zlib = CONFIG.API.ZLIB self.blocking = blocking self.callback = callback self.address = f"http{'s' if self.ssl else ''}://{self.host}" self.ws_address = f"ws{'s' if CONFIG.API.SSL else ''}://{self.host}" self.job_status = None def request(self, tracer: Tracer) -> Tuple[bytes, Dict[str, str]]: interventions = super().__call__(tracer) data = RequestModel(interventions=interventions, tracer=tracer).serialize( self.zlib ) headers = { "nnsight-model-key": self.model_key, "nnsight-zlib": str(self.zlib), "nnsight-version": __version__, "ndif-api-key": self.api_key, "ndif-timestamp": str(time.time()), "ndif-callback": self.callback, } return data, headers def __call__(self, tracer = None): if self.blocking: # Do blocking request. result = self.blocking_request(tracer) else: # Otherwise we are getting the status / result of the existing job. result = self.non_blocking_request(tracer) if tracer is not None and result is not None: tracer.push(result) return result
[docs] def handle_response( self, response: ResponseModel, tracer: Optional[Tracer] = None ) -> Optional[RESULT]: """Handles incoming response data. Logs the response object. If the job is completed, retrieve and stream the result from the remote endpoint. Use torch.load to decode and load the `ResultModel` into memory. Use the backend object's .handle_result method to handle the decoded result. Args: response (Any): Json data to concert to `ResponseModel` Raises: Exception: If the job's status is `ResponseModel.JobStatus.ERROR` Returns: ResponseModel: ResponseModel. """ self.job_status = response.status if response.status == ResponseModel.JobStatus.ERROR: raise RemoteException(f"{response.description}\nRemote exception.") # Log response for user response.log(remote_logger) self.job_status = response.status # If job is completed: if response.status == ResponseModel.JobStatus.COMPLETED: # If the response has no result data, it was too big and we need to stream it from the server. if response.data is None: result = self.get_result(response.id) else: result = response.data return result elif response.status == ResponseModel.JobStatus.STREAM: model = getattr(tracer, "model", None) fn = load(response.data, model) local_tracer = LocalTracer(_info=tracer.info) local_tracer.execute(fn)
[docs] def submit_request( self, data: bytes, headers: Dict[str, Any] ) -> Optional[ResponseModel]: """Sends request to the remote endpoint and handles the response object. Raises: Exception: If there was a status code other than 200 for the response. Returns: (ResponseModel): Response. """ from ...schema.response import ResponseModel headers["Content-Type"] = "application/octet-stream" response = requests.post( f"{self.address}/request", data=data, headers=headers, ) if response.status_code == 200: response = ResponseModel(**response.json()) self.job_id = response.id self.handle_response(response) return response else: try: msg = response.json()["detail"] except: msg = response.reason raise ConnectionError(msg)
[docs] def get_response(self) -> Optional[RESULT]: """Retrieves and handles the response object from the remote endpoint. Raises: Exception: If there was a status code other than 200 for the response. Returns: (ResponseModel): Response. """ from ...schema.response import ResponseModel response = requests.get( f"{self.address}/response/{self.job_id}", headers={"ndif-api-key": self.api_key}, ) if response.status_code == 200: response = ResponseModel(**response.json()) return self.handle_response(response) else: raise Exception(response.reason)
def get_result(self, id: str) -> RESULT: result_bytes = io.BytesIO() result_bytes.seek(0) # Get result from result url using job id. with requests.get( url=f"{self.address}/result/{id}", stream=True, ) as stream: # Total size of incoming data. total_size = float(stream.headers["Content-length"]) with tqdm( total=total_size, unit="B", unit_scale=True, desc="Downloading result", ) as progress_bar: # chunk_size=None so server determines chunk size. for data in stream.iter_content(chunk_size=None): progress_bar.update(len(data)) result_bytes.write(data) # Move cursor to beginning of bytes. result_bytes.seek(0) # Decode bytes with pickle and then into pydantic object. result = torch.load(result_bytes, map_location="cpu", weights_only=False) # Close bytes result_bytes.close() return result
[docs] def blocking_request(self, tracer: Tracer) -> Optional[RESULT]: """Send intervention request to the remote service while waiting for updates via websocket. Args: request (RequestModel):Request. """ # We need to do some processing / optimizations on both the graph were sending remotely # and our local intervention graph. In order handle the more complex Protocols for streaming. # Create a socketio connection to the server. with socketio.SimpleClient(reconnection_attempts=10) as sio: # Connect sio.connect( self.ws_address, socketio_path="/ws/socket.io", transports=["websocket"], wait_timeout=10, ) data, headers = self.request(tracer) headers["ndif-session_id"] = sio.sid # Submit request via response = self.submit_request(data, headers) try: LocalTracer.register(lambda data: self.stream_send(data, sio)) # Loop until while True: # Get pickled bytes value from the websocket. response = sio.receive()[1] # Convert to pydantic object. response = ResponseModel.unpickle(response) # Handle the response. result = self.handle_response(response, tracer=tracer) # Break when completed. if result is not None: return result except Exception as e: raise e finally: LocalTracer.deregister()
[docs] def stream_send( self, values: Dict[int, Any], sio: socketio.SimpleClient ): """Upload some value to the remote service for some job id. Args: value (Any): Value to upload job_id (str): Job id. sio (socketio.SimpleClient): Connected websocket client. """ data = save(values) sio.emit( "stream_upload", data=(data, self.job_id), )
[docs] def non_blocking_request(self, tracer: Tracer): """Send intervention request to the remote service if request provided. Otherwise get job status. Sets CONFIG.API.JOB_ID on initial request as to later get the status of said job. When job is completed, clear CONFIG.API.JOB_ID to request a new job. Args: request (RequestModel): Request if submitting a new request. Defaults to None """ if self.job_id is None: data, headers = self.request(tracer) # Submit request via response = self.submit_request(data, headers) self.job_id = response.id else: result = self.get_response() return result
[docs] class LocalTracer(Tracer): _send: Callable = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.remotes = set() @classmethod def register(cls, send_fn: Callable): cls._send = send_fn @classmethod def deregister(cls): cls._send = None def _save_remote(self, obj: Any): self.remotes.add(id(obj))
[docs] def execute(self, fn: Callable): mount(self._save_remote, "remote") fn(self, self.info) unmount("remote") return
[docs] def push(self): # Find the frame where the traced code is executing state_frame = inspect.currentframe().f_back state = state_frame.f_locals super().push(state) state = {k:v for k,v in state.items() if id(v) in self.remotes} LocalTracer._send(state)