Source code for nnsight.intervention.contexts.globals

"""
Global patching allows us to add un-traceable operations to nnsight by replacing them with ones that use the GLOBAL_TRACING_CONTEXT to add the operation to the current graph.
"""

from __future__ import annotations

from inspect import getmembers, isclass

import torch
from torch.utils import data

from ...tracing.contexts.globals import (
    GlobalTracingContext,
    global_patch,
    global_patch_method,
)
from ...tracing.graph.proxy import proxy_patch
from . import InterventionTracer

# Torch classes
global_patch(torch.nn.Parameter)
global_patch(torch.nn.Linear)

global_patch(data.DataLoader)
# Tensor creation operations
global_patch(torch.arange)
global_patch(torch.empty)
global_patch(torch.eye)
global_patch(torch.full)
global_patch(torch.linspace)
global_patch(torch.logspace)
global_patch(torch.ones)
global_patch(torch.rand)
global_patch(torch.randint)
global_patch(torch.randn)
global_patch(torch.randperm)
global_patch(torch.zeros)
global_patch(torch.cat)

# Module methods

global_patch_method(torch.nn.Module, torch.nn.Module.zero_grad)

# All Optimizers
for key, value in getmembers(torch.optim, isclass):

    if issubclass(value, torch.optim.Optimizer):

        global_patch(value)

import math
from inspect import getmembers, isbuiltin, isfunction

import einops

# Einops
for key, value in getmembers(einops.einops, isfunction):
    setattr(einops.einops, key, proxy_patch(value))
# math
for key, value in getmembers(math, isbuiltin):
    setattr(math, key, proxy_patch(value))


# Give it InterventionTracer methods
[docs] class GlobalInterventionTracingContext(GlobalTracingContext, InterventionTracer): GLOBAL_TRACING_CONTEXT: GlobalInterventionTracingContext
GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalInterventionTracingContext()