Source code for nnsight.intervention.serialization

import io
import pickle
from builtins import open
from types import FrameType
from typing import TYPE_CHECKING, Any, Optional, Union

import cloudpickle

from .envoy import Envoy


[docs] class CustomCloudPickler(cloudpickle.Pickler): def persistent_id(self, obj): if isinstance(obj, Envoy): return f"ENVOY:{obj.path}" if isinstance(obj, FrameType): return "FRAME" return None
[docs] class CustomCloudUnpickler(pickle.Unpickler): def __init__(self, file, root: Envoy, frame: FrameType): super().__init__(file) self.root = root self.frame = frame def persistent_load(self, pid): if pid.startswith("ENVOY:"): path = pid.removeprefix("ENVOY:model") return self.root.get(path) if pid == "FRAME": return self.frame raise pickle.UnpicklingError(f"Unknown persistent id: {pid}")
def save(obj: Any, path: Optional[str] = None): if path is None: file = io.BytesIO() CustomCloudPickler(file, protocol=4).dump(obj) file.seek(0) return file.read() with open(path, "wb") as file: CustomCloudPickler(file).dump(obj) def load(data: Union[str, bytes], model: Envoy, frame: Optional[FrameType] = None): if isinstance(data, bytes): return CustomCloudUnpickler(io.BytesIO(data), model, frame).load() with open(data, "rb") as file: return CustomCloudUnpickler(file, model, frame).load()