Source code for nnsight.modeling.mixins.meta

from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import torch
from accelerate import init_empty_weights

from ...intervention.tracing.tracer import ScanningTracer
from .. import NNsight
from .loadable import LoadableMixin

if TYPE_CHECKING:
    from ...intervention.interleaver import Interleaver
else:
    Interleaver = Any


[docs] class MetaMixin(LoadableMixin): def __init__( self, *args, dispatch: bool = False, meta_buffers: bool = True, rename: Optional[Dict[str, str]] = None, **kwargs, ) -> None: self.dispatched = False if isinstance(args[0], torch.nn.Module) or dispatch: self.dispatched = True super().__init__(*args, rename=rename, **kwargs) else: with init_empty_weights(include_buffers=meta_buffers): model = self._load_meta(*args, **kwargs) NNsight.__init__(self, model, rename=rename) self.args = args self.kwargs = kwargs def _load_meta(self, *args, **kwargs) -> torch.nn.Module: raise NotImplementedError() def dispatch(self) -> None: model = self._load(*self.args, **self.kwargs) self._update(model) # TODO legacy self.__dict__["_model"] = self._module self.dispatched = True def interleave(self, fn: Callable, *args, **kwargs): if not self.dispatched and not isinstance(self._interleaver.tracer, ScanningTracer): self.dispatch() if isinstance(fn, torch.nn.Module): fn = self._module elif isinstance(fn, MethodType) and fn.__self__ is not None: # Unbind using __func__, then bind to new_instance using __get__ new_self = ( self._module if isinstance(fn.__self__, torch.nn.Module) else self ) fn = fn.__func__.__get__(new_self, type(new_self)) return super().interleave(fn, *args, **kwargs)