Source code for nnsight.intervention.inject

import ast
import inspect
import textwrap
from builtins import compile, exec
from collections import defaultdict
from typing import Callable

import astor


[docs] class FunctionCallWrapper(ast.NodeTransformer): def __init__(self, name:str): self.name_index = defaultdict(int) self.line_numbers = {} self.name = name def get_name(self, node:ast.Name): func_name = None if isinstance(node.func, ast.Name): # Simple function call like foo() func_name = node.func.id elif isinstance(node.func, ast.Attribute): # Method call like obj.method() or module.submodule.func() parts = [] current = node.func while isinstance(current, ast.Attribute): parts.append(current.attr) current = current.value if isinstance(current, ast.Name): parts.append(current.id) # Reverse to get the correct order (e.g., torch.nn.functional) func_name = "_".join(reversed(parts)) name = f'{func_name}_{self.name_index[func_name]}' self.name_index[func_name] += 1 return name def visit_Call(self, node): self.generic_visit(node) # First, process nested calls # Get the fully qualified name of the function being called func_name = self.get_name(node) self.line_numbers[func_name] = node.lineno - 2 return ast.Call( func=ast.Call( func=ast.Name(id='wrap', ctx=ast.Load()), args=[node.func], keywords=[ast.keyword(arg='name', value=ast.Constant(value=f'{self.name}.{func_name}'))] ), args=node.args, keywords=node.keywords )
def convert(fn:Callable, wrap:Callable, name:str): #TODO what about exceptions? source = textwrap.dedent(inspect.getsource(fn)) # Get the module where the forward method is defined module_globals = inspect.getmodule(fn).__dict__ tree = ast.parse(source) transformer = FunctionCallWrapper(name) tree = transformer.visit(tree) ast.fix_missing_locations(tree) local_namespace = {'wrap': wrap} # Include both globals from this module and the module where forward is defined global_namespace = {**globals(), **module_globals, 'wrap': wrap} filename = "<nnsight>" code_obj = compile(astor.to_source(tree), filename, 'exec') exec(code_obj, global_namespace, local_namespace) fn = local_namespace[fn.__name__] return source, transformer.line_numbers, fn