Source code for nnsight.intervention.tracing.backwards
from typing import TYPE_CHECKING, Any, Callable
import torch
from ...util import Patch
from ..interleaver import Interleaver, Mediator
from .invoker import Invoker
[docs]
def wrap_grad(interleaver: Interleaver):
"""
Create a hook for gradient intervention.
Returns:
A function that can be used to intercept gradients
"""
def wrap(tensor: torch.Tensor):
# Only wrap the tensor once
if tensor._backward_hooks:
return
# We are providing the grad of the tensor
provider = id(tensor)
# Well need to remove the hook
hook = None
# On backwards for this tensor
def inner(grad: torch.Tensor):
# Inject the grad value
# Possibly editing it in the process
try:
grad = interleaver.handle(f"{provider}.grad", grad)
finally:
hook.remove()
return grad
# Register the hook
hook = tensor.register_hook(inner)
def getter(tensor: torch.Tensor):
wrap(tensor)
requester = id(tensor)
return interleaver.current.request(f"{requester}.grad")
def setter(tensor: torch.Tensor, value: torch.Tensor):
wrap(tensor)
requester = id(tensor)
return interleaver.current.swap(f"{requester}.grad", value)
return property(getter, setter)
[docs]
class BackwardsTracer(Invoker):
def __init__(
self,
tensor: torch.Tensor,
fn: Callable,
*args,
**kwargs,
):
super().__init__(None, *args, **kwargs)
self.tensor = tensor
self.fn = fn
[docs]
def execute(self, fn: Callable):
mediator = BackwardsMediator(fn, self.info)
interleaver = Interleaver([mediator], self)
grad_patch = Patch(torch.Tensor, wrap_grad(interleaver), "grad")
try:
grad_patch.patch()
with interleaver:
self.fn(self.tensor, *self.args, **self.kwargs)
interleaver.check_dangling_mediators()
finally:
grad_patch.restore()
interleaver.cancel()