llama-index
248 строк · 8.2 Кб
1import random2from typing import (3Any,4Dict,5Optional,6Sequence,7)
8
9from llama_index.legacy.bridge.pydantic import Field, PrivateAttr10from llama_index.legacy.callbacks import CallbackManager11from llama_index.legacy.llms.base import (12ChatMessage,13ChatResponse,14ChatResponseAsyncGen,15ChatResponseGen,16CompletionResponse,17CompletionResponseAsyncGen,18CompletionResponseGen,19LLMMetadata,20llm_chat_callback,21)
22from llama_index.legacy.llms.generic_utils import (23completion_to_chat_decorator,24)
25from llama_index.legacy.llms.llm import LLM26from llama_index.legacy.llms.nvidia_triton_utils import GrpcTritonClient27
28DEFAULT_SERVER_URL = "localhost:8001"29DEFAULT_MAX_RETRIES = 330DEFAULT_TIMEOUT = 60.031DEFAULT_MODEL = "ensemble"32DEFAULT_TEMPERATURE = 1.033DEFAULT_TOP_P = 034DEFAULT_TOP_K = 1.035DEFAULT_MAX_TOKENS = 10036DEFAULT_BEAM_WIDTH = 137DEFAULT_REPTITION_PENALTY = 1.038DEFAULT_LENGTH_PENALTY = 1.039DEFAULT_REUSE_CLIENT = True40DEFAULT_TRITON_LOAD_MODEL = True41
42
43class NvidiaTriton(LLM):44server_url: str = Field(45default=DEFAULT_SERVER_URL,46description="The URL of the Triton inference server to use.",47)48model_name: str = Field(49default=DEFAULT_MODEL,50description="The name of the Triton hosted model this client should use",51)52temperature: Optional[float] = Field(53default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling"54)55top_p: Optional[float] = Field(56default=DEFAULT_TOP_P, description="The top-p value to use for sampling"57)58top_k: Optional[float] = Field(59default=DEFAULT_TOP_K, description="The top k value to use for sampling"60)61tokens: Optional[int] = Field(62default=DEFAULT_MAX_TOKENS,63description="The maximum number of tokens to generate.",64)65beam_width: Optional[int] = Field(66default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize"67)68repetition_penalty: Optional[float] = Field(69default=DEFAULT_REPTITION_PENALTY,70description="Last n number of tokens to penalize",71)72length_penalty: Optional[float] = Field(73default=DEFAULT_LENGTH_PENALTY,74description="The penalty to apply repeated tokens",75)76max_retries: Optional[int] = Field(77default=DEFAULT_MAX_RETRIES,78description="Maximum number of attempts to retry Triton client invocation before erroring",79)80timeout: Optional[float] = Field(81default=DEFAULT_TIMEOUT,82description="Maximum time (seconds) allowed for a Triton client call before erroring",83)84reuse_client: Optional[bool] = Field(85default=DEFAULT_REUSE_CLIENT,86description="True for reusing the same client instance between invocations",87)88triton_load_model_call: Optional[bool] = Field(89default=DEFAULT_TRITON_LOAD_MODEL,90description="True if a Triton load model API call should be made before using the client",91)92
93_client: Optional[GrpcTritonClient] = PrivateAttr()94
95def __init__(96self,97server_url: str = DEFAULT_SERVER_URL,98model: str = DEFAULT_MODEL,99temperature: float = DEFAULT_TEMPERATURE,100top_p: float = DEFAULT_TOP_P,101top_k: float = DEFAULT_TOP_K,102tokens: Optional[int] = DEFAULT_MAX_TOKENS,103beam_width: int = DEFAULT_BEAM_WIDTH,104repetition_penalty: float = DEFAULT_REPTITION_PENALTY,105length_penalty: float = DEFAULT_LENGTH_PENALTY,106max_retries: int = DEFAULT_MAX_RETRIES,107timeout: float = DEFAULT_TIMEOUT,108reuse_client: bool = DEFAULT_REUSE_CLIENT,109triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL,110callback_manager: Optional[CallbackManager] = None,111additional_kwargs: Optional[Dict[str, Any]] = None,112**kwargs: Any,113) -> None:114additional_kwargs = additional_kwargs or {}115
116super().__init__(117server_url=server_url,118model=model,119temperature=temperature,120top_p=top_p,121top_k=top_k,122tokens=tokens,123beam_width=beam_width,124repetition_penalty=repetition_penalty,125length_penalty=length_penalty,126max_retries=max_retries,127timeout=timeout,128reuse_client=reuse_client,129triton_load_model_call=triton_load_model_call,130callback_manager=callback_manager,131additional_kwargs=additional_kwargs,132**kwargs,133)134
135try:136self._client = GrpcTritonClient(server_url)137except ImportError as err:138raise ImportError(139"Could not import triton client python package. "140"Please install it with `pip install tritonclient`."141) from err142
143@property144def _get_model_default_parameters(self) -> Dict[str, Any]:145return {146"tokens": self.tokens,147"top_k": self.top_k,148"top_p": self.top_p,149"temperature": self.temperature,150"repetition_penalty": self.repetition_penalty,151"length_penalty": self.length_penalty,152"beam_width": self.beam_width,153}154
155@property156def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:157return {**self._get_model_default_parameters, **kwargs}158
159@property160def _identifying_params(self) -> Dict[str, Any]:161"""Get all the identifying parameters."""162return {163"server_url": self.server_url,164"model_name": self.model_name,165}166
167def _get_client(self) -> Any:168"""Create or reuse a Triton client connection."""169if not self.reuse_client:170return GrpcTritonClient(self.server_url)171
172if self._client is None:173self._client = GrpcTritonClient(self.server_url)174return self._client175
176@property177def metadata(self) -> LLMMetadata:178"""Gather and return metadata about the user Triton configured LLM model."""179return LLMMetadata(180model_name=self.model_name,181)182
183@llm_chat_callback()184def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:185chat_fn = completion_to_chat_decorator(self.complete)186return chat_fn(messages, **kwargs)187
188def stream_chat(189self, messages: Sequence[ChatMessage], **kwargs: Any190) -> ChatResponseGen:191raise NotImplementedError192
193def complete(194self, prompt: str, formatted: bool = False, **kwargs: Any195) -> CompletionResponse:196from tritonclient.utils import InferenceServerException197
198client = self._get_client()199
200invocation_params = self._get_model_default_parameters201invocation_params.update(kwargs)202invocation_params["prompt"] = [[prompt]]203model_params = self._identifying_params204model_params.update(kwargs)205request_id = str(random.randint(1, 9999999)) # nosec206
207if self.triton_load_model_call:208client.load_model(model_params["model_name"])209
210result_queue = client.request_streaming(211model_params["model_name"], request_id, **invocation_params212)213
214response = ""215for token in result_queue:216if isinstance(token, InferenceServerException):217client.stop_stream(model_params["model_name"], request_id)218raise token219response = response + token220
221return CompletionResponse(222text=response,223)224
225def stream_complete(226self, prompt: str, formatted: bool = False, **kwargs: Any227) -> CompletionResponseGen:228raise NotImplementedError229
230async def achat(231self, messages: Sequence[ChatMessage], **kwargs: Any232) -> ChatResponse:233raise NotImplementedError234
235async def acomplete(236self, prompt: str, formatted: bool = False, **kwargs: Any237) -> CompletionResponse:238raise NotImplementedError239
240async def astream_chat(241self, messages: Sequence[ChatMessage], **kwargs: Any242) -> ChatResponseAsyncGen:243raise NotImplementedError244
245async def astream_complete(246self, prompt: str, formatted: bool = False, **kwargs: Any247) -> CompletionResponseAsyncGen:248raise NotImplementedError249