Source code for nnsight.intervention.contexts.interleaving

import weakref
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Tuple,
)


from ...tracing.backends import Backend
from ...tracing.graph import GraphType
from ..graph import (
    InterventionGraph,
    InterventionNode,
    InterventionNodeType,
    ValidatingInterventionNode,
)
from ..interleaver import Interleaver

from . import Invoker
from . import InterventionTracer
if TYPE_CHECKING:
    from .. import NNsight


[docs] class InterleavingTracer(InterventionTracer): """This is the Tracer type that actually interleaves an `InterventionGraph` with a PyTorch model upon execute. Attributes: _model (NNsight): NNsight model. invoker (Invoker): Current open invoker so we can prevent opening two at the same time. args (Tuple[...]): Positional arguments. First is which method to interleave with and subsequent args are invoker inputs. kwargs (Dict[str,Any]): Keyword arguments passed to the method to interleave. These are "global" keyword arguments for our chosen methof while kwargs for a given invoker are used for preprocessing the invoker input. """ def __init__( self, model: "NNsight", method: Optional[str] = None, backend: Optional[Backend] = None, parent: Optional[GraphType] = None, validate: bool = False, debug: Optional[bool] = None, **kwargs, ) -> None: super().__init__( graph_class=InterventionGraph, model=model, node_class=ValidatingInterventionNode if validate else InterventionNode, proxy_class=model.proxy_class, backend=backend, parent=parent, graph=model._default_graph, debug=debug, ) self._model = model # Tell all Envoy's about the current Tracer so they can use it to add InterventionProtocol Nodes. self._model._envoy._set_tracer(weakref.proxy(self)) self.invoker: Optional[Invoker] = None self.args = [method] self.kwargs = kwargs
[docs] def invoke(self, *inputs: Any, **kwargs) -> Invoker: """Create an Invoker context for a given input. Raises: Exception: If an Invoker context is already open Returns: Invoker: Invoker. """ if self.invoker is not None: raise Exception("Can't create an invoker context with one already open!") return Invoker(self, *inputs, **kwargs)
def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.invoker is not None: self.invoker.__exit__(None, None, None) self._model._envoy._reset() super().__exit__(exc_type, exc_val, exc_tb) @classmethod def _batch( cls, model: "NNsight", invoker_inputs: Tuple[Tuple[Tuple[Any], Dict[str, Any]]] ) -> Tuple[Tuple[Tuple[Any], Dict[str, Any]], List[Tuple[int, int]]]: """Batches together each set of inputs from each Invoker by iteratively calling the models ._prepare_input and ._batch methods. Args: model (NNsight): Model which defines its own logic for preparing and batching input invoker_inputs (Tuple[Tuple[Tuple[Any], Dict[str, Any]]]): Tuple of invoker inputs. Returns: Tuple[Tuple[Tuple[Any], Dict[str, Any]], List[Tuple[int, int]]]: One single batched input. List[Tuple[int, int]]: Batch groups """ batch_groups = [] batch_start = 0 batched_input = None for args, kwargs in invoker_inputs: (args, kwargs), batch_size = model._prepare_input(*args, **kwargs) batch_groups.append((batch_start, batch_size)) batched_input = model._batch(batched_input, *args, **kwargs) batch_start += batch_size if batched_input is None: batched_input = (((0, -1),), dict()) return batched_input, batch_groups @property def _invoker_group(self): return len(self.args) - 2 @classmethod def execute(cls, node: InterventionNodeType): graph, method, *invoker_inputs = node.args graph: InterventionGraph model = graph.model # There may be Nodes in the inputs. Convert them to their value invoker_inputs, kwargs = node.prepare_inputs((invoker_inputs, node.kwargs)) # Batch each invoker input into one input (invoker_args, invoker_kwargs), batch_groups = cls._batch(model, invoker_inputs) # Compile Intervention Graph graph.compile() graph.reset() graph.execute() interleaver = Interleaver(graph, batch_groups=batch_groups) graph.model.interleave(interleaver, *invoker_args, fn=method,**kwargs, **invoker_kwargs) graph.cleanup()