Source code for nnsight.modeling.language

from __future__ import annotations

import json
import warnings
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Generic,
    List,
    Optional,
    Protocol,
    Tuple,
    Type,
    Union,
)

import torch
from torch.nn.modules import Module
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    BatchEncoding,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.models.auto import modeling_auto
from transformers.models.llama.configuration_llama import LlamaConfig
from typing_extensions import Self

from ..intervention import Envoy
from ..intervention.contexts import InterleavingTracer
from ..intervention.graph import InterventionNodeType, InterventionProxyType
from ..util import WrapperModule
from .mixins import RemoteableMixin


class LanguageModel(RemoteableMixin):
    """LanguageModels are NNsight wrappers around transformers language models.

    Inputs can be in the form of:
        Prompt: (str)
        Prompts: (List[str])
        Batched prompts: (List[List[str]])
        Tokenized prompt: (Union[List[int], torch.Tensor])
        Tokenized prompts: (Union[List[List[int]], torch.Tensor])
        Direct input: (Dict[str,Any])

    If using a custom model, you also need to provide the tokenizer like ``LanguageModel(custom_model, tokenizer=tokenizer)``

    Calls to generate pass arguments downstream to :func:`GenerationMixin.generate`

    Attributes:
        config (PretrainedConfig): Huggingface config file loaded from repository or checkpoint.
        tokenizer (PreTrainedTokenizer): Tokenizer for LMs.
        automodel (Type): AutoModel type from transformer auto models.
        model (PreTrainedModel): Meta version of underlying auto model.

    """

    __methods__ = {"generate": "_generate"}

    tokenizer: PreTrainedTokenizer

[docs] class Generator(WrapperModule):
[docs] class Streamer(WrapperModule): def put(self, *args): return self(*args) def end(self): pass
def __init__(self) -> None: super().__init__() self.streamer = LanguageModel.Generator.Streamer()
def __init__( self, *args, config: Optional[PretrainedConfig] = None, tokenizer: Optional[PreTrainedTokenizer] = None, automodel: Type[AutoModel] = AutoModelForCausalLM, **kwargs, ) -> None: self.automodel = ( automodel if not isinstance(automodel, str) else getattr(modeling_auto, automodel) ) self.config = config self.tokenizer = tokenizer self.repo_id: str = None super().__init__(*args, **kwargs) self.generator: Envoy[InterventionProxyType, InterventionNodeType] = ( LanguageModel.Generator() ) def _load_config(self, repo_id: str, **kwargs): if self.config is None: self.config = AutoConfig.from_pretrained(repo_id, **kwargs) def _load_tokenizer(self, repo_id: str, **kwargs): if self.tokenizer is None: if "padding_side" not in kwargs: kwargs["padding_side"] = "left" self.tokenizer = AutoTokenizer.from_pretrained( repo_id, config=self.config, **kwargs ) if getattr(self.tokenizer, "pad_token", None) is None: self.tokenizer.pad_token = self.tokenizer.eos_token def _load_meta( self, repo_id: str, tokenizer_kwargs: Optional[Dict[str, Any]] = {}, patch_llama_scan: bool = True, **kwargs, ) -> Module: self.repo_id = repo_id self._load_config(repo_id, **kwargs) self._load_tokenizer(repo_id, **tokenizer_kwargs) if ( patch_llama_scan and isinstance(self.config, LlamaConfig) and isinstance(self.config.rope_scaling, dict) and "rope_type" in self.config.rope_scaling ): self.config.rope_scaling["rope_type"] = "default" model = self.automodel.from_config(self.config, trust_remote_code=True) return model def _load( self, repo_id: str, tokenizer_kwargs: Optional[Dict[str, Any]] = {}, patch_llama_scan: bool = True, **kwargs, ) -> PreTrainedModel: self._load_config(repo_id, **kwargs) self._load_tokenizer(repo_id, **tokenizer_kwargs) if ( patch_llama_scan and isinstance(self.config, LlamaConfig) and isinstance(self.config.rope_scaling, dict) and "rope_type" in self.config.rope_scaling ): self.config.rope_scaling["rope_type"] = "llama3" model = self.automodel.from_pretrained(repo_id, config=self.config, **kwargs) return model def _tokenize( self, inputs: Union[ str, List[str], List[List[str]], List[int], List[List[int]], torch.Tensor, Dict[str, Any], ], **kwargs, ): if isinstance(inputs, str) or ( isinstance(inputs, list) and isinstance(inputs[0], int) ): inputs = [inputs] if isinstance(inputs, torch.Tensor) and inputs.ndim == 1: inputs = inputs.unsqueeze(0) if not isinstance(inputs[0], str): inputs = [{"input_ids": ids} for ids in inputs] return self.tokenizer.pad(inputs, return_tensors="pt", **kwargs) return self.tokenizer(inputs, return_tensors="pt", padding=True, **kwargs) def _prepare_input( self, *inputs: Tuple[ Union[ str, List[str], List[List[str]], List[int], List[List[int]], torch.Tensor, List[torch.Tensor], Dict[str, Any], BatchEncoding, ] ], input_ids: Union[ List[int], List[List[int]], torch.Tensor, List[torch.Tensor] ] = None, labels: Any = None, **kwargs, ) -> Tuple[BatchEncoding, int]: if input_ids is not None: assert len(inputs) == 0 inputs = (input_ids,) assert len(inputs) == 1 inputs = inputs[0] if isinstance(inputs, dict): inputs = BatchEncoding(inputs) elif isinstance(inputs, BatchEncoding): pass else: inputs = self._tokenize(inputs, **kwargs) if labels is not None: labels = self._tokenize(labels, **kwargs)["input_ids"] return ((inputs,), {"labels": labels}), len(inputs["input_ids"]) def _batch( self, batched_inputs: Optional[Tuple[Tuple[BatchEncoding], Dict[str, Any]]], input: BatchEncoding, labels: Optional[torch.Tensor] = None, ) -> Tuple[Dict[str, Any]]: if batched_inputs is None: return ((input,), {"labels": labels}) batched_labels = batched_inputs[1]["labels"] batched_inputs = batched_inputs[0][0] attention_mask = batched_inputs["attention_mask"] batched_inputs = [ {"input_ids": ids} for ids in [ *batched_inputs["input_ids"].tolist(), *input["input_ids"].tolist(), ] ] batched_inputs = self.tokenizer.pad(batched_inputs, return_tensors="pt") if labels is not None: batched_labels = torch.cat((batched_labels, labels)) if self.tokenizer.padding_side == "left": batched_inputs["attention_mask"][ : attention_mask.shape[0], -attention_mask.shape[1] : ] = attention_mask else: batched_inputs["attention_mask"][ : attention_mask.shape[0], : attention_mask.shape[1] ] = attention_mask return ((batched_inputs,), {"labels": batched_labels}) def _execute(self, inputs: BatchEncoding, **kwargs) -> Any: inputs = inputs.to(self.device) return self._model( **inputs, **kwargs, ) def _generate( self, inputs: BatchEncoding, max_new_tokens=1, streamer: Any = None, **kwargs, ): if streamer is None: streamer = self.generator.streamer inputs = inputs.to(self.device) output = self._model.generate( **inputs, **kwargs, streamer=streamer, max_new_tokens=max_new_tokens, ) self.generator(output) return output def _remoteable_model_key(self) -> str: return json.dumps( {"repo_id": self.repo_id} # , "torch_dtype": str(self._model.dtype)} ) @classmethod def _remoteable_from_model_key(cls, model_key: str, **kwargs) -> Self: kwargs = {**json.loads(model_key), **kwargs} repo_id = kwargs.pop("repo_id") return LanguageModel(repo_id, **kwargs) if TYPE_CHECKING:
[docs] class LanguageModel(LanguageModel, PreTrainedModel): def generate(self, *args, **kwargs) -> InterleavingTracer: pass