Source code for nnsight.contexts.Invoker
from __future__ import annotations
import copy
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
import torch
from torch._subclasses.fake_tensor import FakeCopyMode, FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from .. import util
from ..patching import Patch, Patcher
from ..tracing.Node import Node
from ..tracing.Proxy import Proxy
from . import check_for_dependencies
from .GraphBasedContext import GlobalTracingContext
if TYPE_CHECKING:
from .Tracer import Tracer
[docs]
class Invoker(AbstractContextManager):
"""An Invoker is meant to work in tandem with a :class:`nnsight.contexts.Tracer.Tracer` to enter input and manage intervention tracing.
Attributes:
tracer (nnsight.contexts.Tracer.Tracer): Tracer object to enter input and manage context.
inputs (tuple[Any]): Initially entered inputs, then post-processed inputs from model's ._prepare_inputs(...) method.
scan (bool): If to execute the model using `FakeTensor` in order to update the potential sizes/dtypes of all modules' Envoys' inputs/outputs as well as validate things work correctly.
Scanning is not free computation wise so you may want to turn this to false when running in a loop.
When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so
for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to False.
kwargs (Dict[str,Any]): Keyword arguments passed to the model's _prepare_inputs method.
scanning (bool): If currently scanning.
"""
def __init__(
self,
tracer: "Tracer",
*inputs: Any,
scan: bool = False,
**kwargs,
) -> None:
self.tracer = tracer
self.inputs = inputs
self.scan = scan
self.kwargs = kwargs
self.scanning = False
self.tracer.invoker = self
def __enter__(self) -> Invoker:
"""Enters a new invocation context with a given input.
Calls the model's _prepare_inputs method using the input and other arguments.
If scan is True, uses the model's ._execute method to update and validate module Envoy's inputs/outputs using a fake mode.
Gets a batched version of the post processed input using the model's ._batched_inputs method to update the Tracer's
current batch_size and batched_input.
Returns:
Invoker: Invoker.
"""
has_proxies_in_inputs = False
# If were accumulating, we might have Proxies in the input.
# Therefore we first: Check to see if there are any Proxies.
# If there are, preserve the raw inputs with Proxies converted to a Locked Bridge protocol.
# Set self.inputs to be the proxy_value so we can prepare_inputs, get the batch size, and scan.
if self.tracer.model._session is not None:
self.inputs, has_proxies_in_inputs = check_for_dependencies(
self.inputs
)
with GlobalTracingContext.exit_global_tracing_context():
if not has_proxies_in_inputs:
self.inputs, batch_size = self.tracer.model._prepare_inputs(
*self.inputs, **self.kwargs
)
if self.scan:
inputs = self.inputs
if has_proxies_in_inputs:
inputs = util.apply(inputs, lambda x: x.proxy_value, Node)
inputs, batch_size = self.tracer.model._prepare_inputs(
*inputs, **self.kwargs
)
self.tracer.model._envoy._clear()
self.scanning = True
with Patcher() as patcher:
# Some logic (like gpt-j rotary embeddings) gets "poisoned" by FakeTensors.
# This does not happen when `torch._jit_internal.is_scripting() returns True.`
patcher.add(
Patch(torch._jit_internal, lambda: True, "is_scripting")
)
with FakeTensorMode(
allow_non_fake_inputs=True,
shape_env=ShapeEnv(assume_static_by_default=True),
) as fake_mode:
with FakeCopyMode(fake_mode):
self.tracer.model._execute(
*copy.deepcopy(inputs),
**copy.deepcopy(self.tracer._kwargs),
)
self.scanning = False
else:
self.tracer.model._envoy._reset()
self.tracer._invoker_inputs.append(self.inputs)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.tracer.invoker = None
if isinstance(exc_val, BaseException):
raise exc_val