Source code for nnsight.intervention.contexts.tracer
import inspect
from functools import wraps
from typing import Any, Callable, Dict, Optional, TypeVar, Union
from ...tracing.contexts import Tracer
from ..graph import (InterventionNodeType, InterventionProxy,
InterventionProxyType)
from . import LocalContext
from ... import CONFIG
[docs]
class InterventionTracer(Tracer[InterventionNodeType, InterventionProxyType]):
"""Extension of base Tracer to add additional intervention functionality and type hinting for intervention proxies.
"""
R = TypeVar("R")
def __init__(self, *args, **kwargs) -> None:
if kwargs['debug'] == None:
kwargs['debug'] = CONFIG.APP.DEBUG
super().__init__(*args, **kwargs)
def apply(
self, target: Callable[..., R], *args, **kwargs
) -> Union[InterventionProxy, R]:
return super().apply(target, *args, **kwargs)
def local(self, fn: Optional[Callable] = None) -> Union[LocalContext, Callable]:
if fn is None:
return LocalContext(parent=self.graph)
elif inspect.isroutine(fn):
@wraps(fn)
def inner(*args, **kwargs):
with LocalContext(parent=self.graph) as context:
return context.apply(fn, *args, **kwargs)
else:
# TODO: error
pass
[docs]
@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""
default_style = super().style()
default_style["node"] = {"color": "purple", "shape": "polygon", "sides": 6}
default_style["arg_kname"][1] = "method"
return default_style