Source code for nnsight.intervention.backends.remote

from __future__ import annotations

import io
import sys
import time
from datetime import datetime
from typing import Any, Dict, Optional, Tuple

import msgspec
import requests
import socketio
import torch
from tqdm import tqdm

from ... import __IPYTHON__, CONFIG, remote_logger
from ...schema.request import RequestModel, StreamValueModel
from ...schema.response import ResponseModel
from ...schema.result import RESULT, ResultModel
from ...tracing.backends import Backend
from ...tracing.graph import Graph
from ...util import NNsightError
from ..contexts.local import LocalContext, RemoteContext


[docs] class RemoteBackend(Backend): """Backend to execute a context object via a remote service. Context object must inherit from RemoteMixin and implement its methods. Attributes: url (str): Remote host url. Defaults to that set in CONFIG.API.HOST. """ def __init__( self, model_key: str, host: str = None, blocking: bool = True, job_id: str = None, ssl: bool = None, api_key: str = "", ) -> None: self.model_key = model_key self.job_id = job_id or CONFIG.API.JOB_ID self.ssl = CONFIG.API.SSL if ssl is None else ssl self.zlib = CONFIG.API.ZLIB self.format = CONFIG.API.FORMAT self.api_key = api_key or CONFIG.API.APIKEY self.blocking = blocking self.host = host or CONFIG.API.HOST self.address = f"http{'s' if self.ssl else ''}://{self.host}" self.ws_address = f"ws{'s' if CONFIG.API.SSL else ''}://{self.host}" def request(self, graph: Graph) -> Tuple[bytes, Dict[str, str]]: data = RequestModel.serialize(graph, self.format, self.zlib) headers = { "model_key": self.model_key, "format": self.format, "zlib": str(self.zlib), "ndif-api-key": self.api_key, "sent-timestamp": str(time.time()), } return data, headers def __call__(self, graph: Graph): if self.blocking: # Do blocking request. result = self.blocking_request(graph) else: # Otherwise we are getting the status / result of the existing job. result = self.non_blocking_request(graph) if result is not None: ResultModel.inject(graph, result)
[docs] def handle_response( self, response: ResponseModel, graph: Optional[Graph] = None ) -> Optional[RESULT]: """Handles incoming response data. Logs the response object. If the job is completed, retrieve and stream the result from the remote endpoint. Use torch.load to decode and load the `ResultModel` into memory. Use the backend object's .handle_result method to handle the decoded result. Args: response (Any): Json data to concert to `ResponseModel` Raises: Exception: If the job's status is `ResponseModel.JobStatus.ERROR` Returns: ResponseModel: ResponseModel. """ # Log response for user response.log(remote_logger) # If job is completed: if response.status == ResponseModel.JobStatus.COMPLETED: # If the response has no result data, it was too big and we need to stream it from the server. if response.data is None: result = self.get_result(response.id) else: result = response.data return result # If were receiving a streamed value: elif response.status == ResponseModel.JobStatus.STREAM: # Second item is index of LocalContext node. # First item is the streamed value from the remote service. index, dependencies = response.data ResultModel.inject(graph, dependencies) node = graph.nodes[index] node.execute() elif response.status == ResponseModel.JobStatus.NNSIGHT_ERROR: if graph.debug: error_node = graph.nodes[response.data["node_id"]] try: raise NNsightError( response.data["err_message"], error_node.index, response.data["traceback"], ) except NNsightError as nns_err: if ( __IPYTHON__ ): # in IPython the traceback content is rendered by the Error itself # add the error node traceback to the the error's traceback nns_err.traceback_content += "\nDuring handling of the above exception, another exception occurred:\n\n" nns_err.traceback_content += error_node.meta_data["traceback"] else: # else we print the traceback manually print(f"\n{response.data['traceback']}") print( "During handling of the above exception, another exception occurred:\n" ) print(f"{error_node.meta_data['traceback']}") sys.tracebacklimit = 0 raise nns_err from None finally: if __IPYTHON__: sys.tracebacklimit = None else: print(f"\n{response.data['traceback']}") raise SystemExit("Remote exception.")
[docs] def submit_request( self, data: bytes, headers: Dict[str, Any] ) -> Optional[ResponseModel]: """Sends request to the remote endpoint and handles the response object. Raises: Exception: If there was a status code other than 200 for the response. Returns: (ResponseModel): Response. """ from ...schema.response import ResponseModel headers["Content-Type"] = "application/octet-stream" response = requests.post( f"{self.address}/request", data=data, headers=headers, ) if response.status_code == 200: response = ResponseModel(**response.json()) self.handle_response(response) return response else: msg = response.reason raise ConnectionError(msg)
[docs] def get_response(self) -> Optional[RESULT]: """Retrieves and handles the response object from the remote endpoint. Raises: Exception: If there was a status code other than 200 for the response. Returns: (ResponseModel): Response. """ from ...schema.response import ResponseModel response = requests.get( f"{self.address}/response/{self.job_id}", headers={"ndif-api-key": self.api_key}, ) if response.status_code == 200: response = ResponseModel(**response.json()) return self.handle_response(response) else: raise Exception(response.reason)
def get_result(self, id: str) -> RESULT: result_bytes = io.BytesIO() result_bytes.seek(0) # Get result from result url using job id. with requests.get( url=f"{self.address}/result/{id}", stream=True, ) as stream: # Total size of incoming data. total_size = float(stream.headers["Content-length"]) with tqdm( total=total_size, unit="B", unit_scale=True, desc="Downloading result", ) as progress_bar: # chunk_size=None so server determines chunk size. for data in stream.iter_content(chunk_size=None): progress_bar.update(len(data)) result_bytes.write(data) # Move cursor to beginning of bytes. result_bytes.seek(0) # Decode bytes with pickle and then into pydantic object. result = torch.load(result_bytes, map_location="cpu", weights_only=False) result = ResultModel(**result).result # Close bytes result_bytes.close() return result
[docs] def blocking_request(self, graph: Graph) -> Optional[RESULT]: """Send intervention request to the remote service while waiting for updates via websocket. Args: request (RequestModel):Request. """ # We need to do some processing / optimizations on both the graph were sending remotely # and our local intervention graph. In order handle the more complex Protocols for streaming. # Create a socketio connection to the server. with socketio.SimpleClient(reconnection_attempts=10) as sio: # Connect sio.connect( self.ws_address, socketio_path="/ws/socket.io", transports=["websocket"], wait_timeout=10, ) remote_graph = preprocess(graph) data, headers = self.request(remote_graph) headers["session_id"] = sio.sid # Submit request via response = self.submit_request(data, headers) LocalContext.set( lambda *args: self.stream_send(*args, job_id=response.id, sio=sio) ) try: # Loop until while True: # Get pickled bytes value from the websocket. response = sio.receive()[1] # Convert to pydantic object. response = ResponseModel.unpickle(response) # Handle the response. result = self.handle_response(response, graph=graph) # Break when completed. if result is not None: return result except Exception as e: raise e finally: LocalContext.set(None)
[docs] def stream_send( self, values: Dict[int, Any], job_id: str, sio: socketio.SimpleClient ): """Upload some value to the remote service for some job id. Args: value (Any): Value to upload job_id (str): Job id. sio (socketio.SimpleClient): Connected websocket client. """ sio.emit( "stream_upload", data=(StreamValueModel.serialize(values, self.format, self.zlib), job_id), )
[docs] def non_blocking_request(self, graph: Graph): """Send intervention request to the remote service if request provided. Otherwise get job status. Sets CONFIG.API.JOB_ID on initial request as to later get the status of said job. When job is completed, clear CONFIG.API.JOB_ID to request a new job. Args: request (RequestModel): Request if submitting a new request. Defaults to None """ if self.job_id is None: data, headers = self.request(graph) # Submit request via response = self.submit_request(data, headers) CONFIG.API.JOB_ID = response.id CONFIG.save() else: try: result = self.get_response() if result is not None: CONFIG.API.JOB_ID = None CONFIG.save() return result except Exception as e: CONFIG.API.JOB_ID = None CONFIG.save() raise e
def preprocess(graph: Graph): new_graph = graph.copy() for node in new_graph.nodes: if node.target is LocalContext: graph.nodes[node.index].kwargs["uploads"] = RemoteContext.from_local(node) return new_graph