Source code for nnsight.util

"""Module for utility functions and classes used throughout the package."""

import importlib
import types
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Collection,
    Dict,
    Optional,
    Tuple,
    Type,
    Union,
)

import torch

if TYPE_CHECKING:
    from .tracing.Node import Node

# TODO Have an Exception you can raise to stop apply early

[docs] def apply( data: Any, fn: Callable, cls: Type, inplace: bool = False ) -> Collection: """Applies some function to all members of a collection of a give type (or types) Args: data (Any): Collection of data to apply function to. fn (Callable): Function to apply. cls (type): Type or Types to apply function to. inplace (bool): If to apply the fn inplace. (For lists and dicts) Returns: Any: Same kind of collection as data, after then fn has been applied to members of given type. """ if isinstance(data, cls): return fn(data) data_type = type(data) if data_type == list: if inplace: for idx, _data in enumerate(data): data[idx] = apply(_data, fn, cls, inplace=inplace) return data return [apply(_data, fn, cls, inplace=inplace) for _data in data] elif data_type == tuple: return tuple([apply(_data, fn, cls, inplace=inplace) for _data in data]) elif data_type == dict: if inplace: for key, value in data.items(): data[key] = apply(value, fn, cls, inplace=inplace) return data return { key: apply(value, fn, cls, inplace=inplace) for key, value in data.items() } elif data_type == slice: return slice( apply(data.start, fn, cls, inplace=inplace), apply(data.stop, fn, cls, inplace=inplace), apply(data.step, fn, cls, inplace=inplace), ) return data
[docs] def fetch_attr(object: object, target: str) -> Any: """Retrieves an attribute from an object hierarchy given an attribute path. Levels are separated by '.' e.x (transformer.h.1) Args: object (object): Root object to get attribute from. target (str): Attribute path as '.' separated string. Returns: Any: Fetched attribute. """ if target == "": return object target_atoms = target.split(".") for atom in target_atoms: if not atom: continue object = getattr(object, atom) return object
[docs] def wrap(object: object, wrapper: Type, *args, **kwargs) -> object: """Wraps some object given some wrapper type. Updates the __class__ attribute of the object and calls the wrapper type's __init__ method. Args: object (object): Object to wrap. wrapper (Type): Type to wrap the object in. Returns: object: Wrapped object. """ if isinstance(object, wrapper): return object new_class = types.new_class( object.__class__.__name__, (object.__class__, wrapper), ) object.__class__ = new_class wrapper.__init__(object, *args, **kwargs) return object
def to_import_path(type: type) -> str: return f"{type.__module__}.{type.__name__}" def from_import_path(import_path: str) -> type: *import_path, classname = import_path.split(".") import_path = ".".join(import_path) return getattr(importlib.import_module(import_path), classname)
[docs] class WrapperModule(torch.nn.Module): """Simple torch module which passes it's input through. Useful for hooking. If there is only one argument, returns the first element. """
[docs] def forward(self, *args, **kwargs): if len(args) == 1: args = args[0] return args