from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
import torch
from torch._subclasses.fake_tensor import FakeCopyMode, FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from ... import util
from ...tracing.contexts import GlobalTracingContext
from ...tracing.graph import Node, Proxy
from ...tracing.protocols import Protocol
from ..protocols import EntryPoint
if TYPE_CHECKING:
from . import InterventionGraph
[docs]
class InterventionNode(Node):
"""This is the intervention extension of the base Node type.
It has a fake_value to see information about this Node's future value before execution.
It adds additional functionality to Node.prepare_inputs to handle Tensors.
"""
def __init__(
self, *args, fake_value: Optional[Any] = inspect._empty, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.fake_value = fake_value
[docs]
def update_dependencies(self):
for dependency in self.dependencies:
if len(self.graph.defer_stack) > 0 and (
dependency.index < self.graph.defer_stack[-1]
or (
EntryPoint.is_entrypoint(dependency.target)
and dependency.graph is not self.graph
)
):
continue
dependency.remaining_listeners -= 1
if dependency.redundant:
dependency.destroy()
InterventionNodeType = TypeVar("InterventionNodeType", bound=InterventionNode)
[docs]
class ValidatingInterventionNode(InterventionNode):
"""The ValidatingInterventionNode executes its target using the fake_values of all of its dependencies to calculate a new fake_value for this node.
Does not do this if the Node is detached from any graph, already has a fake_value (specified by whoever created the Node) or is a Protocol.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if (
self.attached
and self.fake_value is inspect._empty
and not Protocol.is_protocol(self.target)
):
self.fake_value = validate(self.target, *self.args, **self.kwargs)
@staticmethod
def backwards_check(target: Callable, *args) -> bool:
if target is Proxy.call:
node: Node = args[0]
if not isinstance(node, Node):
return False
if node.target is util.fetch_attr and node.args[1] == "backward":
return True
return False
@staticmethod
def validate(target: Callable, *args, **kwargs):
# Enter FakeMode.
with FakeTensorMode(
allow_non_fake_inputs=True,
shape_env=ShapeEnv(assume_static_by_default=True),
) as fake_mode:
with FakeCopyMode(fake_mode):
with GlobalTracingContext.exit_global_tracing_context():
if backwards_check(target, *args):
return None
args, kwargs = InterventionNode.prepare_inputs(
(args, kwargs), fake=True
)
return target(
*args,
**kwargs,
)