Source code for nnsight.intervention.graph.graph

import copy
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union

from typing_extensions import Self

from ...tracing.contexts import Context
from ...tracing.graph import SubGraph
from ...util import NNsightError
from ..protocols import ApplyModuleProtocol, GradProtocol, InterventionProtocol
from . import InterventionNode, InterventionNodeType, InterventionProxyType

if TYPE_CHECKING:
    from .. import NNsight
    from ...tracing.graph.graph import GraphType, NodeType


[docs] class InterventionGraph(SubGraph[InterventionNode, InterventionProxyType]): """The `InterventionGraph` is the special `SubGraph` type that handles the complex intervention operations a user wants to make during interleaving. We need to `.compile()` it before execution to determine how to execute interventions appropriately. Attributes: model (NNsight): NNsight model. interventions grad_subgraph compiled call_counter deferred """ def __init__( self, *args, model: Optional["NNsight"] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.model = model self.interventions: Dict[str, List[InterventionNode]] = defaultdict(list) self.grad_subgraph: Set[int] = set() self.compiled = False self.call_counter: Dict[int, int] = defaultdict(int) self.deferred: Dict[int, List[int]] = defaultdict(list) def __getstate__(self) -> Dict: return { "subset": self.subset, "nodes": self.nodes, "interventions": self.interventions, "compiled": self.compiled, "call_counter": self.call_counter, "deferred": self.deferred, "grad_subgraph": self.grad_subgraph, "defer_stack": self.defer_stack, } def __setstate__(self, state: Dict) -> None: self.__dict__.update(state)
[docs] def reset(self) -> None: self.call_counter = defaultdict(int) return super().reset()
def set(self, model: "NNsight"): self.model = model def context_dependency( self, context_node: InterventionNode, intervention_subgraphs: List[SubGraph], ) -> None: context_graph: SubGraph = context_node.args[0] start = context_graph.subset[0] end = context_graph.subset[-1] for intervention_subgraph in intervention_subgraphs: # continue if the subgraph does not overlap with the context's graph if intervention_subgraph.subset[-1] < start or end < intervention_subgraph.subset[0]: continue for intervention_index in intervention_subgraph.subset: # if there's an overlapping node, make the context depend on the intervention node in the subgraph if start <= intervention_index and intervention_index <= end: # the first node in the subgraph is an InterventionProtocol node intervention_node = intervention_subgraph[0] context_node._dependencies.add(intervention_node.index) intervention_node._listeners.add(context_node.index) # TODO: maybe we don't need this intervention_subgraph.subset.append(context_node.index) break def compile(self) -> Optional[Dict[str, List[InterventionNode]]]: if self.compiled: return self.interventions if len(self) == 0: self.compiled = True return intervention_subgraphs: List[SubGraph] = [] start = self[0].index # is the first node corresponding to an executable graph? # occurs when a Conditional or Iterator context is explicitly entered by a user if isinstance(self[0].target, type) and issubclass( self[0].target, Context ): graph = self[0].args[0] # handle emtpy if statments or for loops if len(graph) > 0: start = graph[0].index end = self[-1].index + 1 context_start: int = None defer_start: int = None context_node: InterventionNode = None # looping over all the nodes created within this graph's context for index in range(start, end): node: InterventionNodeType = self.nodes[index] # is this node part of an inner context's subgraph? if context_node is None and node.graph is not self: context_node = self.nodes[node.graph[-1].index + 1] context_start = self.subset.index(context_node.index) defer_start = node.index self.context_dependency(context_node, intervention_subgraphs) if node.target is InterventionProtocol: # build intervention subgraph subgraph = SubGraph(self, subset=sorted(list(node.subgraph()))) module_path, *_ = node.args self.interventions[module_path].append(node) intervention_subgraphs.append(subgraph) # if the InterventionProtocol is defined within a sub-context if context_node is not None: # make the current context node dependent on this intervention node context_node._dependencies.add(node.index) node._listeners.add(context_node.index) # TODO: maybe we don't need this self.subset.append(node.index) graph: SubGraph = node.graph graph.subset.remove(node.index) node.kwargs["start"] = context_start node.kwargs["defer_start"] = defer_start node.graph = self else: node.kwargs["start"] = self.subset.index(subgraph.subset[0]) node.kwargs["defer_start"] = node.kwargs["start"] elif node.target is GradProtocol: subgraph = SubGraph(self, subset=sorted(list(node.subgraph()))) intervention_subgraphs.append(subgraph) self.grad_subgraph.update(subgraph.subset[1:]) if context_node is not None: context_node._dependencies.add(node.index) node._listeners.add(context_node.index) subgraph.subset.append(context_node.index) graph: SubGraph = node.graph graph.subset.remove(node.index) node.kwargs["start"] = context_start node.graph = self else: node.kwargs["start"] = self.subset.index(subgraph.subset[1]) elif node.target is ApplyModuleProtocol: node.graph = self elif context_node is not None and context_node is node: context_node = None self.compiled = True
[docs] def execute( self, start: int = 0, grad: bool = False, defer: bool = False, defer_start: int = 0, ) -> None: err: Tuple[int, NNsightError] = None if defer_start in self.deferred: for index in self.deferred[defer_start]: self.nodes[index].reset() del self.deferred[defer_start] if defer: self.defer_stack.append(defer_start) for node in self[start:]: if node.executed: continue elif ( node.index != self[start].index and node.target is InterventionProtocol ): break elif node.fulfilled: try: node.execute() if defer and node.target is not InterventionProtocol: self.deferred[defer_start].append(node.index) except NNsightError as e: err = (node.index, e) break elif not grad and node.index in self.grad_subgraph: continue else: break if defer: self.defer_stack.pop() if err is not None: defer_stack = self.defer_stack self.defer_stack = [] self.clean(err[0]) self.defer_stack = defer_stack raise err[1]
[docs] def count(self, index: int, iteration: Union[int, List[int], slice]) -> bool: """Increments the count of times a given Intervention Node has tried to be executed and returns if the Node is ready and if it needs to be deferred. Args: index (int): Index of intervention node to return count for. iteration (Union[int, List[int], slice]): What iteration(s) this Node should be executed for. Returns: bool: If this Node should be executed on this iteration. bool: If this Node and recursive listeners should have updating their remaining listeners (and therefore their destruction) deferred. """ ready = False defer = False count = self.call_counter[index] if isinstance(iteration, int): ready = count == iteration elif isinstance(iteration, list): iteration.sort() ready = count in iteration defer = count != iteration[-1] elif isinstance(iteration, slice): start = iteration.start or 0 stop = iteration.stop ready = count >= start and (stop is None or count < stop) defer = stop is None or count < stop - 1 # if defer: # self.deferred.add(index) # else: # self.deferred.discard(index) self.call_counter[index] += 1 return ready, defer
[docs] def clean(self, start: Optional[int] = None): if start is None: start = self[0].index end = self[-1].index + 1 # Loop over ALL nodes within the span of this graph. for index in range(start, end): node = self.nodes[index] if node.executed: break node.update_dependencies()
[docs] def cleanup(self) -> None: """Because some modules may be executed more than once, and to accommodate memory management just like a loop, intervention graph sections defer updating the remaining listeners of Nodes if this is not the last time this section will be executed. If we never knew it was the last time, there may still be deferred sections after execution. These will be leftover in graph.deferred, and therefore we need to update their dependencies. """ # For every intervention graph section (indicated by where it started) for start in self.deferred: # Loop through all nodes that got their dependencies deferred. for index in range(start, self.deferred[start][-1] + 1): node = self.nodes[index] # Update each of its dependencies for dependency in node.dependencies: # Only if it was before start # (not within this section, but before) if dependency.index < start: dependency.remaining_listeners -= 1 if dependency.redundant: dependency.destroy()
[docs] def copy( self, new_graph: Self = None, parent: Optional["GraphType"] = None, memo: Optional[Dict[int, "NodeType"]] = None, ) -> Self: if memo is None: memo = {} new_graph = super().copy(new_graph, parent=parent, memo=memo) new_graph.compiled = self.compiled for key, value in self.call_counter.items(): new_graph.call_counter[memo[key]] = value if new_graph.compiled: for module_path, list_of_nodes in self.interventions.items(): new_graph.interventions[module_path] = [ new_graph.nodes[memo[node.index]] for node in list_of_nodes ] for key, values in self.deferred.items(): new_graph.deferred[memo[key]] = [memo[index] for index in values] new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph] return new_graph
# @classmethod # def shift(cls, mgraph: MultiGraph) -> MultiGraph: # InterventionProtocol.compile(mgraph) # intervention_subgraphs = InterventionProtocol.get_interventions(mgraph).values() # graph_id_to_invoker_groups = defaultdict(set) # graph_id_to_intervention_node = defaultdict(list) # for subgraph in intervention_subgraphs: # for (start, end) in subgraph: # node = mgraph[start] # invoker_group = node.args[1] # offset = 0 # for graph in mgraph.id_to_graphs.values(): # offset += len(graph) # if start < offset: # graph_id_to_invoker_groups[graph.id].add(invoker_group) # graph_id_to_intervention_node[graph.id].append(node) # break # global_offset = 0 # for graph_id, invoker_groups in graph_id_to_invoker_groups.items(): # min_group = min(invoker_groups) # max_group = max(invoker_groups) # offset = global_offset - min_group # for node in graph_id_to_intervention_node[graph_id]: # node.args[1] += offset # global_offset += max_group + 1 # return mgraph