Source code for nnsight.intervention.tracing.iterator

from typing import Callable, TYPE_CHECKING, Any, Union
from .base import Tracer
from ..interleaver import Interleaver, Mediator
from .util import try_catch

class IteratorProxy:
    
    def __init__(self, interleaver: Interleaver):
        self.interleaver = interleaver
        
    def __getitem__(self, iteration: Union[int, slice]):
        return IteratorTracer(iteration, self.interleaver)
    
[docs] class IteratorTracer(Tracer): def __init__(self, iteration: Union[int, slice], interleaver: Interleaver): super().__init__() self.interleaver = interleaver self.iteration = iteration
[docs] def compile(self): """ Compile the captured source code as a callable function. Wraps the captured code in a function definition that accepts the necessary context parameters for execution. Returns: A callable function that executes the captured code block """ iteration_var_name = self.info.node.items[0].optional_vars.id if self.info.node.items[0].optional_vars is not None else "__nnsight_iteration__" # Wrap the captured code in a function definition with appropriate parameters self.info.source = [ f"def __nnsight_tracer_{id(self)}__(__nnsight_mediator__, __nnsight_tracing_info__, {iteration_var_name}):\n", " __nnsight_mediator__.pull()\n", *try_catch( self.info.source, exception_source=["__nnsight_mediator__.exception(exception)\n"], else_source=["__nnsight_mediator__.end()\n"], ), ] self.info.start_line -= 2
[docs] def execute(self, fn: Callable): mediator = Mediator(fn, self.info, batch_group=self.interleaver.current.batch_group, stop=self.interleaver.current.all_stop) mediator.name = "Iterator" + mediator.name self.interleaver.current.iter(mediator, self.iteration)