llama-index
422 строки · 14.1 Кб
1import json2from typing import Any, Callable, Dict, List, Optional, Sequence3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.callbacks import CallbackManager6from llama_index.legacy.core.llms.types import (7ChatMessage,8ChatResponse,9ChatResponseAsyncGen,10ChatResponseGen,11CompletionResponse,12CompletionResponseAsyncGen,13CompletionResponseGen,14LLMMetadata,15)
16from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback17from llama_index.legacy.llms.generic_utils import (18completion_response_to_chat_response,19stream_completion_response_to_chat_response,20)
21from llama_index.legacy.llms.generic_utils import (22messages_to_prompt as generic_messages_to_prompt,23)
24from llama_index.legacy.llms.llm import LLM25from llama_index.legacy.llms.vllm_utils import get_response, post_http_request26from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode27
28
29class Vllm(LLM):30model: Optional[str] = Field(description="The HuggingFace Model to use.")31
32temperature: float = Field(description="The temperature to use for sampling.")33
34tensor_parallel_size: Optional[int] = Field(35default=1,36description="The number of GPUs to use for distributed execution with tensor parallelism.",37)38
39trust_remote_code: Optional[bool] = Field(40default=True,41description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.",42)43
44n: int = Field(45default=1,46description="Number of output sequences to return for the given prompt.",47)48
49best_of: Optional[int] = Field(50default=None,51description="Number of output sequences that are generated from the prompt.",52)53
54presence_penalty: float = Field(55default=0.0,56description="Float that penalizes new tokens based on whether they appear in the generated text so far.",57)58
59frequency_penalty: float = Field(60default=0.0,61description="Float that penalizes new tokens based on their frequency in the generated text so far.",62)63
64top_p: float = Field(65default=1.0,66description="Float that controls the cumulative probability of the top tokens to consider.",67)68
69top_k: int = Field(70default=-1,71description="Integer that controls the number of top tokens to consider.",72)73
74use_beam_search: bool = Field(75default=False, description="Whether to use beam search instead of sampling."76)77
78stop: Optional[List[str]] = Field(79default=None,80description="List of strings that stop the generation when they are generated.",81)82
83ignore_eos: bool = Field(84default=False,85description="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.",86)87
88max_new_tokens: int = Field(89default=512,90description="Maximum number of tokens to generate per output sequence.",91)92
93logprobs: Optional[int] = Field(94default=None,95description="Number of log probabilities to return per output token.",96)97
98dtype: str = Field(99default="auto",100description="The data type for the model weights and activations.",101)102
103download_dir: Optional[str] = Field(104default=None,105description="Directory to download and load the weights. (Default to the default cache dir of huggingface)",106)107
108vllm_kwargs: Dict[str, Any] = Field(109default_factory=dict,110description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.",111)112
113api_url: str = Field(description="The api url for vllm server")114
115_client: Any = PrivateAttr()116
117def __init__(118self,119model: str = "facebook/opt-125m",120temperature: float = 1.0,121tensor_parallel_size: int = 1,122trust_remote_code: bool = True,123n: int = 1,124best_of: Optional[int] = None,125presence_penalty: float = 0.0,126frequency_penalty: float = 0.0,127top_p: float = 1.0,128top_k: int = -1,129use_beam_search: bool = False,130stop: Optional[List[str]] = None,131ignore_eos: bool = False,132max_new_tokens: int = 512,133logprobs: Optional[int] = None,134dtype: str = "auto",135download_dir: Optional[str] = None,136vllm_kwargs: Dict[str, Any] = {},137api_url: Optional[str] = "",138callback_manager: Optional[CallbackManager] = None,139system_prompt: Optional[str] = None,140messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,141completion_to_prompt: Optional[Callable[[str], str]] = None,142pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,143output_parser: Optional[BaseOutputParser] = None,144) -> None:145try:146from vllm import LLM as VLLModel147except ImportError:148raise ImportError(149"Could not import vllm python package. "150"Please install it with `pip install vllm`."151)152if model != "":153self._client = VLLModel(154model=model,155tensor_parallel_size=tensor_parallel_size,156trust_remote_code=trust_remote_code,157dtype=dtype,158download_dir=download_dir,159**vllm_kwargs160)161else:162self._client = None163callback_manager = callback_manager or CallbackManager([])164super().__init__(165model=model,166temperature=temperature,167n=n,168best_of=best_of,169presence_penalty=presence_penalty,170frequency_penalty=frequency_penalty,171top_p=top_p,172top_k=top_k,173use_beam_search=use_beam_search,174stop=stop,175ignore_eos=ignore_eos,176max_new_tokens=max_new_tokens,177logprobs=logprobs,178dtype=dtype,179download_dir=download_dir,180vllm_kwargs=vllm_kwargs,181api_url=api_url,182system_prompt=system_prompt,183messages_to_prompt=messages_to_prompt,184completion_to_prompt=completion_to_prompt,185pydantic_program_mode=pydantic_program_mode,186output_parser=output_parser,187)188
189@classmethod190def class_name(cls) -> str:191return "Vllm"192
193@property194def metadata(self) -> LLMMetadata:195return LLMMetadata(model_name=self.model)196
197@property198def _model_kwargs(self) -> Dict[str, Any]:199base_kwargs = {200"temperature": self.temperature,201"max_tokens": self.max_new_tokens,202"n": self.n,203"frequency_penalty": self.frequency_penalty,204"presence_penalty": self.presence_penalty,205"use_beam_search": self.use_beam_search,206"best_of": self.best_of,207"ignore_eos": self.ignore_eos,208"stop": self.stop,209"logprobs": self.logprobs,210"top_k": self.top_k,211"top_p": self.top_p,212"stop": self.stop,213}214return {**base_kwargs}215
216def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:217return {218**self._model_kwargs,219**kwargs,220}221
222@llm_chat_callback()223def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:224kwargs = kwargs if kwargs else {}225prompt = self.messages_to_prompt(messages)226completion_response = self.complete(prompt, **kwargs)227return completion_response_to_chat_response(completion_response)228
229@llm_completion_callback()230def complete(231self, prompt: str, formatted: bool = False, **kwargs: Any232) -> CompletionResponse:233kwargs = kwargs if kwargs else {}234params = {**self._model_kwargs, **kwargs}235
236from vllm import SamplingParams237
238# build sampling parameters239sampling_params = SamplingParams(**params)240outputs = self._client.generate([prompt], sampling_params)241return CompletionResponse(text=outputs[0].outputs[0].text)242
243@llm_chat_callback()244def stream_chat(245self, messages: Sequence[ChatMessage], **kwargs: Any246) -> ChatResponseGen:247raise (ValueError("Not Implemented"))248
249@llm_completion_callback()250def stream_complete(251self, prompt: str, formatted: bool = False, **kwargs: Any252) -> CompletionResponseGen:253raise (ValueError("Not Implemented"))254
255@llm_chat_callback()256async def achat(257self, messages: Sequence[ChatMessage], **kwargs: Any258) -> ChatResponse:259kwargs = kwargs if kwargs else {}260return self.chat(messages, **kwargs)261
262@llm_completion_callback()263async def acomplete(264self, prompt: str, formatted: bool = False, **kwargs: Any265) -> CompletionResponse:266raise (ValueError("Not Implemented"))267
268@llm_chat_callback()269async def astream_chat(270self, messages: Sequence[ChatMessage], **kwargs: Any271) -> ChatResponseAsyncGen:272raise (ValueError("Not Implemented"))273
274@llm_completion_callback()275async def astream_complete(276self, prompt: str, formatted: bool = False, **kwargs: Any277) -> CompletionResponseAsyncGen:278raise (ValueError("Not Implemented"))279
280
281class VllmServer(Vllm):282def __init__(283self,284model: str = "facebook/opt-125m",285api_url: str = "http://localhost:8000",286temperature: float = 1.0,287tensor_parallel_size: Optional[int] = 1,288trust_remote_code: Optional[bool] = True,289n: int = 1,290best_of: Optional[int] = None,291presence_penalty: float = 0.0,292frequency_penalty: float = 0.0,293top_p: float = 1.0,294top_k: int = -1,295use_beam_search: bool = False,296stop: Optional[List[str]] = None,297ignore_eos: bool = False,298max_new_tokens: int = 512,299logprobs: Optional[int] = None,300dtype: str = "auto",301download_dir: Optional[str] = None,302messages_to_prompt: Optional[Callable] = None,303completion_to_prompt: Optional[Callable] = None,304vllm_kwargs: Dict[str, Any] = {},305callback_manager: Optional[CallbackManager] = None,306output_parser: Optional[BaseOutputParser] = None,307) -> None:308self._client = None309messages_to_prompt = messages_to_prompt or generic_messages_to_prompt310completion_to_prompt = completion_to_prompt or (lambda x: x)311callback_manager = callback_manager or CallbackManager([])312
313model = ""314super().__init__(315model=model,316temperature=temperature,317n=n,318best_of=best_of,319presence_penalty=presence_penalty,320frequency_penalty=frequency_penalty,321top_p=top_p,322top_k=top_k,323use_beam_search=use_beam_search,324stop=stop,325ignore_eos=ignore_eos,326max_new_tokens=max_new_tokens,327logprobs=logprobs,328dtype=dtype,329download_dir=download_dir,330messages_to_prompt=messages_to_prompt,331completion_to_prompt=completion_to_prompt,332vllm_kwargs=vllm_kwargs,333api_url=api_url,334callback_manager=callback_manager,335output_parser=output_parser,336)337
338@classmethod339def class_name(cls) -> str:340return "VllmServer"341
342@llm_completion_callback()343def complete(344self, prompt: str, formatted: bool = False, **kwargs: Any345) -> List[CompletionResponse]:346kwargs = kwargs if kwargs else {}347params = {**self._model_kwargs, **kwargs}348
349from vllm import SamplingParams350
351# build sampling parameters352sampling_params = SamplingParams(**params).__dict__353sampling_params["prompt"] = prompt354response = post_http_request(self.api_url, sampling_params, stream=False)355output = get_response(response)356
357return CompletionResponse(text=output[0])358
359@llm_completion_callback()360def stream_complete(361self, prompt: str, formatted: bool = False, **kwargs: Any362) -> CompletionResponseGen:363kwargs = kwargs if kwargs else {}364params = {**self._model_kwargs, **kwargs}365
366from vllm import SamplingParams367
368# build sampling parameters369sampling_params = SamplingParams(**params).__dict__370sampling_params["prompt"] = prompt371response = post_http_request(self.api_url, sampling_params, stream=True)372
373def gen() -> CompletionResponseGen:374for chunk in response.iter_lines(375chunk_size=8192, decode_unicode=False, delimiter=b"\0"376):377if chunk:378data = json.loads(chunk.decode("utf-8"))379
380yield CompletionResponse(text=data["text"][0])381
382return gen()383
384@llm_completion_callback()385async def acomplete(386self, prompt: str, formatted: bool = False, **kwargs: Any387) -> CompletionResponse:388kwargs = kwargs if kwargs else {}389return self.complete(prompt, **kwargs)390
391@llm_completion_callback()392async def astream_complete(393self, prompt: str, formatted: bool = False, **kwargs: Any394) -> CompletionResponseAsyncGen:395kwargs = kwargs if kwargs else {}396params = {**self._model_kwargs, **kwargs}397
398from vllm import SamplingParams399
400# build sampling parameters401sampling_params = SamplingParams(**params).__dict__402sampling_params["prompt"] = prompt403
404async def gen() -> CompletionResponseAsyncGen:405for message in self.stream_complete(prompt, **kwargs):406yield message407
408return gen()409
410@llm_chat_callback()411def stream_chat(412self, messages: Sequence[ChatMessage], **kwargs: Any413) -> ChatResponseGen:414prompt = self.messages_to_prompt(messages)415completion_response = self.stream_complete(prompt, **kwargs)416return stream_completion_response_to_chat_response(completion_response)417
418@llm_chat_callback()419async def astream_chat(420self, messages: Sequence[ChatMessage], **kwargs: Any421) -> ChatResponseAsyncGen:422return self.stream_chat(messages, **kwargs)423