Streaming#

Streaming enables users apply functions and datasets locally during remote model execution. This allows users to stream results for immediate consumption (i.e., seeing tokens as they are generated) or applying non-whitelisted functions such as model tokenizers, large local datasets, and more!

  • nnsight.local() context sends values immediately to user’s local machine from server

  • Intervention graph is executed locally on downstream nodes

  • Exiting local context uploads data back to server

  • @nnsight.trace function decorator enables custom functions to be added to intervention graph when using nnsight.local()

Setup#

[2]:
# if running in Google Colab, install nnsight
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight

nnsight.local()#

You may sometimes want to locally access and manipulate values during remote execution. Using .local() on a proxy, you can send remote content to your local machine and apply local functions. The intervention graph is then executed locally on downstream nodes (until you send execution back to the remote server by exiting the .local() context).

There are a few use cases for streaming with .local(), including live chat generation and applying large datasets or non-whitelisted local functions to the intervention graph.

Now let’s explore how streaming works. We’ll start by grabbing some hidden states of the model and printing their value using tracer.log(). Without calling nnsight.local(), these operations will all occur remotely.

[4]:
from nnsight import CONFIG
from IPython.display import clear_output

if is_colab:
    # include your HuggingFace Token and NNsight API key on Colab secrets
    from google.colab import userdata
    NDIF_API = userdata.get('NDIF_API')
    HF_TOKEN = userdata.get('HF_TOKEN')

    CONFIG.set_default_api_key(NDIF_API)
    !huggingface-cli login -token HF_TOKEN

clear_output()
[5]:
from nnsight import LanguageModel
llama = LanguageModel("meta-llama/Meta-Llama-3.1-70B")
[6]:
# This will give you a remote LOG response because it's coming from the remote server
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)
2025-03-17 15:10:19,417 270a6e4d-53a5-4929-9f1d-d82fdef7292d - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:20,756 270a6e4d-53a5-4929-9f1d-d82fdef7292d - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:21,638 270a6e4d-53a5-4929-9f1d-d82fdef7292d - RUNNING: Your job has started running.
2025-03-17 15:10:23,319 270a6e4d-53a5-4929-9f1d-d82fdef7292d - LOG: tensor(5.4688, device='cuda:2')
2025-03-17 15:10:25,060 270a6e4d-53a5-4929-9f1d-d82fdef7292d - COMPLETED: Your job has been completed.
Downloading result:   0%|          | 0.00/514k [00:00<?, ?B/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 1.92MB/s]
tensor([[[ 6.3750,  8.6250, 13.0000,  ..., -4.1562, -4.1562, -4.1562],
         [10.5000,  2.6406,  4.7812,  ..., -8.8750, -8.8750, -8.8750]]],
       dtype=torch.bfloat16)

[7]:
import nnsight
# This will print locally because it's already local
with llama.trace("hello", remote=True) as tracer:

    with nnsight.local():
        hs = llama.model.layers[-1].output[0]
        tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)
2025-03-17 15:10:42,787 1f677938-114a-4efe-8b5b-eabbb6090ebf - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:43,386 1f677938-114a-4efe-8b5b-eabbb6090ebf - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:43,690 1f677938-114a-4efe-8b5b-eabbb6090ebf - RUNNING: Your job has started running.
tensor(5.4688, dtype=torch.bfloat16)
2025-03-17 15:10:44,819 1f677938-114a-4efe-8b5b-eabbb6090ebf - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 1.77MB/s]
tensor([[[ 6.3750,  8.6250, 13.0000,  ..., -4.1562, -4.1562, -4.1562],
         [10.5000,  2.6406,  4.7812,  ..., -8.8750, -8.8750, -8.8750]]],
       dtype=torch.bfloat16)

@nnsight.trace function decorator#

We can also use function decorators to create custom functions to be used during .local calls. This is a handy way to enable live streaming of a chat or to train probing classifiers on model hidden states.

Let’s try out @nnsight.trace and nnsight.local() to access a custom function during remote execution.

[8]:
# first, let's define our function
@nnsight.trace # decorator that enables this function to be added to the intervention graph
def my_local_fn(value):
    return value * 0

# We use a local function to ablate some hidden states
# This downloads the data for the .local context, and then uploads it back to set the value.
with llama.generate("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    with nnsight.local():

        hs = my_local_fn(hs)

    llama.model.layers[-1].output[0][:] = hs

    out =  llama.lm_head.output.save()
2025-03-17 15:10:50,961 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:54,240 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:54,244 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - RUNNING: Your job has started running.
2025-03-17 15:10:54,842 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 1.14MB/s]

Note that without calling .local, the remote API does not know about my_local_fn and will throw a whitelist error. A whitelist error occurs because you are being allowed access to the function.

[9]:
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    hs = my_local_fn(hs) # no .local - will cause an error

    llama.model.layers[-1].output[0][:] = hs * 2

    out =  llama.lm_head.output.save()

print(out)
---------------------------------------------------------------------------
FunctionWhitelistError                    Traceback (most recent call last)
Cell In[9], line 1
----> 1 with llama.trace("hello", remote=True) as tracer:
      3     hs = llama.model.layers[-1].output[0]
      5     hs = my_local_fn(hs) # no .local - will cause an error

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
     92     self.invoker.__exit__(None, None, None)
     94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     21 from .globals import GlobalTracingContext
     23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/base.py:82, in Context.__exit__(self, exc_type, exc_val, exc_tb)
     78 graph = graph.stack.pop()
     80 graph.alive = False
---> 82 self.backend(graph)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:77, in RemoteBackend.__call__(self, graph)
     72 def __call__(self, graph: Graph):
     74     if self.blocking:
     75
     76         # Do blocking request.
