Source code for nnsight.intervention.graph.proxy

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Collection, TypeVar, Union

import torch
from typing_extensions import Self

from ... import util
from ...tracing.graph import Proxy
from .. import protocols

if TYPE_CHECKING:
    from . import InterventionNode


[docs] class InterventionProxy(Proxy): def __init__(self, node: "InterventionNode") -> None: super().__init__(node) self.__dict__["_grad"] = None self._grad: Self self.node: "InterventionNode" @property def grad(self) -> Self: """ Calling denotes the user wishes to get the grad of proxy tensor and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one. Returns: Proxy: Grad proxy. """ self.__dict__["_grad"] = protocols.GradProtocol.add(self.node.graph, self.node, fake_value=self.node.fake_value) return self._grad @grad.setter def grad(self, value: Union[InterventionProxy, Any]) -> None: """ Calling denotes the user wishes to set the grad of this proxy tensor and therefore we create a Proxy of that request via a SwapProtocol. Args: value (Union[InterventionProxy, Any]): Value to set output to. """ protocols.SwapProtocol.add(self.node.graph, self._grad, value) def __setattr__( self, key: Union[InterventionProxy, Any], value: Union[Self, Any] ) -> None: # We catch setting .grad as that is a special Protocol vs. setting attributes generally. if key == "grad": return getattr(self.__class__, key).fset(self, value) return super().__setattr__(key, value) @property def shape(self) -> Collection[torch.Size]: """Property to retrieve the shape of the traced proxy value or real value. Returns: Union[torch.Size,Collection[torch.Size]]: Proxy value shape or collection of shapes. """ if not self.node.attached: return util.apply(self.value, lambda x: x.shape, torch.Tensor) # If we haven't scanned in a proxy_value, just return a proxy to get the attribute. if self.node.fake_value is inspect._empty: return super().__getattr__("shape") return util.apply(self.node.fake_value, lambda x: x.shape, torch.Tensor) @property def device(self) -> Collection[torch.device]: """Property to retrieve the device of the traced proxy value or real value. Returns: Union[torch.Size,Collection[torch.device]]: Proxy value device or collection of devices. """ if not self.node.attached: return util.apply(self.value, lambda x: x.device, torch.Tensor) # If we haven't scanned in a proxy_value, just return a proxy to get the attribute. if self.node.fake_value is inspect._empty: return super().__getattr__("device") return util.apply(self.node.fake_value, lambda x: x.device, torch.Tensor) @property def dtype(self) -> Collection[torch.device]: """Property to retrieve the dtype of the traced proxy value or real value. Returns: Union[torch.Size,Collection[torch.dtype]]: Proxy value dtype or collection of dtypes. """ if not self.node.attached: return util.apply(self.value, lambda x: x.dtype, torch.Tensor) # If we haven't scanned in a proxy_value, just return a proxy to get the attribute. if self.node.fake_value is inspect._empty: return super().__getattr__("dtype") return util.apply(self.node.fake_value, lambda x: x.dtype, torch.Tensor) @classmethod def __torch_function__(cls, orig_method, types, args=None, kwargs=None) -> Self: if args is None: args = list() if kwargs is None: kwargs = dict() proxy: Proxy = None def get_proxy(arg): nonlocal proxy proxy = arg util.apply((args, kwargs), get_proxy, Proxy) return proxy.node.create( orig_method, *args, **kwargs, )
InterventionProxyType = TypeVar("InterventionProxyType", bound=InterventionProxy)