Streaming#
Streaming enables users apply functions and datasets locally during remote model execution. This allows users to stream results for immediate consumption (i.e., seeing tokens as they are generated) or applying non-whitelisted functions such as model tokenizers, large local datasets, and more!
nnsight.local()
context sends values immediately to user’s local machine from serverIntervention graph is executed locally on downstream nodes
Exiting local context uploads data back to server
@nnsight.trace
function decorator enables custom functions to be added to intervention graph when usingnnsight.local()
Setup#
[2]:
# if running in Google Colab, install nnsight
try:
import google.colab
is_colab = True
except ImportError:
is_colab = False
if is_colab:
!pip install -U nnsight
nnsight.local()
#
You may sometimes want to locally access and manipulate values during remote execution. Using .local()
on a proxy, you can send remote content to your local machine and apply local functions. The intervention graph is then executed locally on downstream nodes (until you send execution back to the remote server by exiting the .local()
context).
There are a few use cases for streaming with .local()
, including live chat generation and applying large datasets or non-whitelisted local functions to the intervention graph.
Now let’s explore how streaming works. We’ll start by grabbing some hidden states of the model and printing their value using tracer.log()
. Without calling nnsight.local()
, these operations will all occur remotely.
[4]:
from nnsight import CONFIG
from IPython.display import clear_output
if is_colab:
# include your HuggingFace Token and NNsight API key on Colab secrets
from google.colab import userdata
NDIF_API = userdata.get('NDIF_API')
HF_TOKEN = userdata.get('HF_TOKEN')
CONFIG.set_default_api_key(NDIF_API)
!huggingface-cli login -token HF_TOKEN
clear_output()
[5]:
from nnsight import LanguageModel
llama = LanguageModel("meta-llama/Meta-Llama-3.1-70B")
[6]:
# This will give you a remote LOG response because it's coming from the remote server
with llama.trace("hello", remote=True) as tracer:
hs = llama.model.layers[-1].output[0]
tracer.log(hs[0,0,0])
out = llama.lm_head.output.save()
print(out)
2025-03-17 15:10:19,417 270a6e4d-53a5-4929-9f1d-d82fdef7292d - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:20,756 270a6e4d-53a5-4929-9f1d-d82fdef7292d - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:21,638 270a6e4d-53a5-4929-9f1d-d82fdef7292d - RUNNING: Your job has started running.
2025-03-17 15:10:23,319 270a6e4d-53a5-4929-9f1d-d82fdef7292d - LOG: tensor(5.4688, device='cuda:2')
2025-03-17 15:10:25,060 270a6e4d-53a5-4929-9f1d-d82fdef7292d - COMPLETED: Your job has been completed.
Downloading result: 0%| | 0.00/514k [00:00<?, ?B/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 1.92MB/s]
tensor([[[ 6.3750, 8.6250, 13.0000, ..., -4.1562, -4.1562, -4.1562],
[10.5000, 2.6406, 4.7812, ..., -8.8750, -8.8750, -8.8750]]],
dtype=torch.bfloat16)
[7]:
import nnsight
# This will print locally because it's already local
with llama.trace("hello", remote=True) as tracer:
with nnsight.local():
hs = llama.model.layers[-1].output[0]
tracer.log(hs[0,0,0])
out = llama.lm_head.output.save()
print(out)
2025-03-17 15:10:42,787 1f677938-114a-4efe-8b5b-eabbb6090ebf - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:43,386 1f677938-114a-4efe-8b5b-eabbb6090ebf - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:43,690 1f677938-114a-4efe-8b5b-eabbb6090ebf - RUNNING: Your job has started running.
tensor(5.4688, dtype=torch.bfloat16)
2025-03-17 15:10:44,819 1f677938-114a-4efe-8b5b-eabbb6090ebf - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 514k/514k [00:00<00:00, 1.77MB/s]
tensor([[[ 6.3750, 8.6250, 13.0000, ..., -4.1562, -4.1562, -4.1562],
[10.5000, 2.6406, 4.7812, ..., -8.8750, -8.8750, -8.8750]]],
dtype=torch.bfloat16)
@nnsight.trace
function decorator#
We can also use function decorators to create custom functions to be used during .local
calls. This is a handy way to enable live streaming of a chat or to train probing classifiers on model hidden states.
Let’s try out @nnsight.trace
and nnsight.local()
to access a custom function during remote execution.
[8]:
# first, let's define our function
@nnsight.trace # decorator that enables this function to be added to the intervention graph
def my_local_fn(value):
return value * 0
# We use a local function to ablate some hidden states
# This downloads the data for the .local context, and then uploads it back to set the value.
with llama.generate("hello", remote=True) as tracer:
hs = llama.model.layers[-1].output[0]
with nnsight.local():
hs = my_local_fn(hs)
llama.model.layers[-1].output[0][:] = hs
out = llama.lm_head.output.save()
2025-03-17 15:10:50,961 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - RECEIVED: Your job has been received and is waiting approval.
2025-03-17 15:10:54,240 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - APPROVED: Your job was approved and is waiting to be run.
2025-03-17 15:10:54,244 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - RUNNING: Your job has started running.
2025-03-17 15:10:54,842 bf34e807-2ae2-4ab4-be77-3ac0836c7e28 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 1.14MB/s]
Note that without calling .local
, the remote API does not know about my_local_fn
and will throw a whitelist error. A whitelist error occurs because you are being allowed access to the function.
[9]:
with llama.trace("hello", remote=True) as tracer:
hs = llama.model.layers[-1].output[0]
hs = my_local_fn(hs) # no .local - will cause an error
llama.model.layers[-1].output[0][:] = hs * 2
out = llama.lm_head.output.save()
print(out)
---------------------------------------------------------------------------
FunctionWhitelistError Traceback (most recent call last)
Cell In[9], line 1
----> 1 with llama.trace("hello", remote=True) as tracer:
3 hs = llama.model.layers[-1].output[0]
5 hs = my_local_fn(hs) # no .local - will cause an error
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/contexts/interleaving.py:96, in InterleavingTracer.__exit__(self, exc_type, exc_val, exc_tb)
92 self.invoker.__exit__(None, None, None)
94 self._model._envoy._reset()
---> 96 super().__exit__(exc_type, exc_val, exc_tb)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/tracer.py:25, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
21 from .globals import GlobalTracingContext
23 GlobalTracingContext.try_deregister(self)
---> 25 return super().__exit__(exc_type, exc_val, exc_tb)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/tracing/contexts/base.py:82, in Context.__exit__(self, exc_type, exc_val, exc_tb)
78 graph = graph.stack.pop()
80 graph.alive = False
---> 82 self.backend(graph)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:77, in RemoteBackend.__call__(self, graph)
72 def __call__(self, graph: Graph):
74 if self.blocking:
75
76 # Do blocking request.
---> 77 result = self.blocking_request(graph)
79 else:
80
81 # Otherwise we are getting the status / result of the existing job.
82 result = self.non_blocking_request(graph)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:289, in RemoteBackend.blocking_request(self, graph)
280 sio.connect(
281 self.ws_address,
282 socketio_path="/ws/socket.io",
283 transports=["websocket"],
284 wait_timeout=10,
285 )
287 remote_graph = preprocess(graph)
--> 289 data, headers = self.request(remote_graph)
291 headers["session_id"] = sio.sid
293 # Submit request via
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:60, in RemoteBackend.request(self, graph)
58 def request(self, graph: Graph) -> Tuple[bytes, Dict[str, str]]:
---> 60 data = RequestModel.serialize(graph, self.format, self.zlib)
62 headers = {
63 "model_key": self.model_key,
64 "format": self.format,
(...)
67 "sent-timestamp": str(time.time()),
68 }
70 return data, headers
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/request.py:43, in RequestModel.serialize(graph, format, _zlib)
38 @staticmethod
39 def serialize(graph: Graph, format:str, _zlib:bool) -> bytes:
41 if format == "json":
---> 43 data = RequestModel(graph=graph)
45 json = data.model_dump(mode="json")
47 data = msgspec.json.encode(json)
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/request.py:30, in RequestModel.__init__(self, memo, *args, **kwargs)
28 def __init__(self, *args, memo: Dict = None, **kwargs):
---> 30 super().__init__(*args, memo=memo or dict(), **kwargs)
32 if memo is None:
34 self.memo = {**MEMO}
[... skipping hidden 1 frame]
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:276, in GraphModel.to_model(value)
273 @staticmethod
274 def to_model(value: Graph) -> Self:
--> 276 return GraphModel(graph=value, nodes=value.nodes)
[... skipping hidden 1 frame]
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:77, in memoized.<locals>.inner(value)
75 def inner(value):
---> 77 model = fn(value)
79 _id = id(value)
81 MEMO[_id] = model
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:101, in NodeModel.to_model(value)
97 @staticmethod
98 @memoized
99 def to_model(value: Node) -> Self:
--> 101 return NodeModel(target=value.target, args=value.args, kwargs=value.kwargs)
[... skipping hidden 1 frame]
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:244, in FunctionModel.to_model(value)
239 @staticmethod
240 def to_model(value:FUNCTION):
242 model = FunctionModel(function_name=get_function_name(value))
--> 244 FunctionModel.check_function_whitelist(model.function_name)
246 return model
File /opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/nnsight/schema/format/types.py:251, in FunctionModel.check_function_whitelist(cls, qualname)
248 @classmethod
249 def check_function_whitelist(cls, qualname: str) -> str:
250 if qualname not in FUNCTIONS_WHITELIST:
--> 251 raise FunctionWhitelistError(
252 f"Function with name `{qualname}` not in function whitelist."
253 )
255 return qualname
FunctionWhitelistError: Function with name `__main__.my_local_fn` not in function whitelist.
Example: Live-streaming remote chat#
Now that we can access data within the tracing context on our local computer, we can apply non-whitelisted functions, such as the model’s tokenizer, within our tracing context.
Let’s build a decoding function that will decode tokens into words and print the result.
[19]:
@nnsight.trace
def my_decoding_function(tokens, model, max_length=80, state=None):
# Initialize state if not provided
if state is None:
state = {'current_line': '', 'current_line_length': 0}
token = tokens[-1] # only use last token
# Decode the token
decoded_token = llama.tokenizer.decode(token).encode("unicode_escape").decode()
if (decoded_token == '\\n') or (decoded_token == '\n'): # Handle explicit newline tokens
# Print the current line and reset state
print('',flush=True)
state['current_line'] = ''
state['current_line_length'] = 0
else:
# Check if adding the token would exceed the max length
if state['current_line_length'] + len(decoded_token) > max_length:
print('',flush=True)
state['current_line'] = decoded_token # Start a new line with the current token
state['current_line_length'] = len(decoded_token)
print(state['current_line'], flush=True, end="") # Print the current line
else:
# Add a space if the line isn't empty and append the token
if state['current_line']:
state['current_line'] += decoded_token
else:
state['current_line'] = decoded_token
state['current_line_length'] += len(decoded_token)
print(state['current_line'], flush=True, end="") # Print the current line
return state
Now we can decode and print our model outputs throughout token generation by accessing our decoding function through nnsight.local()
.
[23]:
import torch
nnsight.CONFIG.APP.REMOTE_LOGGING = False
prompt = "A press release is an official statement delivered to members of the news media for the purpose of"
prompt = "The Eiffel Tower is in the city of"
print("Prompt: ",prompt,'\n', end ="")
# Initialize the state for decoding
state = {'current_line': '', 'current_line_length': 0}
with llama.generate(prompt, remote=True, max_new_tokens = 20) as generator:
# Call .all() to apply to each new token
llama.all()
all_tokens = nnsight.list().save()
# Access model output
out = llama.lm_head.output.save()
# Apply softmax to obtain probabilities and save the result
probs = torch.nn.functional.softmax(out, dim=-1)
max_probs = torch.max(probs, dim=-1)
tokens = max_probs.indices.cpu().tolist()
all_tokens.append(tokens[0]).save()
with nnsight.local():
state = my_decoding_function(tokens[0], llama, max_length=12, state=state)
Prompt: The Eiffel Tower is in the city of
Paris, France. It is a very famous landmark. It is built in 1889. the
Downloading result: 100%|██████████| 258k/258k [00:00<00:00, 1.01MB/s]