Source code for nnsight.intervention.tracing.util


import contextlib
import ctypes
import inspect
import os
import re
import sys
from builtins import open
from types import FrameType
from typing import TYPE_CHECKING, Callable, Dict, List

if TYPE_CHECKING:
    from .base import Tracer


[docs] def indent(source: List[str], indent: int = 1): """ Indents each line in the source list by a specified number of indentation levels. Args: source: List of strings to indent indent: Number of indentation levels to apply (default: 1) Returns: List of indented strings """ return [" " * indent + line for line in source]
[docs] def try_catch( source: List[str], exception_source: List[str] = ["raise\n"], else_source: List[str] = ["pass\n"], finally_source: List[str] = ["pass\n"], ): """ Wraps source code in a try-except-else-finally block. Args: source: The code to be wrapped in the try block exception_source: Code for the except block (default: ["raise\n"]) else_source: Code for the else block (default: ["pass\n"]) finally_source: Code for the finally block (default: ["pass\n"]) Returns: List of strings representing the complete try-catch block, properly indented """ source = [ "try:\n", *source, "except Exception as exception:\n", *indent(exception_source), "else:\n", *indent(else_source), "finally:\n", *indent(finally_source), ] return indent(source)
@contextlib.contextmanager def suppress_all_output(): with open(os.devnull, 'w') as devnull: old_stdout = sys.stdout old_stderr = sys.stderr try: sys.stdout = devnull sys.stderr = devnull yield finally: sys.stdout = old_stdout sys.stderr = old_stderr
[docs] def get_dependencies(fn:Callable): """ Extracts global dependencies used by a function. Args: fn: The function to analyze for dependencies Returns: Dictionary mapping names to their corresponding global objects used by the function """ used_names = fn.__code__.co_names return {name: fn.__globals__[name] for name in used_names if name in fn.__globals__}
from ... import CONFIG
[docs] class ExceptionWrapper(Exception): """ Wrapper for exceptions that provides additional details for tracer created code. This class helps provide better error messages by including source code context and proper line numbers from the original code being traced. """ def __init__(self, info:"Tracer.Info", original:Exception, *args, **kwargs): """ Initialize the exception wrapper. Args: info: Tracer information containing context about where the exception occurred original: The original exception being wrapped *args, **kwargs: Additional arguments passed to the parent Exception class """ super().__init__(*args, **kwargs) self.original = original self.infos = [] self.set_info(info)
[docs] def set_info(self, info:"Tracer.Info"): """ Updates the tracer information and recalculates line offsets. Args: info: New tracer information to use """ # ex_info = ExceptionWrapper.Info(self.accumulator, info.frame.f_code.co_filename, info.frame.f_code.co_firstlineno, info.start_line, info.source, info.frame.f_code.co_name) self.infos.append(info)
def __str__(self): """ Generates a formatted traceback string with proper context. Returns: A string containing the formatted traceback with source code context """ accumulator = 0 co_first_line = 0 filename = "" co_name = "" start_lines = {} filename_mapping = {} co_names = {} source_lines = {} for info in reversed(self.infos): if not info.frame.f_code.co_filename.startswith("<nnsight"): accumulator = info.frame.f_code.co_firstlineno - 1 filename = info.frame.f_code.co_filename co_name = info.frame.f_code.co_name accumulator += info.start_line - 1 start_lines[info.filename] = accumulator filename_mapping[info.filename] = filename co_names[info.filename] = co_name source_lines[info.filename] = info.source traceback = self.original.__traceback__ tb_frames = [] current_tb = traceback import linecache while current_tb is not None: frame = current_tb.tb_frame filename = frame.f_code.co_filename lineno = current_tb.tb_lineno name = frame.f_code.co_name # Case 1: <nnsight> - our traced code if filename.startswith("<nnsight"): fname = filename_mapping[filename] start_line = start_lines[filename] co_name = co_names[filename] if '__nnsight_tracing_info__' in frame.f_locals else frame.f_code.co_name source = source_lines[filename] line_number = lineno - 1 + start_line tb_frames.append(f' File "{fname}", line {line_number+1 + co_first_line}, in {co_name}') tb_frames.append(f' {source[lineno-1].strip()}') # Case 2: Skip internal nnsight code elif "nnsight/" in filename: if CONFIG.APP.DEBUG: tb_frames.append(f' File "{filename}", line {lineno}, in {name}') try: line = linecache.getline(filename, lineno).strip() if line: tb_frames.append(f' {line}') except: pass # Case 3: Regular code - use normal traceback else: tb_frames.append(f' File "{filename}", line {lineno}, in {name}') try: line = linecache.getline(filename, lineno).strip() if line: tb_frames.append(f' {line}') except: pass current_tb = current_tb.tb_next traceback = [ "\n\nTraceback (most recent call last):" ] + tb_frames + [ f'\n{type(self.original).__name__}: {self.original}', ] return "\n".join(traceback)
[docs] def wrap_exception(exception:Exception, info:"Tracer.Info"): """ Wraps an exception with additional context from the tracer. This function either updates an existing ExceptionWrapper or creates a new dynamically-typed exception class that inherits from both the original exception type and ExceptionWrapper. Args: exception: The exception to wrap info: Tracer information containing context about where the exception occurred Returns: A wrapped exception with enhanced traceback information """ if isinstance(exception, ExceptionWrapper): # If already wrapped, just update the info exception.__suppress_context__ = True # Kills "... during handling ..." exception.__traceback__ = None exception.set_info(info) return exception # Create a dynamic exception type that inherits from both the original exception type # and our ExceptionWrapper exception_type = type(exception) class NNsightException(exception_type, ExceptionWrapper): __qualname__ = "NNsightException" __module__ = "nnsight" def __init__(self, *args, **kwargs): exception_type.__init__(self, *args, **kwargs) ExceptionWrapper.__init__(self, info, exception) def __str__(self): return ExceptionWrapper.__str__(self) # Create a new instance of the same type, with overridden __str__ wrapped = NNsightException(*exception.args) wrapped.__dict__.update(exception.__dict__) return wrapped
[docs] class TracingDeffermentException(Exception): """Exception raised when a tracing defferment is encountered. This exception is used to indicate that a tracing defferment is encountered. """ pass
def get_non_nnsight_frame() -> FrameType: frame = inspect.currentframe() while frame: frame = frame.f_back if frame: filename = frame.f_code.co_filename # Match if filename contains 'nnsight/tests' or 'nnsight\tests' # OR if it does NOT contain '/nnsight/' or '\nnsight\' if "__defer_capture__" in frame.f_locals: raise TracingDeffermentException() if ( re.search(r"[\\/]{1}nnsight[\\/]{1}tests", filename) or not re.search(r"[\\/]{1}nnsight[\\/]", filename) ): break return frame def push_variables(frame:FrameType, variables:Dict): is_generated_frame = frame.f_code.co_filename.startswith("<nnsight") if is_generated_frame: global_variables = {k: v for k, v in variables.items() if k not in frame.f_locals} for key, value in global_variables.items(): frame.f_globals[key] = value ctypes.pythonapi.PyFrame_LocalsToFast( ctypes.py_object(frame), ctypes.c_int(0) ) for key, value in variables.items(): frame.f_locals[key] = value ctypes.pythonapi.PyFrame_LocalsToFast( ctypes.py_object(frame), ctypes.c_int(0) )