Custom Functions#

Everything within the tracing context operates on the intervention graph. Therefore for nnsight to trace a function it must also be a part of the intervention graph.

Out-of-the-box nnsight supports Pytorch functions and methods, all operators, as well the einops library. We don’t need to do anything special to use them.

For custom functions we can use nnsight.apply() to add them to the intervention graph:

[1]:
import nnsight
from nnsight import LanguageModel
import torch

model = LanguageModel('openai-community/gpt2', device_map='auto')

# We define a simple custom function that sums all the elements of a tensor
def tensor_sum(tensor):
    flat = tensor.flatten()
    total = 0
    for element in flat:
        total += element.item()

    return torch.tensor(total)

with model.trace("The Eiffel Tower is in the city of") as tracer:

    # Specify the function name and its arguments (in a coma-separated form) to add to the intervention graph
    custom_sum = nnsight.apply(tensor_sum, model.transformer.h[0].output[0]).save()
    sum = model.transformer.h[0].output[0].sum().save()


print("\nPyTorch sum: ", sum)
print("Our sum: ", custom_sum)
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

PyTorch sum:  tensor(191.2442, device='mps:0', grad_fn=<SumBackward0>)
Our sum:  tensor(191.2442)

nnsight.apply() executes the function it wraps and returns its output as a Proxy object. We can then use this Proxy object as we would any other.

The applications of nnsight.apply are wide. It can be used to wrap any custom function or functions from libraries that nnsight does not support out-of-the-box.