llama-index
636 строк · 23.6 Кб
1import logging2from threading import Thread3from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union4
5from llama_index.legacy.bridge.pydantic import Field, PrivateAttr6from llama_index.legacy.callbacks import CallbackManager7from llama_index.legacy.constants import (8DEFAULT_CONTEXT_WINDOW,9DEFAULT_NUM_OUTPUTS,10)
11from llama_index.legacy.core.llms.types import (12ChatMessage,13ChatResponse,14ChatResponseAsyncGen,15ChatResponseGen,16CompletionResponse,17CompletionResponseAsyncGen,18CompletionResponseGen,19LLMMetadata,20MessageRole,21)
22from llama_index.legacy.llms.base import (23llm_chat_callback,24llm_completion_callback,25)
26from llama_index.legacy.llms.custom import CustomLLM27from llama_index.legacy.llms.generic_utils import (28completion_response_to_chat_response,29stream_completion_response_to_chat_response,30)
31from llama_index.legacy.llms.generic_utils import (32messages_to_prompt as generic_messages_to_prompt,33)
34from llama_index.legacy.prompts.base import PromptTemplate35from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode36
37DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b"38if TYPE_CHECKING:39try:40from huggingface_hub import AsyncInferenceClient, InferenceClient41from huggingface_hub.hf_api import ModelInfo42from huggingface_hub.inference._types import ConversationalOutput43except ModuleNotFoundError:44AsyncInferenceClient = Any45InferenceClient = Any46ConversationalOutput = dict47ModelInfo = Any48
49logger = logging.getLogger(__name__)50
51
52class HuggingFaceLLM(CustomLLM):53"""HuggingFace LLM."""54
55model_name: str = Field(56default=DEFAULT_HUGGINGFACE_MODEL,57description=(58"The model name to use from HuggingFace. "59"Unused if `model` is passed in directly."60),61)62context_window: int = Field(63default=DEFAULT_CONTEXT_WINDOW,64description="The maximum number of tokens available for input.",65gt=0,66)67max_new_tokens: int = Field(68default=DEFAULT_NUM_OUTPUTS,69description="The maximum number of tokens to generate.",70gt=0,71)72system_prompt: str = Field(73default="",74description=(75"The system prompt, containing any extra instructions or context. "76"The model card on HuggingFace should specify if this is needed."77),78)79query_wrapper_prompt: PromptTemplate = Field(80default=PromptTemplate("{query_str}"),81description=(82"The query wrapper prompt, containing the query placeholder. "83"The model card on HuggingFace should specify if this is needed. "84"Should contain a `{query_str}` placeholder."85),86)87tokenizer_name: str = Field(88default=DEFAULT_HUGGINGFACE_MODEL,89description=(90"The name of the tokenizer to use from HuggingFace. "91"Unused if `tokenizer` is passed in directly."92),93)94device_map: str = Field(95default="auto", description="The device_map to use. Defaults to 'auto'."96)97stopping_ids: List[int] = Field(98default_factory=list,99description=(100"The stopping ids to use. "101"Generation stops when these token IDs are predicted."102),103)104tokenizer_outputs_to_remove: list = Field(105default_factory=list,106description=(107"The outputs to remove from the tokenizer. "108"Sometimes huggingface tokenizers return extra inputs that cause errors."109),110)111tokenizer_kwargs: dict = Field(112default_factory=dict, description="The kwargs to pass to the tokenizer."113)114model_kwargs: dict = Field(115default_factory=dict,116description="The kwargs to pass to the model during initialization.",117)118generate_kwargs: dict = Field(119default_factory=dict,120description="The kwargs to pass to the model during generation.",121)122is_chat_model: bool = Field(123default=False,124description=(125LLMMetadata.__fields__["is_chat_model"].field_info.description126+ " Be sure to verify that you either pass an appropriate tokenizer "127"that can convert prompts to properly formatted chat messages or a "128"`messages_to_prompt` that does so."129),130)131
132_model: Any = PrivateAttr()133_tokenizer: Any = PrivateAttr()134_stopping_criteria: Any = PrivateAttr()135
136def __init__(137self,138context_window: int = DEFAULT_CONTEXT_WINDOW,139max_new_tokens: int = DEFAULT_NUM_OUTPUTS,140query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}",141tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,142model_name: str = DEFAULT_HUGGINGFACE_MODEL,143model: Optional[Any] = None,144tokenizer: Optional[Any] = None,145device_map: Optional[str] = "auto",146stopping_ids: Optional[List[int]] = None,147tokenizer_kwargs: Optional[dict] = None,148tokenizer_outputs_to_remove: Optional[list] = None,149model_kwargs: Optional[dict] = None,150generate_kwargs: Optional[dict] = None,151is_chat_model: Optional[bool] = False,152callback_manager: Optional[CallbackManager] = None,153system_prompt: str = "",154messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,155completion_to_prompt: Optional[Callable[[str], str]] = None,156pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,157output_parser: Optional[BaseOutputParser] = None,158) -> None:159"""Initialize params."""160try:161import torch162from transformers import (163AutoModelForCausalLM,164AutoTokenizer,165StoppingCriteria,166StoppingCriteriaList,167)168except ImportError as exc:169raise ImportError(170f"{type(self).__name__} requires torch and transformers packages.\n"171"Please install both with `pip install transformers[torch]`."172) from exc173
174model_kwargs = model_kwargs or {}175self._model = model or AutoModelForCausalLM.from_pretrained(176model_name, device_map=device_map, **model_kwargs177)178
179# check context_window180config_dict = self._model.config.to_dict()181model_context_window = int(182config_dict.get("max_position_embeddings", context_window)183)184if model_context_window and model_context_window < context_window:185logger.warning(186f"Supplied context_window {context_window} is greater "187f"than the model's max input size {model_context_window}. "188"Disable this warning by setting a lower context_window."189)190context_window = model_context_window191
192tokenizer_kwargs = tokenizer_kwargs or {}193if "max_length" not in tokenizer_kwargs:194tokenizer_kwargs["max_length"] = context_window195
196self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(197tokenizer_name, **tokenizer_kwargs198)199
200if tokenizer_name != model_name:201logger.warning(202f"The model `{model_name}` and tokenizer `{tokenizer_name}` "203f"are different, please ensure that they are compatible."204)205
206# setup stopping criteria207stopping_ids_list = stopping_ids or []208
209class StopOnTokens(StoppingCriteria):210def __call__(211self,212input_ids: torch.LongTensor,213scores: torch.FloatTensor,214**kwargs: Any,215) -> bool:216for stop_id in stopping_ids_list:217if input_ids[0][-1] == stop_id:218return True219return False220
221self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])222
223if isinstance(query_wrapper_prompt, str):224query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)225
226messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt227
228super().__init__(229context_window=context_window,230max_new_tokens=max_new_tokens,231query_wrapper_prompt=query_wrapper_prompt,232tokenizer_name=tokenizer_name,233model_name=model_name,234device_map=device_map,235stopping_ids=stopping_ids or [],236tokenizer_kwargs=tokenizer_kwargs or {},237tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],238model_kwargs=model_kwargs or {},239generate_kwargs=generate_kwargs or {},240is_chat_model=is_chat_model,241callback_manager=callback_manager,242system_prompt=system_prompt,243messages_to_prompt=messages_to_prompt,244completion_to_prompt=completion_to_prompt,245pydantic_program_mode=pydantic_program_mode,246output_parser=output_parser,247)248
249@classmethod250def class_name(cls) -> str:251return "HuggingFace_LLM"252
253@property254def metadata(self) -> LLMMetadata:255"""LLM metadata."""256return LLMMetadata(257context_window=self.context_window,258num_output=self.max_new_tokens,259model_name=self.model_name,260is_chat_model=self.is_chat_model,261)262
263def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:264"""Use the tokenizer to convert messages to prompt. Fallback to generic."""265if hasattr(self._tokenizer, "apply_chat_template"):266messages_dict = [267{"role": message.role.value, "content": message.content}268for message in messages269]270tokens = self._tokenizer.apply_chat_template(messages_dict)271return self._tokenizer.decode(tokens)272
273return generic_messages_to_prompt(messages)274
275@llm_completion_callback()276def complete(277self, prompt: str, formatted: bool = False, **kwargs: Any278) -> CompletionResponse:279"""Completion endpoint."""280full_prompt = prompt281if not formatted:282if self.query_wrapper_prompt:283full_prompt = self.query_wrapper_prompt.format(query_str=prompt)284if self.system_prompt:285full_prompt = f"{self.system_prompt} {full_prompt}"286
287inputs = self._tokenizer(full_prompt, return_tensors="pt")288inputs = inputs.to(self._model.device)289
290# remove keys from the tokenizer if needed, to avoid HF errors291for key in self.tokenizer_outputs_to_remove:292if key in inputs:293inputs.pop(key, None)294
295tokens = self._model.generate(296**inputs,297max_new_tokens=self.max_new_tokens,298stopping_criteria=self._stopping_criteria,299**self.generate_kwargs,300)301completion_tokens = tokens[0][inputs["input_ids"].size(1) :]302completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)303
304return CompletionResponse(text=completion, raw={"model_output": tokens})305
306@llm_completion_callback()307def stream_complete(308self, prompt: str, formatted: bool = False, **kwargs: Any309) -> CompletionResponseGen:310"""Streaming completion endpoint."""311from transformers import TextIteratorStreamer312
313full_prompt = prompt314if not formatted:315if self.query_wrapper_prompt:316full_prompt = self.query_wrapper_prompt.format(query_str=prompt)317if self.system_prompt:318full_prompt = f"{self.system_prompt} {full_prompt}"319
320inputs = self._tokenizer(full_prompt, return_tensors="pt")321inputs = inputs.to(self._model.device)322
323# remove keys from the tokenizer if needed, to avoid HF errors324for key in self.tokenizer_outputs_to_remove:325if key in inputs:326inputs.pop(key, None)327
328streamer = TextIteratorStreamer(329self._tokenizer,330skip_prompt=True,331decode_kwargs={"skip_special_tokens": True},332)333generation_kwargs = dict(334inputs,335streamer=streamer,336max_new_tokens=self.max_new_tokens,337stopping_criteria=self._stopping_criteria,338**self.generate_kwargs,339)340
341# generate in background thread342# NOTE/TODO: token counting doesn't work with streaming343thread = Thread(target=self._model.generate, kwargs=generation_kwargs)344thread.start()345
346# create generator based off of streamer347def gen() -> CompletionResponseGen:348text = ""349for x in streamer:350text += x351yield CompletionResponse(text=text, delta=x)352
353return gen()354
355@llm_chat_callback()356def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:357prompt = self.messages_to_prompt(messages)358completion_response = self.complete(prompt, formatted=True, **kwargs)359return completion_response_to_chat_response(completion_response)360
361@llm_chat_callback()362def stream_chat(363self, messages: Sequence[ChatMessage], **kwargs: Any364) -> ChatResponseGen:365prompt = self.messages_to_prompt(messages)366completion_response = self.stream_complete(prompt, formatted=True, **kwargs)367return stream_completion_response_to_chat_response(completion_response)368
369
370def chat_messages_to_conversational_kwargs(371messages: Sequence[ChatMessage],372) -> Dict[str, Any]:373"""Convert ChatMessages to keyword arguments for Inference API conversational."""374if len(messages) % 2 != 1:375raise NotImplementedError("Messages passed in must be of odd length.")376last_message = messages[-1]377kwargs: Dict[str, Any] = {378"text": last_message.content,379**last_message.additional_kwargs,380}381if len(messages) != 1:382kwargs["past_user_inputs"] = []383kwargs["generated_responses"] = []384for user_msg, assistant_msg in zip(messages[::2], messages[1::2]):385if (386user_msg.role != MessageRole.USER387or assistant_msg.role != MessageRole.ASSISTANT388):389raise NotImplementedError(390"Didn't handle when messages aren't ordered in alternating"391f" pairs of {(MessageRole.USER, MessageRole.ASSISTANT)}."392)393kwargs["past_user_inputs"].append(user_msg.content)394kwargs["generated_responses"].append(assistant_msg.content)395return kwargs396
397
398class HuggingFaceInferenceAPI(CustomLLM):399"""400Wrapper on the Hugging Face's Inference API.
401
402Overview of the design:
403- Synchronous uses InferenceClient, asynchronous uses AsyncInferenceClient
404- chat uses the conversational task: https://huggingface.co/tasks/conversational
405- complete uses the text generation task: https://huggingface.co/tasks/text-generation
406
407Note: some models that support the text generation task can leverage Hugging
408Face's optimized deployment toolkit called text-generation-inference (TGI).
409Use InferenceClient.get_model_status to check if TGI is being used.
410
411Relevant links:
412- General Docs: https://huggingface.co/docs/api-inference/index
413- API Docs: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client
414- Source: https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub/inference
415"""
416
417@classmethod418def class_name(cls) -> str:419return "HuggingFaceInferenceAPI"420
421# Corresponds with huggingface_hub.InferenceClient422model_name: Optional[str] = Field(423default=None,424description=(425"The model to run inference with. Can be a model id hosted on the Hugging"426" Face Hub, e.g. bigcode/starcoder or a URL to a deployed Inference"427" Endpoint. Defaults to None, in which case a recommended model is"428" automatically selected for the task (see Field below)."429),430)431token: Union[str, bool, None] = Field(432default=None,433description=(434"Hugging Face token. Will default to the locally saved token. Pass "435"token=False if you don’t want to send your token to the server."436),437)438timeout: Optional[float] = Field(439default=None,440description=(441"The maximum number of seconds to wait for a response from the server."442" Loading a new model in Inference API can take up to several minutes."443" Defaults to None, meaning it will loop until the server is available."444),445)446headers: Dict[str, str] = Field(447default=None,448description=(449"Additional headers to send to the server. By default only the"450" authorization and user-agent headers are sent. Values in this dictionary"451" will override the default values."452),453)454cookies: Dict[str, str] = Field(455default=None, description="Additional cookies to send to the server."456)457task: Optional[str] = Field(458default=None,459description=(460"Optional task to pick Hugging Face's recommended model, used when"461" model_name is left as default of None."462),463)464
465_sync_client: "InferenceClient" = PrivateAttr()466_async_client: "AsyncInferenceClient" = PrivateAttr()467_get_model_info: "Callable[..., ModelInfo]" = PrivateAttr()468
469context_window: int = Field(470default=DEFAULT_CONTEXT_WINDOW,471description=(472LLMMetadata.__fields__["context_window"].field_info.description473+ " This may be looked up in a model's `config.json`."474),475)476num_output: int = Field(477default=DEFAULT_NUM_OUTPUTS,478description=LLMMetadata.__fields__["num_output"].field_info.description,479)480is_chat_model: bool = Field(481default=False,482description=(483LLMMetadata.__fields__["is_chat_model"].field_info.description484+ " Unless chat templating is intentionally applied, Hugging Face models"485" are not chat models."486),487)488is_function_calling_model: bool = Field(489default=False,490description=(491LLMMetadata.__fields__["is_function_calling_model"].field_info.description492+ " As of 10/17/2023, Hugging Face doesn't support function calling"493" messages."494),495)496
497def _get_inference_client_kwargs(self) -> Dict[str, Any]:498"""Extract the Hugging Face InferenceClient construction parameters."""499return {500"model": self.model_name,501"token": self.token,502"timeout": self.timeout,503"headers": self.headers,504"cookies": self.cookies,505}506
507def __init__(self, **kwargs: Any) -> None:508"""Initialize.509
510Args:
511kwargs: See the class-level Fields.
512"""
513try:514from huggingface_hub import (515AsyncInferenceClient,516InferenceClient,517model_info,518)519except ModuleNotFoundError as exc:520raise ImportError(521f"{type(self).__name__} requires huggingface_hub with its inference"522" extra, please run `pip install huggingface_hub[inference]>=0.19.0`."523) from exc524if kwargs.get("model_name") is None:525task = kwargs.get("task", "")526# NOTE: task being None or empty string leads to ValueError,527# which ensures model is present528kwargs["model_name"] = InferenceClient.get_recommended_model(task=task)529logger.debug(530f"Using Hugging Face's recommended model {kwargs['model_name']}"531f" given task {task}."532)533if kwargs.get("task") is None:534task = "conversational"535else:536task = kwargs["task"].lower()537
538super().__init__(**kwargs) # Populate pydantic Fields539self._sync_client = InferenceClient(**self._get_inference_client_kwargs())540self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs())541self._get_model_info = model_info542
543def validate_supported(self, task: str) -> None:544"""545Confirm the contained model_name is deployed on the Inference API service.
546
547Args:
548task: Hugging Face task to check within. A list of all tasks can be
549found here: https://huggingface.co/tasks
550"""
551all_models = self._sync_client.list_deployed_models(frameworks="all")552try:553if self.model_name not in all_models[task]:554raise ValueError(555"The Inference API service doesn't have the model"556f" {self.model_name!r} deployed."557)558except KeyError as exc:559raise KeyError(560f"Input task {task!r} not in possible tasks {list(all_models.keys())}."561) from exc562
563def get_model_info(self, **kwargs: Any) -> "ModelInfo":564"""Get metadata on the current model from Hugging Face."""565return self._get_model_info(self.model_name, **kwargs)566
567@property568def metadata(self) -> LLMMetadata:569return LLMMetadata(570context_window=self.context_window,571num_output=self.num_output,572is_chat_model=self.is_chat_model,573is_function_calling_model=self.is_function_calling_model,574model_name=self.model_name,575)576
577def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:578# default to conversational task as that was the previous functionality579if self.task == "conversational" or self.task is None:580output: "ConversationalOutput" = self._sync_client.conversational(581**{**chat_messages_to_conversational_kwargs(messages), **kwargs}582)583return ChatResponse(584message=ChatMessage(585role=MessageRole.ASSISTANT, content=output["generated_text"]586)587)588else:589# try and use text generation590prompt = self.messages_to_prompt(messages)591completion = self.complete(prompt)592return ChatResponse(593message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text)594)595
596def complete(597self, prompt: str, formatted: bool = False, **kwargs: Any598) -> CompletionResponse:599return CompletionResponse(600text=self._sync_client.text_generation(601prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}602)603)604
605def stream_chat(606self, messages: Sequence[ChatMessage], **kwargs: Any607) -> ChatResponseGen:608raise NotImplementedError609
610def stream_complete(611self, prompt: str, formatted: bool = False, **kwargs: Any612) -> CompletionResponseGen:613raise NotImplementedError614
615async def achat(616self, messages: Sequence[ChatMessage], **kwargs: Any617) -> ChatResponse:618raise NotImplementedError619
620async def acomplete(621self, prompt: str, formatted: bool = False, **kwargs: Any622) -> CompletionResponse:623response = await self._async_client.text_generation(624prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}625)626return CompletionResponse(text=response)627
628async def astream_chat(629self, messages: Sequence[ChatMessage], **kwargs: Any630) -> ChatResponseAsyncGen:631raise NotImplementedError632
633async def astream_complete(634self, prompt: str, formatted: bool = False, **kwargs: Any635) -> CompletionResponseAsyncGen:636raise NotImplementedError637