Source code for nnsight.patching

"""The patching module handles patching of classes and functions in modules."""
from __future__ import annotations

import importlib
import types
from contextlib import AbstractContextManager
from typing import Any, List, Optional

from . import util


[docs] class Patch: """Class representing a replacement of an attribute on a module. Attributes: obj (Any): Object to replace. replacement (Any): Object that replaces. parent (Any): Module or class to replace attribute. """ def __init__(self, parent: Any, replacement: Any, key: str) -> None: self.parent = parent self.replacement = replacement self.key = key self.orig = getattr(self.parent, key)
[docs] def patch(self) -> None: """Carries out the replacement of an object in a module/class.""" setattr(self.parent, self.key, self.replacement)
[docs] def restore(self) -> None: """Carries out the restoration of the original object on the objects module/class.""" setattr(self.parent, self.key, self.orig)
[docs] class Patcher(AbstractContextManager): """Context manager that patches from a list of Patches on __enter__ and restores the patch on __exit__. Attributes: patches (List[Patch]): """ def __init__(self, patches: Optional[List[Patch]] = None) -> None: self.patches = patches or []
[docs] def add(self, patch: Patch) -> None: """Adds a Patch to the patches. Also calls `.patch()` on the Patch. Args: patch (Patch): Patch to add. """ self.patches.append(patch) patch.patch()
def __enter__(self) -> Patcher: """Enters the patching context. Calls `.patch()` on all patches. Returns: Patcher: Patcher """ for patch in self.patches: patch.patch() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Calls `.restore()` on all patches.""" for patch in self.patches: patch.restore()