Source code for nnsight.intervention.tracing.invoker

from typing import Callable, TYPE_CHECKING, Any

from ..interleaver import Mediator
from .base import Tracer
from .util import try_catch


if TYPE_CHECKING:
    from .tracer import InterleavingTracer
else:
    InterleavingTracer = Any


[docs] class Invoker(Tracer): """ Extends the Tracer class to invoke intervention functions. This class captures code blocks and compiles them into intervention functions that can be executed by the Interleaver. """ def __init__(self, tracer: InterleavingTracer, *args, **kwargs): """ Initialize an Invoker with a reference to the parent tracer. Args: tracer: The parent InterleavingTracer instance *args: Additional arguments to pass to the traced function **kwargs: Additional keyword arguments to pass to the traced function """ if tracer is not None and tracer.model.interleaving: raise ValueError("Cannot invoke during an active model execution / interleaving.") self.tracer = tracer super().__init__(*args, **kwargs)
[docs] def compile(self): """ Compile the captured code block into an intervention function. The function is wrapped with try-catch logic to handle exceptions and signal completion to the mediator. Returns: A callable intervention function """ self.info.source = [ f"def __nnsight_tracer_{id(self)}__(__nnsight_mediator__, __nnsight_tracing_info__):\n", " __nnsight_mediator__.pull()\n", *try_catch( self.info.source, exception_source=["__nnsight_mediator__.exception(exception)\n"], else_source=["__nnsight_mediator__.end()\n"], ), ] self.info.start_line -= 2
[docs] def execute(self, fn: Callable): """ Execute the compiled intervention function. Creates a new Mediator for the intervention function and adds it to the parent tracer's mediators list. Args: fn: The compiled intervention function """ inputs, batch_group = self.tracer.batcher.batch(self.tracer.model, *self.args, **self.kwargs) self.inputs = inputs mediator = Mediator(fn, self.info, batch_group=batch_group) self.tracer.mediators.append(mediator)