Source code for nnsight.modeling.vllm.sampling
import copy
from typing import Dict, List, Optional, Tuple
from nnsight.intervention.graph import InterventionGraph
import torch
from vllm.model_executor.sampling_metadata import (
SamplingMetadata,
SamplingMetadataCache,
_prepare_seq_groups,
)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d
[docs]
class NNsightSamplingParams(SamplingParams):
intervention_graph: Optional[InterventionGraph] = None
nns_batch_groups: Optional[List[Tuple[int, int]]] = None
invoker_group: Optional[int] = None
is_default_param: bool = True
[docs]
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
memo = {}
if self.logits_processors is not None:
for lp in self.logits_processors:
memo[id(lp)] = lp
if self.intervention_graph is not None:
memo[id(self.intervention_graph)] = self.intervention_graph
return copy.deepcopy(self, memo=memo)
[docs]
class NNsightSamplingMetadata(SamplingMetadata):
intervention_graph: Optional[InterventionGraph] = None
nns_batch_groups: Optional[List[Tuple[int, int]]] = None
batch_groups: Optional[List[Tuple[int, int]]] = None
def __init__(
self,
*args,
intervention_graph: InterventionGraph = None,
nns_batch_groups: List[Tuple[int, int]] = None,
batch_groups: Dict[int, Tuple[int, int]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.intervention_graph = intervention_graph
self.nns_batch_groups = nns_batch_groups
self.batch_groups = batch_groups
@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: List[int],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(
seq_group_metadata_list, seq_lens, query_lens, device, generators, cache
)
selected_token_indices = async_tensor_h2d(
selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory,
)
categorized_sample_indices = {
t: async_tensor_h2d(
seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory,
)
for t, seq_ids in categorized_sample_indices.items()
}
### NNSIGHT ###########################################
intervention_graphs = []
nns_batch_groups = []
batch_groups = []
batch_groups_offset = 0
for idx, seq_group in enumerate(seq_group_metadata_list):
if isinstance(seq_group.sampling_params, NNsightSamplingParams):
seq_group_intervention_graph = (
seq_group.sampling_params.intervention_graph
)
seq_group_nns_batch_groups = seq_group.sampling_params.nns_batch_groups
if isinstance(seq_group_intervention_graph, InterventionGraph):
if seq_group_intervention_graph not in intervention_graphs:
intervention_graphs.append(seq_group_intervention_graph)
nns_batch_groups.append(seq_group_nns_batch_groups)
batch_groups_offset = len(batch_groups)
seq_group_batch_group = (
seq_group.sampling_params.invoker_group + batch_groups_offset
)
batch_size = query_lens[idx]
if seq_group_batch_group >= len(batch_groups):
batch_start = sum(batch_groups[-1]) if len(batch_groups) > 0 else 0
batch_groups.append((batch_start, batch_size))
else:
batch_start, seq_group_batch_size = batch_groups[
seq_group_batch_group
]
batch_size += seq_group_batch_size
batch_groups[seq_group_batch_group] = (batch_start, batch_size)
n_graphs = len(intervention_graphs)
if n_graphs== 0:
intervention_graph = None
nns_batch_groups = None
elif n_graphs == 1:
intervention_graph =intervention_graphs[0]
nns_batch_groups = nns_batch_groups[0]
""" else:
intervention_graph = MultiGraph(intervention_graphs.values())
InterventionProtocol.shift(intervention_graph) """
###########################################
sampling_metadata = NNsightSamplingMetadata(
seq_groups=seq_groups,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
num_prompts=num_prompts,
intervention_graph=intervention_graph,
nns_batch_groups = nns_batch_groups,
batch_groups=batch_groups,
)
return sampling_metadata