vLLM Support¶
Summary¶
vLLM is a high-performance inference
library that uses PagedAttention, dynamic batching, and a custom execution
engine to serve HuggingFace-compatible language models at much higher
throughput than a plain transformers run. NNsight ships with a full vLLM
backend so you can observe and intervene on model internals without giving up
vLLM's speed.
This guide walks through all the ways to use the integration:
- Instantiating the
VLLMwrapper (sync and async) - Reading and writing module activations inside a trace
- Batching multiple prompts with
tracer.invoke()loops - The
logitsandsampleseproperties for inspecting/modifying sampling - Multi-token generation with
tracer.all()andtracer.iter[...] - Efficient activation extraction with
tracer.cache() - Async mode (
mode="async") for streaming per-request saves - Tensor parallelism and the Ray distributed executor for multi-GPU / multi-node
- Known limitations and gotchas
When to use¶
VLLM is the right tool when you need high-throughput generation or want to
run experiments over many prompts with the same interventions.
A few considerations before choosing VLLM over LanguageModel:
- Text generation only. NNsight supports vLLM's text-generation models (list here). Multimodal and image-generation models aren't supported yet.
- vLLM results may differ from transformers. vLLM's custom kernels,
quantization defaults, and batching produce slightly different numerical
outputs than
transformers, so comparisons across the two backends are never exact. - No gradients. vLLM's paged-attention path doesn't support autograd, so
VLLMmodels can't be used for backward-pass experiments. UseLanguageModelif you need gradients. - One prompt per invoke. Unlike
LanguageModel, eachtracer.invoke(...)in a vLLM trace must receive exactly one prompt. Batching is done by putting many invokes inside a single trace — vLLM's scheduler takes care of batching them on the GPU.
Setup¶
Install NNsight alongside vllm and triton. zstandard is used by the
async backend for transport compression.
pip install nnsight "vllm==0.18.0" "triton>=3.1" zstandard
from IPython.display import clear_output
from pprint import pprint
Instantiating a vLLM model¶
from nnsight.modeling.vllm import VLLM
vllm = VLLM(
"openai-community/gpt2",
tensor_parallel_size=1,
gpu_memory_utilization=0.15,
dispatch=True,
)
print(vllm)
/disk/u/jadenfk/miniconda3/envs/nn6/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
INFO 04-14 18:14:28 [parallel_state.py:1395] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:60485 backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 04-14 18:14:29 [model.py:533] Resolved architecture: GPT2LMHeadModel
INFO 04-14 18:14:29 [model.py:1917] Downcasting torch.float32 to torch.bfloat16.
INFO 04-14 18:14:29 [model.py:1582] Using max model len 1024
2026-04-14 18:14:29,538 INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
INFO 04-14 18:14:29 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-14 18:14:29 [vllm.py:754] Asynchronous scheduling is enabled.
WARNING 04-14 18:14:29 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
WARNING 04-14 18:14:29 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
INFO 04-14 18:14:29 [vllm.py:964] Cudagraph is disabled under eager mode
INFO 04-14 18:14:29 [compilation.py:289] Enabled custom fusions: norm_quant, act_quant
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 04-14 18:14:29 [parallel_state.py:1717] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A, EPLB rank N/A
INFO 04-14 18:14:30 [cuda.py:317] Using FLASH_ATTN attention backend out of potential backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION'].
INFO 04-14 18:14:30 [flash_attn.py:598] Using FlashAttention version 2
INFO 04-14 18:14:30 [utils.py:233] non-default args: {'gpu_memory_utilization': 0.15, 'disable_log_stats': True, 'enforce_eager': True, 'worker_cls': 'nnsight.modeling.vllm.workers.GPUWorker.NNsightGPUWorker', 'model': 'openai-community/gpt2'}
INFO 04-14 18:14:30 [model.py:533] Resolved architecture: GPT2LMHeadModel
INFO 04-14 18:14:30 [model.py:1917] Downcasting torch.float32 to torch.bfloat16.
INFO 04-14 18:14:30 [model.py:1582] Using max model len 1024
INFO 04-14 18:14:30 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 04-14 18:14:30 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
WARNING 04-14 18:14:30 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
INFO 04-14 18:14:30 [vllm.py:964] Cudagraph is disabled under eager mode
WARNING 04-14 18:14:31 [system_utils.py:152] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reasons: CUDA is initialized
(EngineCore pid=651098) INFO 04-14 18:14:36 [core.py:103] Initializing a V1 LLM engine (v0.18.0) with config: model='openai-community/gpt2', speculative_config=None, tokenizer='openai-community/gpt2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=openai-community/gpt2, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.NONE: 0>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_endpoints': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore pid=651098) INFO 04-14 18:14:38 [worker_base.py:269] Injected <class 'vllm_lens._worker_ext.HiddenStatesExtension'> into <class 'nnsight.modeling.vllm.workers.GPUWorker.NNsightGPUWorker'> for extended collective_rpc calls ['clear_captured_states', 'clear_steering_data', 'get_captured_states', 'install_hooks', 'set_steering_data']
(EngineCore pid=651098) INFO 04-14 18:14:38 [parallel_state.py:1395] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://10.201.16.108:47357 backend=nccl (EngineCore pid=651098) INFO 04-14 18:14:38 [parallel_state.py:1717] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A, EPLB rank N/A
(EngineCore pid=651098) INFO 04-14 18:14:39 [gpu_model_runner.py:4481] Starting to load model openai-community/gpt2...
(EngineCore pid=651098) INFO 04-14 18:14:39 [cuda.py:317] Using FLASH_ATTN attention backend out of potential backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION']. (EngineCore pid=651098) INFO 04-14 18:14:39 [flash_attn.py:598] Using FlashAttention version 2
(EngineCore pid=651098) INFO 04-14 18:14:39 [weight_utils.py:618] No model.safetensors.index.json found in remote.
(EngineCore pid=651098) Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s] (EngineCore pid=651098) Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 5.28it/s] (EngineCore pid=651098) Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 5.27it/s] (EngineCore pid=651098)
(EngineCore pid=651098) INFO 04-14 18:14:40 [default_loader.py:384] Loading weights took 0.21 seconds
(EngineCore pid=651098) INFO 04-14 18:14:40 [gpu_model_runner.py:4566] Model loading took 0.24 GiB memory and 0.799149 seconds
(EngineCore pid=651098) INFO 04-14 18:14:42 [gpu_worker.py:456] Available KV cache memory: 10.99 GiB (EngineCore pid=651098) INFO 04-14 18:14:42 [kv_cache_utils.py:1316] GPU KV cache size: 320,112 tokens (EngineCore pid=651098) INFO 04-14 18:14:42 [kv_cache_utils.py:1321] Maximum concurrency for 1,024 tokens per request: 312.61x (EngineCore pid=651098) INFO 04-14 18:14:42 [core.py:281] init engine (profile, create kv cache, warmup model) took 1.42 seconds
(EngineCore pid=651098) INFO 04-14 18:14:42 [vllm.py:754] Asynchronous scheduling is enabled. (EngineCore pid=651098) WARNING 04-14 18:14:42 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none (EngineCore pid=651098) WARNING 04-14 18:14:42 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored. (EngineCore pid=651098) INFO 04-14 18:14:42 [vllm.py:964] Cudagraph is disabled under eager mode (EngineCore pid=651098) INFO 04-14 18:14:42 [compilation.py:289] Enabled custom fusions: norm_quant, act_quant INFO 04-14 18:14:42 [llm.py:391] Supported tasks: ['generate']
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): VocabParallelEmbedding(num_embeddings=50304, embedding_dim=768, org_vocab_size=50257, num_embeddings_padded=50304, tp_size=1)
(wpe): Embedding(1024, 768)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): QKVParallelLinear(in_features=768, output_features=2304, bias=True, tp_size=1, gather_output=False)
(c_proj): RowParallelLinear(in_features=768, output_features=768, bias=True, tp_size=1, reduce_results=True)
(attn): Attention(head_size=64, num_heads=12, num_kv_heads=12, scale=0.125, backend=FlashAttentionImpl)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): ColumnParallelLinear(in_features=768, output_features=3072, bias=True, tp_size=1, gather_output=False)
(c_proj): RowParallelLinear(in_features=3072, output_features=768, bias=True, tp_size=1, reduce_results=True)
(act): NewGELU()
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): ParallelLMHead(num_embeddings=50304, embedding_dim=768, org_vocab_size=50257, num_embeddings_padded=50304, tp_size=1)
(logits_processor): LogitsProcessor(vocab_size=50257, org_vocab_size=50257, scale=1.0, logits_as_input=False)
(logits): Logits
(samples): Sampled token ids
)
The wrapper looks and behaves like any other NNsight model: the underlying
vLLM engine is reachable via vllm.vllm_entrypoint, and the module envoys
(vllm.transformer.h[...], vllm.lm_head, etc.) are what you intervene on.
Two wrapper-only attributes are worth knowing about upfront:
vllm.logits— the final logits tensor before samplingvllm.samples— the sampled token ids after temperature / top-p
We'll use both later on.
Basic interventions¶
You can read or write any module's inputs/outputs inside a model.trace(...)
block, just like with LanguageModel:
prompt = "The Eiffel Tower is in the city of"
with vllm.trace(prompt, temperature=0.0, top_p=1) as tracer:
# Read: middle-layer MLP output
mlp_out = vllm.transformer.h[6].mlp.output.save()
# Write: zero out the last hidden state going into lm_head. vLLM runs
# under InferenceMode so tensors returned from modules are read-only;
# clone first, mutate, then assign the whole tensor back.
ln_out = vllm.transformer.ln_f.output.clone()
ln_out[-1, :] = 0
vllm.transformer.ln_f.output = ln_out
# Grab the first predicted token
next_token = vllm.logits.argmax(dim=-1).save()
print("mlp_out shape:", mlp_out.shape)
print("predicted token:", repr(vllm.tokenizer.decode(next_token)))
Rendering prompts: 0%| | 0/1 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 1/1 [00:00<00:00, 51.74it/s]
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 2.32it/s, est. speed input: 23.15 toks/s, output: 37.04 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 2.32it/s, est. speed input: 23.15 toks/s, output: 37.04 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 2.31it/s, est. speed input: 23.15 toks/s, output: 37.04 toks/s]
mlp_out shape: torch.Size([10, 768]) predicted token: '!'
A few things to notice:
- Flat token layout. vLLM concatenates all tokens from all prompts into a
single
[total_tokens, hidden]tensor — there's no batch dimension. That's why the slice above is[-1, :]and not[:, -1, :]like you'd write forLanguageModel. NNsight narrows the right slice for you automatically when you're inside a single-prompt invoke, so you still work in "per-request" coordinates. - Clone before writing. vLLM executes inside
torch.inference_mode(), so module outputs are read-only. Grab a clone, mutate it, and assign the whole tensor back — NNsight will pick up the assigned value and feed it to the next layer. - Save works the same.
.save()(ornns.save(...)) marks a value to be transported back to your process once generation finishes. Anything not saved is discarded.
Batching multiple prompts with invoke loops¶
With LanguageModel you can pass a list of prompts to a single invoke. With
vLLM you can't — each invoke is one request. Instead you put a loop of
tracer.invoke(...) calls inside a single trace() context:
prompts = [
"The Eiffel Tower is in the city of",
"Madison Square Garden is in the city of",
"The Colosseum is in the city of",
]
with vllm.trace(temperature=0.0, top_p=1) as tracer:
predictions = list().save()
for prompt in prompts:
with tracer.invoke(prompt):
token_id = vllm.logits.argmax(dim=-1)
predictions.append(vllm.tokenizer.decode(token_id))
for prompt, pred in zip(prompts, predictions):
print(f"{prompt!r:<46} → {pred!r}")
Rendering prompts: 0%| | 0/3 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 3/3 [00:00<00:00, 2766.08it/s]
Processed prompts: 0%| | 0/3 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 33%|███▎ | 1/3 [00:00<00:00, 9.58it/s, est. speed input: 95.83 toks/s, output: 153.31 toks/s]
Processed prompts: 100%|██████████| 3/3 [00:00<00:00, 9.58it/s, est. speed input: 266.24 toks/s, output: 456.39 toks/s]
Processed prompts: 100%|██████████| 3/3 [00:00<00:00, 28.44it/s, est. speed input: 266.24 toks/s, output: 456.39 toks/s]
'The Eiffel Tower is in the city of' → ' Paris' 'Madison Square Garden is in the city of' → ' New' 'The Colosseum is in the city of' → ' T'
Every invoke runs its own intervention code, but vLLM batches the underlying
forward passes for efficiency. You can pass different sampling params per
invoke too (e.g. tracer.invoke(prompt, temperature=0.8)).
Collecting results across invokes¶
Variables defined at trace scope are shared across every invoke. That makes it easy to build up a single data structure from many prompts:
prompts = [
"The Eiffel Tower is in",
"Madison Square Garden is in",
"The Colosseum is in",
]
with vllm.trace(temperature=0.0, top_p=1, max_tokens=3) as tracer:
# Shared, trace-scope state — .save() on the container itself.
all_tokens = [list() for _ in range(len(prompts))].save()
for i, prompt in enumerate(prompts):
with tracer.invoke(prompt):
# tracer.all() fires on every generation step.
with tracer.all():
all_tokens[i].append(vllm.samples.item())
for prompt, toks in zip(prompts, all_tokens):
print(f"{prompt!r:<32} → {vllm.tokenizer.decode(toks)!r}")
Rendering prompts: 0%| | 0/3 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 3/3 [00:00<00:00, 3239.68it/s]
Processed prompts: 0%| | 0/3 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 3/3 [00:00<00:00, 98.45it/s, est. speed input: 623.81 toks/s, output: 295.44 toks/s]
Processed prompts: 100%|██████████| 3/3 [00:00<00:00, 97.47it/s, est. speed input: 623.81 toks/s, output: 295.44 toks/s]
'The Eiffel Tower is in' → ' the middle of' 'Madison Square Garden is in' → ' the heart of' 'The Colosseum is in' → ' the middle of'
Key points:
.save()the container, not the list contents. The sharedall_tokenslist is saved once at trace scope; each invoke mutates it in place.- Per-invoke sampling params override the trace-level ones. You can do
tracer.invoke(prompt, temperature=0.8)inside a trace that was opened withtemperature=0.0.
Accessing logits and sampled tokens¶
vllm.logits and vllm.samples are synthesized envoys that sit at the end
of the pipeline:
vllm.logits— the pre-sampling logits for the current generation step (shape[vocab_size]per invoke)vllm.samples— the scalar token id that was actually sampled
with vllm.trace(
"Madison Square Garden is located in",
temperature=0.8,
top_p=0.95,
max_tokens=3,
) as tracer:
step_logits = list().save()
step_samples = list().save()
with tracer.all():
step_logits.append(vllm.logits)
step_samples.append(vllm.samples.item())
for i, (l, s) in enumerate(zip(step_logits, step_samples)):
print(f"step {i}: top-1 via argmax={l.argmax().item():5d} sampled={s:5d}"
f" ({vllm.tokenizer.decode(s)!r})")
Rendering prompts: 0%| | 0/1 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 1/1 [00:00<00:00, 2045.00it/s]
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 26.90it/s, est. speed input: 161.49 toks/s, output: 80.72 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 26.64it/s, est. speed input: 161.49 toks/s, output: 80.72 toks/s]
step 0: top-1 via argmax= 262 sampled= 262 (' the')
step 1: top-1 via argmax= 2612 sampled= 2612 (' heart')
step 2: top-1 via argmax= 286 sampled= 286 (' of')
Because we're sampling with temperature=0.8, the sampled token often
doesn't match argmax(logits). You can also write into vllm.logits
or vllm.samples to force specific sampling decisions.
Multi-token generation with .all() and .iter[...]¶
tracer.all() applies its body to every generation step:
with vllm.trace("Hello world", max_tokens=5) as tracer:
sampled = list().save()
with tracer.all():
sampled.append(vllm.samples.item())
print("tokens:", sampled)
print("decoded:", vllm.tokenizer.decode(sampled))
Rendering prompts: 0%| | 0/1 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 1/1 [00:00<00:00, 2048.00it/s]
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 35.47it/s, est. speed input: 71.00 toks/s, output: 177.44 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 35.10it/s, est. speed input: 71.00 toks/s, output: 177.44 toks/s]
tokens: [2644, 33011, 26842, 756, 404] decoded: ...olla rayontop
tracer.iter[slice] runs its body only on specific generation steps. This is
useful when you want to intervene on (or observe) a window of tokens:
prompt = "The Eiffel Tower is in the city of"
mlp = vllm.transformer.h[6].mlp
with vllm.trace(prompt, max_tokens=6) as tracer:
hidden_states = list().save()
# Zero out the MLP output on steps 2-4 (inclusive-exclusive).
with tracer.iter[2:5]:
masked = mlp.output.clone()
masked[-1] = 0
mlp.output = masked
hidden_states.append(mlp.output)
print(f"captured {len(hidden_states)} steps")
Rendering prompts: 0%| | 0/1 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 1/1 [00:00<00:00, 2175.47it/s]
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 30.46it/s, est. speed input: 304.87 toks/s, output: 182.85 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 30.17it/s, est. speed input: 304.87 toks/s, output: 182.85 toks/s]
captured 3 steps
Inside tracer.all() / tracer.iter[...] you can both read (append to a
saved list) and write (mutate .output). Writes apply only on the steps
the iterator selects.
Efficient activation extraction with tracer.cache()¶
The patterns above all involve explicitly appending values to a saved list.
For dense per-layer activation capture — e.g. "give me the residual stream at
layer 6 for every token of every request" — NNsight provides
tracer.cache(), which registers a persistent hook on each target module and
collects activations into a CacheDict keyed by module path.
target_layer = vllm.transformer.h[6]
with vllm.trace(
"The Eiffel Tower is in the city of",
temperature=0.0,
max_tokens=8,
) as tracer:
cache = tracer.cache(modules=[target_layer]).save()
import torch
print("cached module paths:", list(cache.keys()))
# The CacheDict is keyed by Envoy path. For GPT-2 under vLLM that's
# "model.transformer.h.6" (the "model" prefix is vLLM's wrapper namespace).
# When a module fires multiple times (prefill + each decode step) the entry
# is a list of Entry objects — concatenate their hidden states to get every
# captured token in one tensor.
entry = cache[next(iter(cache.keys()))]
entries = entry if isinstance(entry, list) else [entry]
def hidden_states(e):
out = e.output
return out[0] if isinstance(out, tuple) else out
all_hs = torch.cat([hidden_states(e) for e in entries], dim=0)
print(f"captured {len(entries)} forward passes")
print(f"total hidden-state shape: {tuple(all_hs.shape)} "
f"(prefill tokens + one row per decode step)")
Rendering prompts: 0%| | 0/1 [00:00<?, ?it/s]
Rendering prompts: 100%|██████████| 1/1 [00:00<00:00, 2141.04it/s]
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 26.22it/s, est. speed input: 262.36 toks/s, output: 209.83 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 25.99it/s, est. speed input: 262.36 toks/s, output: 209.83 toks/s]
cached module paths: ['model.transformer.h.6'] captured 8 forward passes total hidden-state shape: (17, 768) (prefill tokens + one row per decode step)
A few things to know about tracer.cache():
- Multiple modules. Pass a list (
modules=[layer_a, layer_b]) or leave it out to cache every module in the model. - Inputs too, if you want them.
tracer.cache(include_inputs=True)captures each module's(args, kwargs)alongside its output. - Device / dtype controls. Pass
device=torch.device("cpu")ordtype=torch.float32to move/cast the captured tensors as they're collected. CPU is the default. - Per-request transport. Inside the async backend, the cache is zstd- compressed and pickled per request instead of in one giant blob, which matters a lot at high concurrency.
tracer.cache() is the API the
nnsight-vllm-lens-comparison
benchmark uses to extract residual streams at scale.
Async mode (mode="async")¶
Pass mode="async" to VLLM(...) to load the model against vLLM's
AsyncLLM engine instead of the sync LLM. The trace-writing API is the
same, but the result is an async generator that streams RequestOutput
objects as tokens land.
import nnsight as nns
from nnsight.modeling.vllm import VLLM
async_vllm = VLLM(
"openai-community/gpt2",
tensor_parallel_size=1,
gpu_memory_utilization=0.15,
dispatch=True,
mode="async",
)
async def run_one(prompt: str):
with async_vllm.trace(prompt, temperature=0.0, max_tokens=5) as tracer:
cache = tracer.cache(modules=[async_vllm.transformer.h[6]])
nns.save(cache)
# tracer.backend is an AsyncVLLMBackend instance you iterate as an
# async generator. Each yielded output is a vLLM RequestOutput with
# a .saves dict attached once the request finishes.
final = None
async for output in tracer.backend:
if output.finished:
final = output
return final
# Jupyter already runs an asyncio loop — just `await` directly. In a plain
# Python script you'd wrap this in `asyncio.run(run_one(...))`.
final = await run_one("The Eiffel Tower is in")
print("finished:", final.finished)
print("decoded:", final.outputs[0].text)
print("saves:", list(final.saves.keys()))
INFO 04-14 18:14:43 [parallel_state.py:1395] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:48869 backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 04-14 18:14:43 [model.py:533] Resolved architecture: GPT2LMHeadModel
INFO 04-14 18:14:43 [model.py:1917] Downcasting torch.float32 to torch.bfloat16.
INFO 04-14 18:14:43 [model.py:1582] Using max model len 1024
INFO 04-14 18:14:43 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
WARNING 04-14 18:14:43 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
WARNING 04-14 18:14:43 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
INFO 04-14 18:14:43 [vllm.py:964] Cudagraph is disabled under eager mode
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 04-14 18:14:43 [model.py:533] Resolved architecture: GPT2LMHeadModel
INFO 04-14 18:14:44 [model.py:1917] Downcasting torch.float32 to torch.bfloat16.
INFO 04-14 18:14:44 [model.py:1582] Using max model len 1024
INFO 04-14 18:14:44 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
WARNING 04-14 18:14:44 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
WARNING 04-14 18:14:44 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
INFO 04-14 18:14:44 [vllm.py:964] Cudagraph is disabled under eager mode
(EngineCore pid=652060) INFO 04-14 18:14:49 [core.py:103] Initializing a V1 LLM engine (v0.18.0) with config: model='openai-community/gpt2', speculative_config=None, tokenizer='openai-community/gpt2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=openai-community/gpt2, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.NONE: 0>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_endpoints': [2048], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore pid=652060) INFO 04-14 18:14:51 [worker_base.py:269] Injected <class 'vllm_lens._worker_ext.HiddenStatesExtension'> into <class 'nnsight.modeling.vllm.workers.GPUWorker.NNsightGPUWorker'> for extended collective_rpc calls ['clear_captured_states', 'clear_steering_data', 'get_captured_states', 'install_hooks', 'set_steering_data'] (EngineCore pid=652060) INFO 04-14 18:14:51 [parallel_state.py:1395] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://10.201.16.108:53515 backend=nccl
(EngineCore pid=652060) INFO 04-14 18:14:51 [parallel_state.py:1717] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A, EPLB rank N/A
(EngineCore pid=652060) INFO 04-14 18:14:52 [gpu_model_runner.py:4481] Starting to load model openai-community/gpt2...
(EngineCore pid=652060) INFO 04-14 18:14:52 [cuda.py:317] Using FLASH_ATTN attention backend out of potential backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION']. (EngineCore pid=652060) INFO 04-14 18:14:52 [flash_attn.py:598] Using FlashAttention version 2
(EngineCore pid=652060) INFO 04-14 18:14:52 [weight_utils.py:618] No model.safetensors.index.json found in remote. (EngineCore pid=652060) INFO 04-14 18:14:52 [default_loader.py:384] Loading weights took 0.19 seconds
(EngineCore pid=652060) Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s] (EngineCore pid=652060) Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 5.96it/s] (EngineCore pid=652060) Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 5.95it/s] (EngineCore pid=652060)
(EngineCore pid=652060) INFO 04-14 18:14:53 [gpu_model_runner.py:4566] Model loading took 0.24 GiB memory and 0.756701 seconds
(EngineCore pid=652060) INFO 04-14 18:14:55 [gpu_worker.py:456] Available KV cache memory: 11.09 GiB (EngineCore pid=652060) INFO 04-14 18:14:55 [kv_cache_utils.py:1316] GPU KV cache size: 322,864 tokens (EngineCore pid=652060) INFO 04-14 18:14:55 [kv_cache_utils.py:1321] Maximum concurrency for 1,024 tokens per request: 315.30x (EngineCore pid=652060) INFO 04-14 18:14:55 [core.py:281] init engine (profile, create kv cache, warmup model) took 1.40 seconds
(EngineCore pid=652060) INFO 04-14 18:14:55 [vllm.py:754] Asynchronous scheduling is enabled. (EngineCore pid=652060) WARNING 04-14 18:14:55 [vllm.py:788] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none (EngineCore pid=652060) WARNING 04-14 18:14:55 [vllm.py:799] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored. (EngineCore pid=652060) INFO 04-14 18:14:55 [vllm.py:964] Cudagraph is disabled under eager mode (EngineCore pid=652060) INFO 04-14 18:14:55 [compilation.py:289] Enabled custom fusions: norm_quant, act_quant WARNING 04-14 18:14:55 [input_processor.py:227] Passing raw prompts to InputProcessor is deprecated and will be removed in v0.18. You should instead pass the outputs of Renderer.render_cmpl() or Renderer.render_chat().
finished: True decoded: the middle of the city saves: ['cache']
The async path is the right choice when you're firing many requests
concurrently — e.g. using asyncio.gather to submit a whole dataset at once.
vLLM's scheduler forms multi-request batches, and each request's saves are
transported back independently (zstd-compressed) as it finishes, rather than
waiting on one giant pickle at the end.
Sync vs. async at a glance:
Sync (mode="sync", default) |
Async (mode="async") |
|
|---|---|---|
| Engine class | vllm.LLM |
vllm.v1.engine.async_llm.AsyncLLM |
| Best for | Notebook experiments, small batches | High-throughput concurrent workloads |
| Trace API | with model.trace(...) as t: |
same — but iterate t.backend after |
| Saves transport | One pickle per finished request | zstd-compressed per-request saves on every streamed output |
| Multi-token intervention | tracer.iter[...], tracer.all() |
same |
Tensor parallelism and multi-GPU¶
VLLM(..., tensor_parallel_size=N) shards the model across N GPUs on a
single node. NNsight handles the sharded-tensor semantics automatically:
intervention code always sees the full gathered tensor (not a shard),
and writes are re-sharded before being passed back to vLLM. From your
perspective as a user, vllm.trace(...) looks identical to the tp=1
case — you just see the full hidden-state tensor.
The code below isn't executed inline because this notebook runs on a single
GPU, but the pattern is drop-in: if you have two GPUs visible via
CUDA_VISIBLE_DEVICES, just set tensor_parallel_size=2.
from nnsight.modeling.vllm import VLLM
# 2-way tensor parallelism on a single machine
vllm_tp2 = VLLM(
"facebook/opt-6.7b",
tensor_parallel_size=2,
gpu_memory_utilization=0.85,
dispatch=True,
)
# Use it exactly like a tp=1 model — NNsight gathers sharded tensors
# before your intervention code runs and re-shards after.
with vllm_tp2.trace("Hello world", max_tokens=5) as tracer:
mid = vllm_tp2.model.decoder.layers[16].output.save()
Under the hood, NNsight's VLLMBatcher registers pre/post hooks on every
parallel linear layer (ColumnParallelLinear, RowParallelLinear). On entry
to a sharded layer it gathers the sharded input/output with
tensor_model_parallel_all_gather, runs the intervention code against the
full tensor on every rank, and on exit re-splits the tensor so vLLM's
forward pass sees the same shape it would have without the hooks. Every
GPU runs the same intervention on the same complete tensor.
Multi-node with Ray¶
For tensor parallelism across multiple machines, switch the executor to Ray:
vllm_multinode = VLLM(
"meta-llama/Llama-3.1-70B",
tensor_parallel_size=8,
distributed_executor_backend="ray",
dispatch=True,
)
Before instantiation, point RAY_ADDRESS at an existing Ray cluster's GCS
address (e.g. RAY_ADDRESS=head-node:6379, not ray://host:port).
NNsight's NNsightRayExecutor joins the cluster as a driver-only node and
places workers on the available machines. You don't need to change any of
your tracing code — interventions run identically against a multi-node Ray
executor.
A full multi-node Docker example, including the Dockerfile and cluster
test harness, lives at
nnsight/src/nnsight/modeling/vllm/examples/multi_node_with_ray/.
Limitations and gotchas¶
A short list of things that surprise people:
- No gradients. vLLM's paged-attention kernels don't retain a computation
graph, so
.backward(),.grad, and any gradient-based operations aren't supported onVLLMmodels. UseLanguageModelfor those. - vLLM ≠ transformers numerically. Fused kernels, different attention
implementations, quantization defaults, and temperature-0 argmax being
deterministic within an engine (but not across engines at different
batch sizes) all mean you shouldn't expect exact-match outputs between
VLLMandLanguageModelfor the same input. The interventions are correct in both — the baselines just aren't identical. - One prompt per invoke. As noted above,
tracer.invoke(prompt)takes exactly one string / token-id list. To batch, loop invokes inside a single trace; vLLM handles the batching. tracer.cache()under preempt+recompute. When 1000s of concurrent requests exceed the KV-cache budget, vLLM evicts some in-flight requests and re-runs their full prefill on resume. The cache hook fires for both the original run and the recompute, so those requests end up with ~2× their expected token rows in theCacheDict. If you're running at a scale where preemption can happen (check for the engine'sPreemptedlog lines), provision enoughgpu_memory_utilization/max_num_seqsto keep everything resident.- Async engine requires
asyncio.mode="async"traces can only be iterated inside a coroutine. In a notebook, useawaitdirectly (Jupyter runs an event loop for you) orasyncio.run(...)from a plain Python script.
See also¶
Other feature guides:
- Tracer fundamentals — how traces, invokes, and
.save()work in general - Multi-token generation — deeper coverage of
tracer.all()/tracer.iter[...]
End-to-end examples:
ndif-team/nnsight-vllm-demos— runnable demos of the vLLM integration, including a streaming chat UI and a Llama-Scope-SAE steering example that modifies activations mid-generation behind a real chat serverJadenFiotto-Kaufman/nnsight-vllm-lens-comparison— reproducible benchmark pittingtracer.cache()against the vllm-lens plugin on the same OPT workload (OPT-30B, 1000 Alpaca prompts, 1024 max tokens, TP=2); includes avalidate.pythat bitwise-compares the extracted activations
Contributor-facing internals:
ndif-team/nnsight— source for the full vLLM integrationsrc/nnsight/modeling/vllm/README.md— architecture document: mediator transport viaextra_args, batch-group management (flat-token vs. prompt-level), the three interleaving phases (forward / logits / sampling), tensor-parallel gather/reshard, the Ray distributed executor, and the async engine. Read this if you're modifying the integration or trying to understand why a specific intervention behaves the way it does inside vLLM.src/nnsight/modeling/vllm/examples/multi_node_with_ray/— Docker-based multi-node Ray example with a test harness for verifying TP across machines