---> 77         result = self.blocking_request(graph)
     79     else:
     80
     81         # Otherwise we are getting the status / result of the existing job.
     82         result = self.non_blocking_request(graph)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:289, in RemoteBackend.blocking_request(self, graph)
    280 sio.connect(
    281     self.ws_address,
    282     socketio_path="/ws/socket.io",
    283     transports=["websocket"],
    284     wait_timeout=10,
    285 )
    287 remote_graph = preprocess(graph)
--> 289 data, headers = self.request(remote_graph)
    291 headers["session_id"] = sio.sid
    293 # Submit request via

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:60, in RemoteBackend.request(self, graph)
     58 def request(self, graph: Graph) -> Tuple[bytes, Dict[str, str]]:
---> 60     data = RequestModel.serialize(graph, self.format, self.zlib)
     62     headers = {
     63         "model_key": self.model_key,
     64         "format": self.format,
   (...)
     67         "sent-timestamp": str(time.time()),
     68     }
     70     return data, headers

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/request.py:43, in RequestModel.serialize(graph, format, _zlib)
     38 @staticmethod
     39 def serialize(graph: Graph, format:str, _zlib:bool) -> bytes:
     41     if format == "json":
---> 43         data = RequestModel(graph=graph)
     45         json = data.model_dump(mode="json")
     47         data = msgspec.json.encode(json)

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/request.py:30, in RequestModel.__init__(self, memo, *args, **kwargs)
     28 def __init__(self, *args, memo: Dict = None, **kwargs):
---> 30     super().__init__(*args, memo=memo or dict(), **kwargs)
     32     if memo is None:
     34         self.memo = {**MEMO}

    [... skipping hidden 1 frame]

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:276, in GraphModel.to_model(value)
    273 @staticmethod
    274 def to_model(value: Graph) -> Self:
--> 276     return GraphModel(graph=value, nodes=value.nodes)

    [... skipping hidden 1 frame]

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:77, in memoized.<locals>.inner(value)
     75 def inner(value):
---> 77     model = fn(value)
     79     _id = id(value)
     81     MEMO[_id] = model

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:101, in NodeModel.to_model(value)
     97 @staticmethod
     98 @memoized
     99 def to_model(value: Node) -> Self:
--> 101     return NodeModel(target=value.target, args=value.args, kwargs=value.kwargs)

    [... skipping hidden 1 frame]

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:244, in FunctionModel.to_model(value)
    239 @staticmethod
    240 def to_model(value:FUNCTION):
    242     model = FunctionModel(function_name=get_function_name(value))
--> 244     FunctionModel.check_function_whitelist(model.function_name)
    246     return model

File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:251, in FunctionModel.check_function_whitelist(cls, qualname)
    248 @classmethod
    249 def check_function_whitelist(cls, qualname: str) -> str:
    250     if qualname not in FUNCTIONS_WHITELIST:
--> 251         raise FunctionWhitelistError(
    252             f"Function with name `{qualname}` not in function whitelist."
    253         )
    255     return qualname

FunctionWhitelistError: Function with name `__main__.my_local_fn` not in function whitelist.

Example: Live-streaming remote chat#

Now that we can access data within the tracing context on our local computer, we can apply non-whitelisted functions, such as the model’s tokenizer, within our tracing context.

Let’s build a decoding function that will decode tokens into words and print the result.

[19]:
@nnsight.trace
def my_decoding_function(tokens, model, max_length=80, state=None):
    # Initialize state if not provided
    if state is None:
        state = {'current_line': '', 'current_line_length': 0}

    token = tokens[-1] # only use last token

    # Decode the token
    decoded_token = llama.tokenizer.decode(token).encode("unicode_escape").decode()

    if (decoded_token == '\\n') or (decoded_token == '\n'):  # Handle explicit newline tokens
        # Print the current line and reset state
        print('',flush=True)
        state['current_line'] = ''
        state['current_line_length'] = 0
    else:
        # Check if adding the token would exceed the max length
        if state['current_line_length'] + len(decoded_token) > max_length:
            print('',flush=True)
            state['current_line'] = decoded_token  # Start a new line with the current token
            state['current_line_length'] = len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line
        else:
            # Add a space if the line isn't empty and append the token
            if state['current_line']:
                state['current_line'] += decoded_token
            else:
                state['current_line'] = decoded_token
            state['current_line_length'] += len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line

    return state

Now we can decode and print our model outputs throughout token generation by accessing our decoding function through nnsight.local().

[23]:
import torch

nnsight.CONFIG.APP.REMOTE_LOGGING = False

prompt = "A press release is an official statement delivered to members of the news media for the purpose of"
prompt = "The Eiffel Tower is in the city of"

print("Prompt: ",prompt,'\n', end ="")

# Initialize the state for decoding
state = {'current_line': '', 'current_line_length': 0}

with llama.generate(prompt, remote=True, max_new_tokens = 20) as generator:
    # Call .all() to apply to each new token
    llama.all()

    all_tokens = nnsight.list().save()

    # Access model output
    out = llama.lm_head.output.save()

    # Apply softmax to obtain probabilities and save the result
    probs = torch.nn.functional.softmax(out, dim=-1)
    max_probs = torch.max(probs, dim=-1)
    tokens = max_probs.indices.cpu().tolist()
    all_tokens.append(tokens[0]).save()

    with nnsight.local():
        state = my_decoding_function(tokens[0], llama, max_length=12, state=state)
Prompt:  The Eiffel Tower is in the city of
 Paris, France. It is a very famous landmark. It is built in 1889. the
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 1.01MB/s]