Iterative Interventions#

A simple for loop with a Tracer context inside it results in many intervention graphs created over and over at each iteration - this is not scalable.

The Iterator context allows us to run an intervention loop at scale. It iteratively executes and updates a single intervention graph.

Use a session to define the Iterator context and pass in a sequence of items that you want to loop over at each iteration:

[1]:
import nnsight
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')


with model.session() as session:

  with session.iter([0, 1, 2]) as item:
    # define intervention body here ...

    with model.trace("_"):
      # define interventions here ...
      pass

    with model.trace("_"):
      # define interventions here ...
      pass
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

The Iterator context extends all the nnsight graph-based functionalities, but also closely mimics the conventional for loop statement in Python, which allows it to support all kind of iterative operations with a use of as item syntax:

[2]:

with model.session() as session: li = nnsight.list() # an NNsight built-in list object [li.append([num]) for num in range(0, 3)] # adding [0], [1], [2] to the list li2 = nnsight.list().save() # You can create nested Iterator contexts with session.iter(li) as item: with session.iter(item) as item_2: li2.append(item_2) print("\nList: ", li2)

List:  [0, 1, 2]

We can also expose the iterator context object via a return_context flag. You can then use it to exit out of the Iteration loop early and log the intermediate outputs within the loop:

[3]:
with model.session() as session:

  with session.iter([0, 1, 2, 3], return_context=True) as (item, iterator):

      iterator.log(item)

      with iterator.cond(item == 2):
        iterator.exit()

0
1
2