Source code for nnsight.modeling.mixins.loadable
from typing import Dict, Optional
import torch
from ..base import NNsight
[docs]
class LoadableMixin(NNsight):
def __init__(self, *args, rename: Optional[Dict[str,str]] = None,**kwargs) -> None:
if not isinstance(args[0], torch.nn.Module):
model = self._load(*args, **kwargs)
else:
model = args[0]
super().__init__(model, rename=rename)
def _load(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError()