Source code for nnsight.intervention.protocols.module

from typing import TYPE_CHECKING, Any, Dict

import torch
from typing_extensions import Self

from ... import util
from ...tracing.protocols import Protocol

if TYPE_CHECKING:
    from ..graph import InterventionGraph, InterventionNode

[docs] class ApplyModuleProtocol(Protocol): """Protocol that references some root model, and calls its .forward() method given some input. Using .forward() vs .__call__() means it wont trigger hooks. Uses an attachment to the Graph to store the model. """
[docs] @classmethod def add( cls, graph: "InterventionGraph", module_path: str, *args, hook=False, **kwargs ) -> Self: """Creates and adds an ApplyModuleProtocol to the Graph. Assumes the attachment has already been added via ApplyModuleProtocol.set_module(). Args: graph (Graph): Graph to add the Protocol to. module_path (str): Module path (model.module1.module2 etc), of module to apply from the root module. Returns: InterventionProxy: ApplyModule Proxy. """ from ..graph.node import ValidatingInterventionNode, validate # If the Graph is validating, we need to compute the proxy_value for this node. if graph.node_class is ValidatingInterventionNode: # If the module has parameters, get its device to move input tensors to. module: torch.nn.Module = util.fetch_attr( graph.model._model, module_path ) try: device = next(module.parameters()).device except: device = None # Enter FakeMode for proxy_value computing. kwargs['fake_value'] = validate(module.forward, *args, **kwargs) kwargs["hook"] = hook # Create and attach Node. return graph.create( cls, module_path, *args, **kwargs, )
[docs] @classmethod def execute(cls, node: "InterventionNode") -> None: """Executes the ApplyModuleProtocol on Node. Args: node (Node): ApplyModule Node. """ graph: InterventionGraph = node.graph module: torch.nn.Module = util.fetch_attr( graph.model._model, node.args[0] ) try: device = next(module.parameters()).device except: device = None args, kwargs = node.prepare_inputs((node.args, node.kwargs), device=device) module_path, *args = args hook = kwargs.pop("hook") if hook: output = module(*args, **kwargs) else: output = module.forward(*args, **kwargs) node.set_value(output)
[docs] @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. Returns: - Dict: dictionary style. """ default_style = super().style() default_style["node"] = {"color": "green4", "shape": "polygon", "sides": 6} return default_style