Source code for nnsight.intervention.contexts.session

from typing import TYPE_CHECKING, Optional

from typing_extensions import Self

from ..graph import (InterventionNode, InterventionProxy,
                     ValidatingInterventionNode)
from . import InterventionTracer

if TYPE_CHECKING:
    from .. import NNsight


[docs] class Session(InterventionTracer[InterventionNode, InterventionProxy]): """A Session simply allows grouping multiple Tracers in one computation graph. """ def __init__(self, model: "NNsight", validate: bool = False, debug:Optional[bool] = None, **kwargs) -> None: super().__init__( node_class=ValidatingInterventionNode if validate else InterventionNode, proxy_class=model.proxy_class, debug=debug, **kwargs, ) self.model = model def __enter__(self) -> Self: self.model._session = self return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.model._session = None return super().__exit__(exc_type, exc_val, exc_tb)