Source code for nnsight.intervention.tracing.base

import ast
import ctypes
import inspect
import re
import sys
from types import FrameType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

from ..backends.base import Backend
from ..backends.execution import ExecutionBackend
from .globals import Globals
from .util import (TracingDeffermentException, get_non_nnsight_frame,
                   push_variables, suppress_all_output)


[docs] class ExitTracingException(Exception): """Exception raised to exit the tracing process. This exception is used as a control flow mechanism to cleanly exit a with block without executing the code inside it. """ pass
[docs] class WithBlockNotFoundError(Exception): """Exception raised when a with block is not found in the source code. This exception is used to indicate that a with block was not found at the specified line number. """ pass
[docs] class Tracer: """ Captures and executes code within a tracing context. This class allows for capturing code blocks within a 'with' statement, compiling them into callable functions, and executing them with access to the model and local variables. It provides a mechanism for intercepting and manipulating the execution flow of Python code. The tracing process works by: 1. Capturing the code inside a 'with' block 2. Compiling it into a callable function 3. Executing it with the appropriate context """
[docs] class Info: """ Container for information about the traced code. This class stores metadata about the code being traced, including the source code itself and frame information from the call stack. Attributes: source: List of source code lines from the traced block frame: Frame information from the call stack where tracing occurred indent: Number of spaces/tabs used for indentation in the original code """ def __init__( self, source: List[str], frame: FrameType, start_line: int, node: ast.With, filename: str = None, ): """ Initialize Info with source code and frame information. Args: source: List of source code lines from the traced block frame: Frame information from the call stack indent: Number of spaces/tabs used for indentation in the original code node: AST node of the with block """ self.source = source self.frame = frame self.start_line = start_line self.node = node self.filename = ( filename if filename is not None else f"<nnsight {id(self)}>" ) def copy(self): return Tracer.Info( self.source, self.frame, self.start_line, self.node, self.filename ) def __getstate__(self): """Get the state of the info for serialization.""" return { "source": self.source, "start_line": self.start_line, "filename": self.filename, "frame": self.frame, } def __setstate__(self, state): """Set the state of the info for deserialization.""" self.source = state["source"] self.start_line = state["start_line"] self.filename = state["filename"] self.frame = state["frame"] self.node = None
def __init__(self, *args, backend: Backend = None, _info: Info = None, **kwargs): """ Initialize a Tracer instance. Args: *args: Additional arguments to pass to the traced function backend: Backend implementation for executing the traced code (defaults to ExecutionBackend if None) **kwargs: Additional keyword arguments to pass to the traced function """ self.args = args self.kwargs = kwargs self.backend = ExecutionBackend() if backend is None else backend self.info = _info if _info is not None else None if self.info is None: try: self.capture() except TracingDeffermentException: pass
[docs] def capture(self): """ Capture the code block within the 'with' statement. This method walks up the call stack to find the frame outside of nnsight, extracts the source code of the 'with' block, and prepares it for later execution. It identifies the exact code block to be traced by analyzing the source code structure. """ # Find the frame outside of nnsight by walking up the call stack frame = get_non_nnsight_frame() # Get source code lines from the appropriate location start_line = frame.f_lineno # CASE 1: Were already inside of another nnsight trace. if "__nnsight_tracing_info__" in frame.f_locals: # For dynamically generated code, get source from tracing info source_lines = frame.f_locals["__nnsight_tracing_info__"].source # CASE 2: We're in an IPython console. elif "_ih" in frame.f_locals: import IPython ipython = IPython.get_ipython() source_lines = ipython.user_global_ns["_ih"][-1].splitlines(keepends=True) if not source_lines[-1].endswith("\n"): source_lines[-1] += "\n" # CASE 3: We're in a regular Python file. elif not frame.f_code.co_filename.startswith("<nnsight"): # For regular files, get source lines using inspect source_lines, offset = inspect.getsourcelines(frame) start_line = start_line if offset == 0 else start_line - offset + 1 # CASE 4: We're in a regular Python interactive console. elif frame.f_code.co_filename == "<nnsight-console>": from ... import __INTERACTIVE_CONSOLE__ source_lines = __INTERACTIVE_CONSOLE__.buffer # Add newline to each source line if it doesn't end with one source_lines = [ line if line.endswith("\n") else line + "\n" for line in source_lines ] else: raise ValueError("No source code found") # Calculate indentation level of the source code itself stripped = source_lines[0].lstrip( "\t " ) # indent for removing leading tabs/spaces indent = len(source_lines[0]) - len(stripped) # If theres an indent, we need to remove it. This handles the case of say a trace in an indented function. E.x. a trace inside a method on a class. if indent > 0: source_lines = [ line[indent:] if line.strip() else line for line in source_lines ] # Extract the code using AST parsing start_line, source_lines, node = self.parse(source_lines, start_line) # Calculate indentation level of the Tracer creation line. stripped = source_lines[0].lstrip("\t ") # removes leading tabs/spaces indent = len(source_lines[0]) - len(stripped) - 4 # If theres an indent (more than just the indentation of the with block), we need to remove it. This handles the case of say a trace in an indented block. E.x. a trace inside a for loop or another with block. if indent > 0: source_lines = [ line[indent:] if line.strip() else line for line in source_lines ] # Store the captured information for later use self.info = Tracer.Info(source_lines, frame, start_line, node)
[docs] def parse(self, source_lines:List[str], start_line:int): """ Parse the source code to extract the source code. Uses the Abstract Syntax Tree (AST) to identify the exact boundaries of the code from the specified line number. Args: source_lines: List of source code lines start_line: Line number where the tracer creation statement begins Returns: List of source code lines. """ # Parse the entire source into an AST tree = ast.parse("".join(source_lines)) class Visitor(ast.NodeVisitor): """AST visitor to find the 'with' node at the specified line.""" def __init__(self, line_no): self.target = None self.line_no = line_no def visit_With(self, node): if node.lineno == self.line_no: self.target = node else: self.generic_visit(node) visitor = Visitor(start_line) visitor.visit(tree) if visitor.target is None: # Gather 5 lines before and after start_line for context context_start = max(0, start_line - 5) context_end = min(len(source_lines), start_line + 6) context_lines = source_lines[context_start:context_end] context_lines[start_line - context_start - 1] = context_lines[start_line - context_start - 1].rstrip('\n') + " <--- HERE\n" context_str = "".join(context_lines) message = f"With block not found at line {start_line}\n" message += f"We looked here:\n\n{context_str}" raise WithBlockNotFoundError(message) end_line = visitor.target.end_lineno start_line = visitor.target.body[0].lineno - 1 return start_line, source_lines[start_line:end_line], visitor.target
[docs] def compile(self) -> Callable: """ Compile the captured source code as a callable function. Wraps the captured code in a function definition that accepts the necessary context parameters for execution. Returns: A callable function that executes the captured code block """ # Wrap the captured code in a function definition with appropriate parameters self.info.source = [ f"def __nnsight_tracer_{id(self)}__(__nnsight_tracer__, __nnsight_tracing_info__):\n", " __nnsight_tracer__.pull()\n", *self.info.source, " __nnsight_tracer__.push()\n", ] self.info.start_line -= 1
[docs] def execute(self, fn: Callable): """ Execute the compiled function. Runs the compiled function with the necessary context to execute the traced code block. Args: fn: The compiled function to execute """ fn(self, self.info)
[docs] def push(self, state: Dict = None): """ Push local variables back to the original execution frame. This allows changes made during tracing to affect the original scope. Args: state: Dictionary of variable names and values to push to the frame. If None, automatically collects variables from the current frame. """ frame = self.info.frame if state is None: # Find the frame where the traced code is executing state_frame = inspect.currentframe() while state_frame: state_frame = state_frame.f_back if state_frame and state_frame.f_code.co_filename.startswith( "<nnsight" ): break state = state_frame.f_locals # Collect all non-nnsight variables from the frame state = {k: v for k, v in state.items() if not k.startswith("__nnsight")} if Globals.stack == 1 : state = {k: v for k, v in state.items() if id(v) in Globals.saves} push_variables(frame, state) state.clear()
def pull(self): frame = inspect.currentframe() while frame: frame = frame.f_back if frame and frame.f_code.co_filename.startswith( "<nnsight" ): break state = self.info.frame.f_locals state = {k: v for k, v in state.items() if not k.startswith("__nnsight")} push_variables(frame, state) def __enter__(self): """ Enter the tracing context. Captures the code block and sets up a trace function to exit normal execution flow once the block is captured. Returns: The Tracer instance for use in the 'with' statement """ if self.info is None: self.capture() if isinstance(self.info.node.body[0], ast.Pass): return self def skip(new_frame, event, arg): """ Trace function that raises ExitTracingException when the traced code is reached. This prevents the actual execution of the traced code in its original context, allowing us to execute it later with our custom handling. """ new_lineno = new_frame.f_lineno - new_frame.f_code.co_firstlineno if ( new_frame.f_code.co_filename == self.info.frame.f_code.co_filename and new_lineno >= self.info.start_line ): # To remove colab warning with suppress_all_output(): sys.settrace(None) self.info.frame.f_trace = None raise ExitTracingException() # Set the trace function at both global and frame level # To remove colab warning with suppress_all_output(): sys.settrace(skip) self.info.frame.f_trace = skip return self def __exit__(self, exc_type, exc_val, exc_tb): """ Exit the tracing context. Args: exc_type: Exception type if an exception was raised exc_val: Exception value if an exception was raised exc_tb: Exception traceback if an exception was raised Returns: True if an ExitTracingException was caught (to suppress it), None otherwise (to propagate other exceptions) """ # Suppress the ExitTracingException but let other exceptions propagate if exc_type is ExitTracingException: # Execute the traced code using the configured backend self.backend(self) return True self.backend(self) ### Serialization ### def __getstate__(self): """Get the state of the tracer for serialization.""" return { "args": self.args, "kwargs": self.kwargs, "info": self.info, } def __setstate__(self, state): """Set the state of the tracer for deserialization.""" self.args = state["args"] self.kwargs = state["kwargs"] self.info = state["info"] self.info.start_line = 0 self.backend = ExecutionBackend()