vLLM Support#

vLLM is a popular library used for fast inference. By leveraging PagedAttention, dynamic batching, and Hugging Face model integration, vLLM makes inference more efficient and scalable for real-world applications.

Starting with NNsight 0.5.X, NNsight includes support for investigations of vLLM-run models.

Setup#

You will need to install nnsight@vllmv1, vllm==0.12.0, and triton==3.5.0 to use vLLM with NNsight.

[ ]:
from IPython.display import clear_output
from pprint import pprint

try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    %pip install -U git+https://github.com/ndif-team/nnsight@vllmv1 triton==3.5.0 vllm==0.12.0 numpy==2.2.4
clear_output()

Next, let’s load in our NNsight vLLM model. You can find vLLM-supported models here. For this exercise, we will use Llama 3.1 8B. (Note that Meta gates access to Llama models on Huggingface, so as usual, you will need a HF_TOKEN and approved access, or to use an ungated model.)

vLLM models require a supported GPU/backend to run.

[ ]:
from nnsight.modeling.vllm import VLLM

# vLLM supports explicit parallelism
vllm = VLLM("meta-llama/Llama-3.1-8B", dispatch=True, tensor_parallel_size=2)

clear_output()

print(vllm)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)
          (o_proj): RowParallelLinear(in_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)
          (attn): Attention(head_size=128, num_heads=32, num_kv_heads=8, scale=0.08838834764831845, backend=FlashAttentionImpl)
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)
          (down_proj): RowParallelLinear(in_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
        (post_attention_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
      )
    )
    (norm): RMSNorm(hidden_size=4096, eps=1e-05)
  )
  (lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
  (logits_processor): LogitsProcessor(vocab_size=128256, org_vocab_size=128256, scale=1.0, logits_as_input=False)
  (logits): WrapperModule()
  (samples): WrapperModule()
  (generator): WrapperModule()
)

Interventions on vLLM models#

We now have a vLLM model that runs with nnsight. Let’s try applying some interventions on it.

[ ]:
neurons = [394, 5490, 8929]
prompt = "The truth is the"

mlp = vllm.model.layers[16].mlp.down_proj

with vllm.trace(prompt, remote=False):
    mlp.input = mlp.input.clone()
    mlp.input[-1, neurons] = 10               # no batch dimension
    out = vllm.output.save()
    last = out[:, -1].argmax()                # returns a tensor
    prediction = vllm.tokenizer.decode(last).save()

print(f"Prediction with vLLM: '{prediction}'")

Prediction with vLLM: '!'

⚠️ Note: Because of differences in inference settings and other implementation details, results may differ compared to Transformers, even for the same intervention!

[ ]:
# Use the HuggingFace transformers backend for comparison
from nnsight import LanguageModel

neurons = [394, 5490, 8929]
prompt = "The truth is the"

# Use CUDA_VISIBLE_DEVICES in your env, not tensor_parallel_size
lm = LanguageModel("meta-llama/Llama-3.1-8B", dispatch=True, device_map="auto")
mlp = lm.model.layers[16].mlp.down_proj

with lm.trace(prompt, remote=False):
    mlp.input[:, -1, neurons] = 10                # batch dimension
    out = lm.output.save()
    last = out["logits"][:, -1].argmax()          # dict of tensors
    prediction = lm.tokenizer.decode(last).save()

print(f"Prediction with transformers: '{prediction}'")
Prediction with transformers: ' lie'

Keep these differences in mind when working with different backends.

Sampled Token Traceability#

vLLM provides functionality to configure how each sequence samples its next token. Here’s an example of how you can trace token sampling operations with the nnsight VLLM wrapper.

[ ]:
with vllm.trace("Madison Square Garden is located in the city of", temperature=0.8, top_p=0.95, max_tokens=3) as tracer:
    samples = list().save()
    logits = list().save()

    for ii in range(3):
        tracer.next()
        samples.append(vllm.samples.output)
        tracer.next()
        logits.append(vllm.logits.output)
    samples.save()
    logits.save()

pprint(samples)
pprint(logits) # different than samples with current sampling parameters
[SamplerOutput(sampled_token_ids=tensor([[4356]], device='cuda:0', dtype=torch.int32),
               logprobs_tensors=None),
 SamplerOutput(sampled_token_ids=tensor([[323]], device='cuda:0', dtype=torch.int32),
               logprobs_tensors=None),
 SamplerOutput(sampled_token_ids=tensor([[279]], device='cuda:0', dtype=torch.int32),
               logprobs_tensors=None)]
[tensor([[11.9375,  5.9062,  4.7188,  ..., -2.1875, -2.1875, -2.1875]],
       device='cuda:0', dtype=torch.bfloat16),
 tensor([[ 2.5156,  1.6328,  0.2578,  ..., -6.0938, -6.0938, -6.0938]],
       device='cuda:0', dtype=torch.bfloat16),
 tensor([[ 1.9531,  4.6875,  2.4844,  ..., -6.7812, -6.7812, -6.7812]],
       device='cuda:0', dtype=torch.bfloat16)]

Note: gradients are not supported with vLLM

vLLM speeds up inference through its paged attention mechanism. This means that accessing gradients and backward passes are not supported for vLLM models. As such, calling gradient operations when using nnsight vLLM wrappers will throw an error.

Other features#

Intervening on generated token iterations with .all() and .iter[]#

NNSight supports iteration via all() and iter()

[ ]:
with vllm.trace("Hello World!", max_tokens=10) as tracer:
    outputs = list().save()

    # will iterate over all 10 tokens
    with tracer.all():
        out = vllm.output[:, -1]
        outputs.append(out)

print(len(outputs))
print("".join([vllm.tokenizer.decode(output.argmax()) for output in outputs]))
10
!!!!!!!!!!
[ ]:
prompt = 'The Eiffel Tower is in the city of'
mlp = vllm.model.layers[16].mlp.down_proj
n_new_tokens = 50

with vllm.trace(prompt, max_tokens=n_new_tokens) as tracer:
    hidden_states = list().save() # Initialize & .save() list

    # Call .iter() to apply intervention to specific new tokens
    with tracer.iter[2:5]:

        # Apply intervention - set to zero
        mlp.input = mlp.input.clone()
        mlp.input[-1] = 0

        # Append hidden state post-intervention
        hidden_states.append(mlp.input) # no need to call .save

print("Hidden state length: ",len(hidden_states))
pprint(hidden_states)
Hidden state length:  3
[tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)]