Source code for nnsight.tracing.contexts.globals
from __future__ import annotations
import inspect
from contextlib import AbstractContextManager
from functools import wraps
from types import FunctionType, MethodType
from typing import Any, Type, Union
from ... import util
from ..graph import Graph
from . import Tracer
def global_patch_class(cls: type) -> util.Patch:
if cls.__new__ is object.__new__:
def super_new(cls, *args, **kwargs):
return object.__new__(cls)
cls.__new__ = super_new
fn = cls.__new__
@wraps(fn)
def inner(cls, *args, **kwargs):
if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return cls(*args, **kwargs)
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(cls, *args, **kwargs)
return util.Patch(cls, inner, "__new__")
def global_patch_fn(fn: FunctionType) -> util.Patch:
@wraps(fn)
def inner(*args, **kwargs):
if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)
return util.Patch(inspect.getmodule(fn), inner, fn.__name__)
def global_patch_method(cls: type, fn: MethodType) -> None:
@wraps(fn)
def inner(*args, **kwargs):
if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)
patch = util.Patch(cls, inner, fn.__name__)
GlobalTracingContext.PATCHER.add(patch)
def global_patch(obj: Union[FunctionType, Type]):
if isinstance(obj, type):
patch = global_patch_class(obj)
else:
patch = global_patch_fn(obj)
GlobalTracingContext.PATCHER.add(patch)
[docs]
class GlobalTracingContext(Tracer):
"""The Global Tracing Context handles adding tracing operations globally without reference to a given `GraphBasedContext`.
There should only be one of these and that is `GlobalTracingContext.GLOBAL_TRACING_CONTEXT`.
`GlobalTracingContext.TORCH_HANDLER` handles adding torch functions without reference to a given `GraphBasedContext`.
"""
GLOBAL_TRACING_CONTEXT: GlobalTracingContext
PATCHER: util.Patcher = util.Patcher()
[docs]
class GlobalTracingExit(AbstractContextManager):
def __enter__(self) -> Any:
GlobalTracingContext.PATCHER.__exit__(None, None, None)
return self
def __exit__(self, exc_type, exc_val, traceback):
GlobalTracingContext.PATCHER.__enter__()
if isinstance(exc_val, BaseException):
raise exc_val
def __init__(self) -> None:
"""We create an empty `GraphBasedContext` by default."""
self.graph: Graph = None
@staticmethod
def exit_global_tracing_context():
return GlobalTracingContext.GlobalTracingExit()
[docs]
@staticmethod
def try_register(graph_based_context: Tracer) -> bool:
"""Attempts to register a `Graph` globally.]
Will not if one is already registered.
Args:
graph_based_context (GraphBasedContext): `GraphBasedContext` to register.
Returns:
bool: True if registering ws successful, False otherwise.
"""
if GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return False
GlobalTracingContext.register(graph_based_context)
return True
[docs]
@staticmethod
def try_deregister(graph_based_context: Tracer) -> bool:
"""Attempts to deregister a `Graph` globally.
Will not if `graph_based_context` does not have the same `Graph` as the currently registered one.
Args:
graph_based_context (GraphBasedContext): `GraphBasedContext` to deregister.
Returns:
bool: True if deregistering ws successful, False otherwise.
"""
if (
not GlobalTracingContext.GLOBAL_TRACING_CONTEXT
or graph_based_context.graph
is not GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph
):
return False
GlobalTracingContext.deregister()
return True
[docs]
@staticmethod
def register(graph_based_context: Tracer) -> None:
"""Register `GraphBasedContext` globally.
Args:
graph_based_context (GraphBasedContext): GraphBasedContext to register.
"""
assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is None
GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = graph_based_context.graph
GlobalTracingContext.PATCHER.__enter__()
[docs]
@staticmethod
def deregister() -> None:
"""Deregister `GraphBasedContext` globally.
Args:
graph_based_context (GraphBasedContext): GraphBasedContext to deregister.
"""
assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is not None
GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = None
GlobalTracingContext.PATCHER.__exit__(None, None, None)
def __bool__(self) -> bool:
"""True if there is a `GraphBasedContext` registered globally. False otherwise."""
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is not None
GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()