Access LLMs with NDIF and NNsight#
NDIF is an inference service hosting large open-weight LLMs for use by researchers.
NNsight is a package for interpreting and manipulating internals of deep learning models.
Together, NDIF and NNsight work hand in hand to let researchers run complex experiments on huge open models easily with full transparent access.
Run an interactive version of this walkthrough in Google Colab
Install NNsight#
To start using NNsight, you can install it via pip
.
[ ]:
!pip install nnsight
from IPython.display import clear_output
clear_output()
Sign up for NDIF remote model access#
In order to remotely access LLMs through NDIF, users must sign up for an NDIF API key.
Register herefor a free API key!#
Once you have a valid NDIF API key, you then can configure nnsight
by doing the following:
[ ]:
from nnsight import CONFIG
CONFIG.API.APIKEY = input("Enter your API key: ")
clear_output()
More about API key configuration
The above code saves your API key as the default in a config file along with the nnsight
installation. If you’re running this walkthrough using a local Python installation, this only needs to be run once. If you’re using Colab, we recommend saving your API key as a Colab Secret, and configuring it as follows in your notebooks:
from nnsight import CONFIG
if is_colab:
# include your NNsight API key on Colab secrets
from google.colab import userdata
NDIF_API = userdata.get('NDIF_API')
CONFIG.set_default_api_key(NDIF_API)
Choose a Model#
NDIF hosts multiple LLMs, including various sizes of the Llama 3.1 models and DeepSeek-R1 models. You can view the full list of hosted models onour status page. All of our models are open for public use, except you need to apply for access to the Llama-3.1-405B models.
Apply for 405B access
If you have a clear research need for Llama-3.1-405B and would like more details about applying for access, please refer to this page!
For these exercises, we will explore how we can access and modify the Llama-3.1-70B model’s internal states. This 70-billion-parameter model is about the maximum size that you could run on a single A100 GPU with 80GB of VRAM, but we are going to access it remotely on NDIF resources, so you can run it on Colab or your laptop computer!
Note: Llama models are gated on HuggingFace
Llama models are gated and require you to register for access via HuggingFace. Check out their website for more information about registration with Meta.
If you are using a local Python installation, you can activate your HuggingFace token using the terminal:
huggingface-cli login -token YOUR_HF_TOKEN
If you are using Colab, you can add your HuggingFace token to your Secrets.
We will be using the LanguageModel
subclass of NNsight to load in the Llama-3.1-70B model and access its internal states.
About NNsight LanguageModel
The LanguageModel
subclass of NNsight is a wrapper that includes special support for HuggingFace language models, including automatically loading models from a HuggingFace ID together with the appropriate tokenizer.
This way there’s no need to pretokenize your input, and instead you can just pass a string as an input!
Note: ``LanguageModel`` models also accept tokenized inputs, includingchat templates.
[ ]:
# instantiate the model using the LanguageModel class
from nnsight import LanguageModel
# don't worry, this won't load locally!
llm = LanguageModel("meta-llama/Meta-Llama-3.1-70B", device_map="auto")
print(llm)
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8192)
(layers): ModuleList(
(0-79): 80 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=1024, bias=False)
(v_proj): Linear(in_features=8192, out_features=1024, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
(up_proj): Linear(in_features=8192, out_features=28672, bias=False)
(down_proj): Linear(in_features=28672, out_features=8192, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((8192,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8192, out_features=128256, bias=False)
(generator): Generator(
(streamer): Streamer()
)
)
Access model internals#
Now that we’ve installed nnsight
, configured our API key, and instantiated a model, we can run an experiment.
For this experiment, let’s try grabbing some of the LLM’s hidden states using nnsight
’s tracing context, .trace()
.
Entering the tracing context allows us to customize how a neural network runs. By calling .trace()
, we are telling the model to run with a given input and to collect and/or modify the internal model states based on user-defined code within the tracing context. We can also specify that we want to use an NDIF-hosted model instead of executing locally by setting remote=True
.
To get started, let’s ask NNsight to collect the layer output (known as “logits”) at the final layer, along with the overall model output. NNsight needs to know what specific parts of the model we’re interested in accessing later, so we need to specify which elements we’d like to save after exiting the tracing context using .save()
.
Note: You will not be able to access any values defined within a ``.trace()`` that aren’t saved with ``.save()`` after exiting the tracing context!
[ ]:
# remote = True means the model will execute on NDIF's shared resources
with llm.trace("The Eiffel Tower is in the city of", remote=True):
# user-defined code to access internal model components
hidden_states = llm.model.layers[-1].output[0].save()
output = llm.output.save()
# after exiting the tracing context, we can access any values that were saved
print("Hidden State Logits: ",hidden_states[0])
output_logits = output["logits"]
print("Model Output Logits: ",output_logits[0])
# decode the final model output from output logits
max_probs, tokens = output_logits[0].max(dim=-1)
word = [llm.tokenizer.decode(tokens.cpu()[-1])]
print("Model Output: ", word[0])
2025-02-26 16:30:19,122 549c7e98-7fa2-4f82-8678-11f29098e96e - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:549c7e98-7fa2-4f82-8678-11f29098e96e - RECEIVED: Your job has been received and is waiting approval.
2025-02-26 16:30:19,429 549c7e98-7fa2-4f82-8678-11f29098e96e - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:549c7e98-7fa2-4f82-8678-11f29098e96e - APPROVED: Your job was approved and is waiting to be run.
2025-02-26 16:30:19,709 549c7e98-7fa2-4f82-8678-11f29098e96e - RUNNING: Your job has started running.
INFO:nnsight_remote:549c7e98-7fa2-4f82-8678-11f29098e96e - RUNNING: Your job has started running.
2025-02-26 16:30:20,948 549c7e98-7fa2-4f82-8678-11f29098e96e - COMPLETED: Your job has been completed.
INFO:nnsight_remote:549c7e98-7fa2-4f82-8678-11f29098e96e - COMPLETED: Your job has been completed.
Hidden State Logits: tensor([[ 5.4688, -4.9062, 2.2188, ..., -3.6875, 0.9492, 1.2578],
[ 1.5469, -0.6250, -1.4531, ..., -1.1562, -0.1328, -2.1250],
[ 1.7969, -1.8828, -1.1875, ..., 0.1719, 0.9531, 0.5586],
...,
[ 0.9531, -0.3906, 1.3594, ..., 1.3984, -0.8086, -1.9297],
[-0.8906, 0.3691, 0.2578, ..., 2.4688, -0.4531, -0.6641],
[-1.6016, 1.0703, 1.7188, ..., 1.8594, -1.1328, -0.4922]],
dtype=torch.bfloat16)
Model Output Logits: tensor([[ 6.3750, 8.6250, 13.0000, ..., -4.1562, -4.1562, -4.1562],
[-2.8281, -2.2344, -3.0938, ..., -8.6250, -8.6250, -8.6250],
[ 8.9375, 3.6094, 4.5312, ..., -3.9375, -3.9375, -3.9375],
...,
[ 3.6250, 3.5000, 0.1455, ..., -6.5938, -6.5938, -6.5938],
[10.8750, 6.4062, 4.9375, ..., -3.9844, -3.9844, -3.9844],
[ 7.3125, 6.2188, 3.5781, ..., -4.7188, -4.7188, -4.7188]],
dtype=torch.bfloat16)
Model Output: Paris
What are we seeing here? NNsight tells you if your job is recieved, approved, running, or completed via logs.
Disabling remote logging notifications If you prefer, you can disable NNsight remote logging notifications with the following code, although they can help troubleshoot any network issues.
from nnsight import CONFIG
CONFIG.APP.REMOTE_LOGGING = False
If you’d like to turn them back on, just set REMOTE_LOGGING = True
:
from nnsight import CONFIG
CONFIG.APP.REMOTE_LOGGING = True
We are also seeing our printed results. After exiting the tracing context, NNsight downloads the saved results, which we can perform operations on using Python code. Pretty simple!
Alter model internals#
Now that we’ve accessed the internal layers of the model, let’s try modifying them and see how it affects the output!
We can do this using in-place operations in NNsight, which alter the model’s state during execution. Let’s try changing the output of layer 8 to be equal to 4.
[ ]:
# remote = True means the model will execute on NDIF's shared resources
with llm.trace("The Eiffel Tower is in the city of", remote=True):
# user-defined code to access internal model components
llm.model.layers[7].output[0][:] = 4 # in-place operation to change a single layer's output values
output = llm.output.save()
# after exiting the tracing context, we can access any values that were saved
output_logits = output["logits"]
print("Model Output Logits: ",output_logits[0])
# decode the final model output from output logits
max_probs, tokens = output_logits[0].max(dim=-1)
word = [llm.tokenizer.decode(tokens.cpu()[-1])]
print("Model Output: ", word[0])
2025-02-26 16:50:08,852 ccd16638-811d-429a-8584-a371825430db - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:ccd16638-811d-429a-8584-a371825430db - RECEIVED: Your job has been received and is waiting approval.
2025-02-26 16:50:10,618 ccd16638-811d-429a-8584-a371825430db - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:ccd16638-811d-429a-8584-a371825430db - APPROVED: Your job was approved and is waiting to be run.
2025-02-26 16:50:11,496 ccd16638-811d-429a-8584-a371825430db - RUNNING: Your job has started running.
INFO:nnsight_remote:ccd16638-811d-429a-8584-a371825430db - RUNNING: Your job has started running.
2025-02-26 16:50:13,316 ccd16638-811d-429a-8584-a371825430db - COMPLETED: Your job has been completed.
INFO:nnsight_remote:ccd16638-811d-429a-8584-a371825430db - COMPLETED: Your job has been completed.
Model Output Logits: tensor([[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086],
[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086],
[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086],
...,
[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086],
[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086],
[-1.4766, 0.9492, -0.9688, ..., -0.3086, -0.3086, -0.3086]],
dtype=torch.bfloat16)
Model Output: Bounty
Okay! The output for “The Eiffel Tower is in the city of” is now “Bounty”. Looks like our intervention on the hidden 8th layer worked to change the model output!
Are you ready for something a little more complicated? Let’s take the model’s state when answering the city that the London Bridge is in, and swap that into the model’s final layer when answering the Eiffel Tower question! We can do this using NNsight’s invoking contexts, which batch different inputs into the same run through the model.
We can access values defined in invoking contexts throughout the other invoke context, allowing us to do something like swapping model states for different inputs. Let’s try it out!
[ ]:
import nnsight
# remote = True means the model will execute on NDIF's shared resources
with llm.trace(remote=True) as tracer:
with tracer.invoke("The London Bridge is in the city of"):
hidden_states = llm.model.layers[-1].output[0] # no .save()
with tracer.invoke("The Eiffel Tower is in the city of"):
# user-defined code to access internal model components
llm.model.layers[-1].output[0][:] = hidden_states # can be accessed without .save()!
output = llm.output.save()
output_logits = output["logits"]
print("Model Output Logits: ",output_logits[0])
# decode the final model output from output logits
max_probs, tokens = output_logits[0].max(dim=-1)
word = [llm.tokenizer.decode(tokens.cpu()[-1])]
print("Model Output: ", word[0])
2025-02-26 16:59:16,986 7b5bed21-30ab-48ed-922c-a0bf6f344a85 - RECEIVED: Your job has been received and is waiting approval.
INFO:nnsight_remote:7b5bed21-30ab-48ed-922c-a0bf6f344a85 - RECEIVED: Your job has been received and is waiting approval.
2025-02-26 16:59:18,078 7b5bed21-30ab-48ed-922c-a0bf6f344a85 - APPROVED: Your job was approved and is waiting to be run.
INFO:nnsight_remote:7b5bed21-30ab-48ed-922c-a0bf6f344a85 - APPROVED: Your job was approved and is waiting to be run.
2025-02-26 16:59:19,011 7b5bed21-30ab-48ed-922c-a0bf6f344a85 - RUNNING: Your job has started running.
INFO:nnsight_remote:7b5bed21-30ab-48ed-922c-a0bf6f344a85 - RUNNING: Your job has started running.
2025-02-26 16:59:21,891 7b5bed21-30ab-48ed-922c-a0bf6f344a85 - COMPLETED: Your job has been completed.
INFO:nnsight_remote:7b5bed21-30ab-48ed-922c-a0bf6f344a85 - COMPLETED: Your job has been completed.
Model Output Logits: tensor([[-1.0859, -1.3203, -0.2852, ..., -0.1680, -0.1689, -0.1680],
[-0.9805, -1.2109, -0.2520, ..., -0.2461, -0.2471, -0.2461],
[ 6.3750, 8.6250, 13.0000, ..., -4.1562, -4.1562, -4.1562],
...,
[ 1.0547, 1.1641, -3.6094, ..., -5.9375, -5.9375, -5.9375],
[10.9375, 6.4062, 4.9688, ..., -4.1562, -4.1562, -4.1562],
[ 7.2500, 4.2188, 2.7812, ..., -5.1562, -5.1562, -5.1562]],
dtype=torch.bfloat16)
Model Output: London
Awesome, looks like it worked! The model output London instead of Paris when asked about the location of the Eiffel Tower.
Next steps: Run your own experiment with NDIF and NNsight#
This is just a quick overview of some of NNsight’s functionality when working with remote models, so to learn more we recommend taking a deeper dive into these resources:
📚 Get a comprehensive overview of the library with the NNsight Walkthrough
🔎 Check out some NNsight implementations of common LLM interpretability techniques
💬 Join the conversation with the NDIF Discord community
Want to scale up your research?Apply for access to Llama-3.1-405B!