llama-index
480 строк · 17.3 Кб
1import asyncio2import logging3from typing import (4TYPE_CHECKING,5Any,6Callable,7Dict,8List,9Literal,10Optional,11Sequence,12)
13
14from llama_index.legacy.bridge.pydantic import Field, PrivateAttr15from llama_index.legacy.callbacks import CallbackManager16from llama_index.legacy.core.llms.types import (17ChatMessage,18ChatResponse,19ChatResponseAsyncGen,20ChatResponseGen,21CompletionResponse,22CompletionResponseAsyncGen,23CompletionResponseGen,24LLMMetadata,25)
26from llama_index.legacy.llms.base import (27llm_chat_callback,28llm_completion_callback,29)
30from llama_index.legacy.llms.generic_utils import (31completion_response_to_chat_response,32)
33from llama_index.legacy.llms.generic_utils import (34messages_to_prompt as generic_messages_to_prompt,35)
36from llama_index.legacy.llms.llm import LLM37from llama_index.legacy.types import PydanticProgramMode38
39logger = logging.getLogger(__name__)40
41if TYPE_CHECKING:42from typing import TypeVar43
44M = TypeVar("M")45T = TypeVar("T")46Metadata = Any47
48
49class OpenLLM(LLM):50"""OpenLLM LLM."""51
52model_id: str = Field(53description="Given Model ID from HuggingFace Hub. This can be either a pretrained ID or local path. This is synonymous to HuggingFace's '.from_pretrained' first argument"54)55model_version: Optional[str] = Field(56description="Optional model version to save the model as."57)58model_tag: Optional[str] = Field(59description="Optional tag to save to BentoML store."60)61prompt_template: Optional[str] = Field(62description="Optional prompt template to pass for this LLM."63)64backend: Optional[Literal["vllm", "pt"]] = Field(65description="Optional backend to pass for this LLM. By default, it will use vLLM if vLLM is available in local system. Otherwise, it will fallback to PyTorch."66)67quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = Field(68description="Optional quantization methods to use with this LLM. See OpenLLM's --quantize options from `openllm start` for more information."69)70serialization: Literal["safetensors", "legacy"] = Field(71description="Optional serialization methods for this LLM to be save as. Default to 'safetensors', but will fallback to PyTorch pickle `.bin` on some models."72)73trust_remote_code: bool = Field(74description="Optional flag to trust remote code. This is synonymous to Transformers' `trust_remote_code`. Default to False."75)76if TYPE_CHECKING:77from typing import Generic78
79try:80import openllm81
82_llm: openllm.LLM[Any, Any]83except ImportError:84_llm: Any # type: ignore[no-redef]85else:86_llm: Any = PrivateAttr()87
88def __init__(89self,90model_id: str,91model_version: Optional[str] = None,92model_tag: Optional[str] = None,93prompt_template: Optional[str] = None,94backend: Optional[Literal["vllm", "pt"]] = None,95*args: Any,96quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = None,97serialization: Literal["safetensors", "legacy"] = "safetensors",98trust_remote_code: bool = False,99callback_manager: Optional[CallbackManager] = None,100system_prompt: Optional[str] = None,101messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,102completion_to_prompt: Optional[Callable[[str], str]] = None,103pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,104**attrs: Any,105):106try:107import openllm108except ImportError:109raise ImportError(110"OpenLLM is not installed. Please install OpenLLM via `pip install openllm`"111)112self._llm = openllm.LLM[Any, Any](113model_id,114model_version=model_version,115model_tag=model_tag,116prompt_template=prompt_template,117system_message=system_prompt,118backend=backend,119quantize=quantize,120serialisation=serialization,121trust_remote_code=trust_remote_code,122embedded=True,123**attrs,124)125if messages_to_prompt is None:126messages_to_prompt = self._tokenizer_messages_to_prompt127
128# NOTE: We need to do this here to ensure model is saved and revision is set correctly.129assert self._llm.bentomodel130
131super().__init__(132model_id=model_id,133model_version=self._llm.revision,134model_tag=str(self._llm.tag),135prompt_template=prompt_template,136backend=self._llm.__llm_backend__,137quantize=self._llm.quantise,138serialization=self._llm._serialisation,139trust_remote_code=self._llm.trust_remote_code,140callback_manager=callback_manager,141system_prompt=system_prompt,142messages_to_prompt=messages_to_prompt,143completion_to_prompt=completion_to_prompt,144pydantic_program_mode=pydantic_program_mode,145)146
147@classmethod148def class_name(cls) -> str:149return "OpenLLM"150
151@property152def metadata(self) -> LLMMetadata:153"""LLM metadata."""154return LLMMetadata(155num_output=self._llm.config["max_new_tokens"],156model_name=self.model_id,157)158
159def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:160"""Use the tokenizer to convert messages to prompt. Fallback to generic."""161if hasattr(self._llm.tokenizer, "apply_chat_template"):162return self._llm.tokenizer.apply_chat_template(163[message.dict() for message in messages],164tokenize=False,165add_generation_prompt=True,166)167return generic_messages_to_prompt(messages)168
169@llm_completion_callback()170def complete(171self, prompt: str, formatted: bool = False, **kwargs: Any172) -> CompletionResponse:173return asyncio.run(self.acomplete(prompt, **kwargs))174
175@llm_chat_callback()176def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:177return asyncio.run(self.achat(messages, **kwargs))178
179@property180def _loop(self) -> asyncio.AbstractEventLoop:181try:182loop = asyncio.get_running_loop()183except RuntimeError:184loop = asyncio.get_event_loop()185return loop186
187@llm_completion_callback()188def stream_complete(189self, prompt: str, formatted: bool = False, **kwargs: Any190) -> CompletionResponseGen:191generator = self.astream_complete(prompt, **kwargs)192# Yield items from the queue synchronously193while True:194try:195yield self._loop.run_until_complete(generator.__anext__())196except StopAsyncIteration:197break198
199@llm_chat_callback()200def stream_chat(201self, messages: Sequence[ChatMessage], **kwargs: Any202) -> ChatResponseGen:203generator = self.astream_chat(messages, **kwargs)204# Yield items from the queue synchronously205while True:206try:207yield self._loop.run_until_complete(generator.__anext__())208except StopAsyncIteration:209break210
211@llm_chat_callback()212async def achat(213self,214messages: Sequence[ChatMessage],215**kwargs: Any,216) -> ChatResponse:217response = await self.acomplete(self.messages_to_prompt(messages), **kwargs)218return completion_response_to_chat_response(response)219
220@llm_completion_callback()221async def acomplete(222self, prompt: str, formatted: bool = False, **kwargs: Any223) -> CompletionResponse:224response = await self._llm.generate(prompt, **kwargs)225return CompletionResponse(226text=response.outputs[0].text,227raw=response.model_dump(),228additional_kwargs={229"prompt_token_ids": response.prompt_token_ids,230"prompt_logprobs": response.prompt_logprobs,231"finished": response.finished,232"outputs": {233"token_ids": response.outputs[0].token_ids,234"cumulative_logprob": response.outputs[0].cumulative_logprob,235"logprobs": response.outputs[0].logprobs,236"finish_reason": response.outputs[0].finish_reason,237},238},239)240
241@llm_chat_callback()242async def astream_chat(243self,244messages: Sequence[ChatMessage],245**kwargs: Any,246) -> ChatResponseAsyncGen:247async for response_chunk in self.astream_complete(248self.messages_to_prompt(messages), **kwargs249):250yield completion_response_to_chat_response(response_chunk)251
252@llm_completion_callback()253async def astream_complete(254self, prompt: str, formatted: bool = False, **kwargs: Any255) -> CompletionResponseAsyncGen:256config = self._llm.config.model_construct_env(**kwargs)257if config["n"] > 1:258logger.warning("Currently only support n=1")259
260texts: List[List[str]] = [[]] * config["n"]261
262async for response_chunk in self._llm.generate_iterator(prompt, **kwargs):263for output in response_chunk.outputs:264texts[output.index].append(output.text)265yield CompletionResponse(266text=response_chunk.outputs[0].text,267delta=response_chunk.outputs[0].text,268raw=response_chunk.model_dump(),269additional_kwargs={270"prompt_token_ids": response_chunk.prompt_token_ids,271"prompt_logprobs": response_chunk.prompt_logprobs,272"finished": response_chunk.finished,273"outputs": {274"text": response_chunk.outputs[0].text,275"token_ids": response_chunk.outputs[0].token_ids,276"cumulative_logprob": response_chunk.outputs[2770278].cumulative_logprob,279"logprobs": response_chunk.outputs[0].logprobs,280"finish_reason": response_chunk.outputs[0].finish_reason,281},282},283)284
285
286class OpenLLMAPI(LLM):287"""OpenLLM Client interface. This is useful when interacting with a remote OpenLLM server."""288
289address: Optional[str] = Field(290description="OpenLLM server address. This could either be set here or via OPENLLM_ENDPOINT"291)292timeout: int = Field(description="Timeout for sending requests.")293max_retries: int = Field(description="Maximum number of retries.")294api_version: Literal["v1"] = Field(description="OpenLLM Server API version.")295
296if TYPE_CHECKING:297try:298from openllm_client import AsyncHTTPClient, HTTPClient299
300_sync_client: HTTPClient301_async_client: AsyncHTTPClient302except ImportError:303_sync_client: Any # type: ignore[no-redef]304_async_client: Any # type: ignore[no-redef]305else:306_sync_client: Any = PrivateAttr()307_async_client: Any = PrivateAttr()308
309def __init__(310self,311address: Optional[str] = None,312timeout: int = 30,313max_retries: int = 2,314api_version: Literal["v1"] = "v1",315**kwargs: Any,316):317try:318from openllm_client import AsyncHTTPClient, HTTPClient319except ImportError:320raise ImportError(321f'"{type(self).__name__}" requires "openllm-client". Make sure to install with `pip install openllm-client`'322)323super().__init__(324address=address,325timeout=timeout,326max_retries=max_retries,327api_version=api_version,328**kwargs,329)330self._sync_client = HTTPClient(331address=address,332timeout=timeout,333max_retries=max_retries,334api_version=api_version,335)336self._async_client = AsyncHTTPClient(337address=address,338timeout=timeout,339max_retries=max_retries,340api_version=api_version,341)342
343@classmethod344def class_name(cls) -> str:345return "OpenLLM_Client"346
347@property348def _server_metadata(self) -> "Metadata":349return self._sync_client._metadata350
351@property352def _server_config(self) -> Dict[str, Any]:353return self._sync_client._config354
355@property356def metadata(self) -> LLMMetadata:357return LLMMetadata(358num_output=self._server_config["max_new_tokens"],359model_name=self._server_metadata.model_id.replace("/", "--"),360)361
362def _convert_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:363return self._sync_client.helpers.messages(364messages=[365{"role": message.role, "content": message.content}366for message in messages367],368add_generation_prompt=True,369)370
371async def _async_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:372return await self._async_client.helpers.messages(373messages=[374{"role": message.role, "content": message.content}375for message in messages376],377add_generation_prompt=True,378)379
380@llm_completion_callback()381def complete(382self, prompt: str, formatted: bool = False, **kwargs: Any383) -> CompletionResponse:384response = self._sync_client.generate(prompt, **kwargs)385return CompletionResponse(386text=response.outputs[0].text,387raw=response.model_dump(),388additional_kwargs={389"prompt_token_ids": response.prompt_token_ids,390"prompt_logprobs": response.prompt_logprobs,391"finished": response.finished,392"outputs": {393"token_ids": response.outputs[0].token_ids,394"cumulative_logprob": response.outputs[0].cumulative_logprob,395"logprobs": response.outputs[0].logprobs,396"finish_reason": response.outputs[0].finish_reason,397},398},399)400
401@llm_completion_callback()402def stream_complete(403self, prompt: str, formatted: bool = False, **kwargs: Any404) -> CompletionResponseGen:405for response_chunk in self._sync_client.generate_stream(prompt, **kwargs):406yield CompletionResponse(407text=response_chunk.text,408delta=response_chunk.text,409raw=response_chunk.model_dump(),410additional_kwargs={"token_ids": response_chunk.token_ids},411)412
413@llm_chat_callback()414def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:415return completion_response_to_chat_response(416self.complete(self._convert_messages_to_prompt(messages), **kwargs)417)418
419@llm_chat_callback()420def stream_chat(421self, messages: Sequence[ChatMessage], **kwargs: Any422) -> ChatResponseGen:423for response_chunk in self.stream_complete(424self._convert_messages_to_prompt(messages), **kwargs425):426yield completion_response_to_chat_response(response_chunk)427
428@llm_completion_callback()429async def acomplete(430self, prompt: str, formatted: bool = False, **kwargs: Any431) -> CompletionResponse:432response = await self._async_client.generate(prompt, **kwargs)433return CompletionResponse(434text=response.outputs[0].text,435raw=response.model_dump(),436additional_kwargs={437"prompt_token_ids": response.prompt_token_ids,438"prompt_logprobs": response.prompt_logprobs,439"finished": response.finished,440"outputs": {441"token_ids": response.outputs[0].token_ids,442"cumulative_logprob": response.outputs[0].cumulative_logprob,443"logprobs": response.outputs[0].logprobs,444"finish_reason": response.outputs[0].finish_reason,445},446},447)448
449@llm_completion_callback()450async def astream_complete(451self, prompt: str, formatted: bool = False, **kwargs: Any452) -> CompletionResponseAsyncGen:453async for response_chunk in self._async_client.generate_stream(454prompt, **kwargs455):456yield CompletionResponse(457text=response_chunk.text,458delta=response_chunk.text,459raw=response_chunk.model_dump(),460additional_kwargs={"token_ids": response_chunk.token_ids},461)462
463@llm_chat_callback()464async def achat(465self, messages: Sequence[ChatMessage], **kwargs: Any466) -> ChatResponse:467return completion_response_to_chat_response(468await self.acomplete(469await self._async_messages_to_prompt(messages), **kwargs470)471)472
473@llm_chat_callback()474async def astream_chat(475self, messages: Sequence[ChatMessage], **kwargs: Any476) -> ChatResponseAsyncGen:477async for response_chunk in self.astream_complete(478await self._async_messages_to_prompt(messages), **kwargs479):480yield completion_response_to_chat_response(response_chunk)481