vLLM Support#

Summary#

vLLM is a popular library used for fast inference. By leveraging PagedAttention, dynamic batching, and Hugging Face model integration, vLLM makes inference more efficient and scalable for real-world applications.

Starting with NNsight 0.5.13, NNsight includes support for investigations of vLLM-run models.

# instantiating vllm model
from nnsight.modeling.vllm import VLLM

vllm = VLLM("model_ID")

When to Use#

vLLM is useful for performance speed-ups, particularly for experiments with multiple batches or generations.

A few considerations when choosing to use vLLM for your experiments:

  • NNsight supports vLLM text-generation models. You can find a list of supported models here. Our support currently does not extend to multimodal or image generation models.

  • Be aware that vLLM results may also differ from the base Transformer model results, even for the same experiment.

  • Note that vLLM models do not use gradients, if you want to research gradient methods, use LanguageModel instead.

More info:

vLLM speeds up inference through its paged attention mechanism. This means that accessing gradients and backward passes are not supported for vLLM models. As such, calling gradient operations when using nnsight vLLM wrappers will throw an error.

How to Use#

Setup#

You will need to use nnsight >= 0.5.13, vllm >= 0.12, and triton==3.5.0 to use vLLM with NNsight.

[1]:
from IPython.display import clear_output
from pprint import pprint

%pip install -U nnsight triton==3.5.0 vllm==0.13.0 numpy==2.2.4

clear_output()

Instantiating vLLM Models#

Next, let’s load in our NNsight vLLM model (list of vLLM-supported models & their IDs here).

For this exercise, we will use meta-llama/Llama-3.1-8B. Note that Meta gates access to Llama models on Huggingface, so as usual, you will need a HF_TOKEN and approved access, or to use an ungated model.

vLLM models require a supported GPU/backend to run.

[ ]:
from nnsight.modeling.vllm import VLLM

# vLLM supports explicit parallelism
vllm = VLLM("meta-llama/Llama-3.1-8B", dispatch=True, tensor_parallel_size=1, gpu_memory_utilization=0.8)

clear_output()

print(vllm)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)
          (o_proj): RowParallelLinear(in_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)
          (attn): Attention(head_size=128, num_heads=32, num_kv_heads=8, scale=0.08838834764831845, backend=FlashAttentionImpl)
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)
          (down_proj): RowParallelLinear(in_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
        (post_attention_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
      )
    )
    (norm): RMSNorm(hidden_size=4096, eps=1e-05)
  )
  (lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
  (logits_processor): LogitsProcessor(vocab_size=128256, org_vocab_size=128256, scale=1.0, logits_as_input=False)
  (logits): WrapperModule()
  (samples): WrapperModule()
  (generator): WrapperModule()
)

We now have a vLLM model that runs with nnsight.

Interventions on vLLM models#

You can access and intervene on model internals in vLLM models just like you do for LanguageModel models through nnsight’s get and set operations.

⚠️ Note: As mentioned earlier, because differences in vLLM inference settings and other implementation details, results may differ compared to Transformers, even for the same intervention!

Let’s load up a LanguageModel instance of the same vLLM model so we can compare the two. Here, we’re loading in Llama-3.1-8B and making an intervention on identified antonym neurons, which should change the output to the antonym of the expected output.

[ ]:
# Use the HuggingFace transformers backend for comparison
from nnsight import LanguageModel

neurons = [394, 5490, 8929]
prompt = "The truth is the"

# Use CUDA_VISIBLE_DEVICES in your env, not tensor_parallel_size
lm = LanguageModel("meta-llama/Llama-3.1-8B", dispatch=True, device_map="auto")
mlp = lm.model.layers[16].mlp.down_proj

with lm.trace(prompt):
    mlp.input[:, -1, neurons] = 10                # batch dimension
    out = lm.output.save()
    last = out["logits"][:, -1].argmax()          # dict of tensors
    prediction = lm.tokenizer.decode(last).save()

print(f"Prediction with transformers: '{prediction}'")
Prediction with transformers: ' lie'

Great, the antonym neurons appeared to do their job.

Now, let’s intervene on the same neurons for the vLLM model and see how the result changes.

[ ]:
neurons = [394, 5490, 8929]
prompt = "The truth is the"

mlp = vllm.model.layers[16].mlp.down_proj

with vllm.trace(prompt):
    mlp.input = mlp.input.clone()
    mlp.input[-1, neurons] = 10               # no batch dimension
    out = vllm.output.save()
    last = out[:, -1].argmax()                # returns a tensor
    prediction = vllm.tokenizer.decode(last).save()

print(f"Prediction with vLLM: '{prediction}'")

Prediction with vLLM: '!'

As expected, the results were different, indicating that these models are not interchangeable. Keep these differences in mind when working with vLLM vs Transformers models and making comparisons between the two.

Sampled Token Traceability#

vLLM provides functionality to configure how each sequence samples its next token. Here’s an example of how you can trace token sampling operations of models with the nnsight vLLM wrapper.

[4]:
with vllm.trace("Madison Square Garden is located in the city of", temperature=0.8, top_p=0.95, max_tokens=3) as tracer:
    samples = list().save()
    logits = list().save()

    with tracer.iter[:3]:
        logits.append(vllm.logits.output)
        samples.append(vllm.samples.output)

pprint(samples)
pprint(logits) # different than samples with current sampling parameters
[tensor([[1561]], device='cuda:0', dtype=torch.int32),
 tensor([[4356]], device='cuda:0', dtype=torch.int32),
 tensor([[11]], device='cuda:0', dtype=torch.int32)]
[tensor([[ 5.2812,  4.0312,  2.0312,  ..., -5.5000, -5.5000, -5.5000]],
       device='cuda:0', dtype=torch.bfloat16),
 tensor([[ 7.6875,  6.3750,  4.0312,  ..., -5.3438, -5.3438, -5.3438]],
       device='cuda:0', dtype=torch.bfloat16),
 tensor([[11.9375,  5.9375,  4.7188,  ..., -2.2031, -2.2031, -2.2031]],
       device='cuda:0', dtype=torch.bfloat16)]

Other features#

Intervening on generated token iterations with .all() and .iter[]#

NNSight supports iteration via all() and iter()

[5]:
with vllm.trace("Hello World!", max_tokens=10) as tracer:
    outputs = list().save()

    # will iterate over all 10 tokens
    with tracer.all():
        out = vllm.output[:, -1]
        outputs.append(out)

print(len(outputs))
print("".join([vllm.tokenizer.decode(output.argmax()) for output in outputs]))
10
!!!!!!!!!!
[6]:
prompt = 'The Eiffel Tower is in the city of'
mlp = vllm.model.layers[16].mlp.down_proj
n_new_tokens = 50

with vllm.trace(prompt, max_tokens=n_new_tokens) as tracer:
    hidden_states = list().save() # Initialize & .save() list

    # Call .iter() to apply intervention to specific new tokens
    with tracer.iter[2:5]:

        # Apply intervention - set to zero
        mlp.input = mlp.input.clone()
        mlp.input[-1] = 0

        # Append hidden state post-intervention
        hidden_states.append(mlp.input) # no need to call .save

print("Hidden state length: ",len(hidden_states))
pprint(hidden_states)
Hidden state length:  3
[tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)]