Source code for nnsight.tracing.contexts.tracer
from typing import Callable, TypeVar, Union
from typing_extensions import Self
from ..graph import ProxyType, SubGraph, NodeType, Proxy
from ..protocols import StopProtocol
from . import Condition, Context, Iterator
[docs]
class Tracer(Context[SubGraph[NodeType, ProxyType]]):
def __enter__(self) -> Self:
from .globals import GlobalTracingContext
GlobalTracingContext.try_register(self)
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
from .globals import GlobalTracingContext
GlobalTracingContext.try_deregister(self)
return super().__exit__(exc_type, exc_val, exc_tb)
def iter(self, collection):
return Iterator(collection, parent=self.graph)
def cond(self, condition):
return Condition(condition, parent=self.graph)
def stop(self):
StopProtocol.add(self.graph)
def log(self, *args):
self.apply(print, *args)
R = TypeVar('R')
def apply(self, target: Callable[..., R], *args, **kwargs) -> Union[Proxy, R]:
return self.graph.create(
target,
*args,
**kwargs,
)