Iterative Interventions#
A simple for
loop with a Tracer context inside it results in many intervention graphs created over and over at each iteration - this is not scalable.
The Iterator context allows us to run an intervention loop at scale. It iteratively executes and updates a single intervention graph.
Use a session
to define the Iterator context and pass in a sequence of items that you want to loop over at each iteration:
[1]:
import nnsight
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto')
with model.session() as session:
with session.iter([0, 1, 2]) as item:
# define intervention body here ...
with model.trace("_"):
# define interventions here ...
pass
with model.trace("_"):
# define interventions here ...
pass
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/opt/homebrew/anaconda3/envs/nnsight_local/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The Iterator
context extends all the nnsight
graph-based functionalities, but also closely mimics the conventional for
loop statement in Python, which allows it to support all kind of iterative operations with a use of as item
syntax:
[2]:
with model.session() as session:
li = nnsight.list() # an NNsight built-in list object
[li.append([num]) for num in range(0, 3)] # adding [0], [1], [2] to the list
li2 = nnsight.list().save()
# You can create nested Iterator contexts
with session.iter(li) as item:
with session.iter(item) as item_2:
li2.append(item_2)
print("\nList: ", li2)
List: [0, 1, 2]
We can also expose the iterator
context object via a return_context
flag. You can then use it to exit
out of the Iteration loop early and log the intermediate outputs within the loop:
[3]:
with model.session() as session:
with session.iter([0, 1, 2, 3], return_context=True) as (item, iterator):
iterator.log(item)
with iterator.cond(item == 2):
iterator.exit()
0
1
2