llama-index
343 строки · 11.8 Кб
1import abc2import json3import random4import time5from functools import partial6from queue import Queue7from typing import (8TYPE_CHECKING,9Any,10Dict,11List,12Optional,13Type,14Union,15)
16
17import numpy as np18
19if TYPE_CHECKING:20import tritonclient.grpc as grpcclient21import tritonclient.http as httpclient22
23STOP_WORDS = ["</s>"]24RANDOM_SEED = 025
26
27class StreamingResponseGenerator(Queue):28"""A Generator that provides the inference results from an LLM."""29
30def __init__(31self, client: "GrpcTritonClient", request_id: str, force_batch: bool32) -> None:33"""Instantiate the generator class."""34super().__init__()35self._client = client36self.request_id = request_id37self._batch = force_batch38
39def __iter__(self) -> "StreamingResponseGenerator":40"""Return self as a generator."""41return self42
43def __next__(self) -> str:44"""Return the next retrieved token."""45val = self.get()46if val is None or val in STOP_WORDS:47self._stop_stream()48raise StopIteration49return val50
51def _stop_stream(self) -> None:52"""Drain and shutdown the Triton stream."""53self._client.stop_stream(54"tensorrt_llm", self.request_id, signal=not self._batch55)56
57
58class _BaseTritonClient(abc.ABC):59"""An abstraction of the connection to a triton inference server."""60
61def __init__(self, server_url: str) -> None:62"""Initialize the client."""63self._server_url = server_url64self._client = self._inference_server_client(server_url)65
66@property67@abc.abstractmethod68def _inference_server_client(69self,70) -> Union[71Type["grpcclient.InferenceServerClient"],72Type["httpclient.InferenceServerClient"],73]:74"""Return the preferred InferenceServerClient class."""75
76@property77@abc.abstractmethod78def _infer_input(79self,80) -> Union[Type["grpcclient.InferInput"], Type["httpclient.InferInput"]]:81"""Return the preferred InferInput."""82
83@property84@abc.abstractmethod85def _infer_output(86self,87) -> Union[88Type["grpcclient.InferRequestedOutput"], Type["httpclient.InferRequestedOutput"]89]:90"""Return the preferred InferRequestedOutput."""91
92def load_model(self, model_name: str, timeout: int = 1000) -> None:93"""Load a model into the server."""94if self._client.is_model_ready(model_name):95return96
97self._client.load_model(model_name)98t0 = time.perf_counter()99t1 = t0100while not self._client.is_model_ready(model_name) and t1 - t0 < timeout:101t1 = time.perf_counter()102
103if not self._client.is_model_ready(model_name):104raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s")105
106def get_model_list(self) -> List[str]:107"""Get a list of models loaded in the triton server."""108res = self._client.get_model_repository_index(as_json=True)109return [model["name"] for model in res["models"]]110
111def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int:112"""Get the model concurrency."""113self.load_model(model_name, timeout)114instances = self._client.get_model_config(model_name, as_json=True)["config"][115"instance_group"116]117return sum(instance["count"] * len(instance["gpus"]) for instance in instances)118
119def _generate_stop_signals(120self,121) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:122"""Generate the signal to stop the stream."""123inputs = [124self._infer_input("input_ids", [1, 1], "INT32"),125self._infer_input("input_lengths", [1, 1], "INT32"),126self._infer_input("request_output_len", [1, 1], "UINT32"),127self._infer_input("stop", [1, 1], "BOOL"),128]129inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))130inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))131inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))132inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))133return inputs134
135def _generate_outputs(136self,137) -> List[138Union["grpcclient.InferRequestedOutput", "httpclient.InferRequestedOutput"]139]:140"""Generate the expected output structure."""141return [self._infer_output("text_output")]142
143def _prepare_tensor(144self, name: str, input_data: Any145) -> Union["grpcclient.InferInput", "httpclient.InferInput"]:146"""Prepare an input data structure."""147from tritonclient.utils import np_to_triton_dtype148
149t = self._infer_input(150name, input_data.shape, np_to_triton_dtype(input_data.dtype)151)152t.set_data_from_numpy(input_data)153return t154
155def _generate_inputs( # pylint: disable=too-many-arguments,too-many-locals156self,157prompt: str,158tokens: int = 300,159temperature: float = 1.0,160top_k: float = 1,161top_p: float = 0,162beam_width: int = 1,163repetition_penalty: float = 1,164length_penalty: float = 1.0,165stream: bool = True,166) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:167"""Create the input for the triton inference server."""168query = np.array(prompt).astype(object)169request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))170runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))171runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))172temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))173len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))174repetition_penalty_array = (175np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))176)177random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1))178beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))179streaming_data = np.array([[stream]], dtype=bool)180
181return [182self._prepare_tensor("text_input", query),183self._prepare_tensor("max_tokens", request_output_len),184self._prepare_tensor("top_k", runtime_top_k),185self._prepare_tensor("top_p", runtime_top_p),186self._prepare_tensor("temperature", temperature_array),187self._prepare_tensor("length_penalty", len_penalty),188self._prepare_tensor("repetition_penalty", repetition_penalty_array),189self._prepare_tensor("random_seed", random_seed),190self._prepare_tensor("beam_width", beam_width_array),191self._prepare_tensor("stream", streaming_data),192]193
194def _trim_batch_response(self, result_str: str) -> str:195"""Trim the resulting response from a batch request by removing provided prompt and extra generated text."""196# extract the generated part of the prompt197split = result_str.split("[/INST]", 1)198generated = split[-1]199end_token = generated.find("</s>")200if end_token == -1:201return generated202return generated[:end_token].strip()203
204
205class GrpcTritonClient(_BaseTritonClient):206"""GRPC connection to a triton inference server."""207
208@property209def _inference_server_client(210self,211) -> Type["grpcclient.InferenceServerClient"]:212"""Return the preferred InferenceServerClient class."""213import tritonclient.grpc as grpcclient214
215return grpcclient.InferenceServerClient # type: ignore216
217@property218def _infer_input(self) -> Type["grpcclient.InferInput"]:219"""Return the preferred InferInput."""220import tritonclient.grpc as grpcclient221
222return grpcclient.InferInput # type: ignore223
224@property225def _infer_output(226self,227) -> Type["grpcclient.InferRequestedOutput"]:228"""Return the preferred InferRequestedOutput."""229import tritonclient.grpc as grpcclient230
231return grpcclient.InferRequestedOutput # type: ignore232
233def _send_stop_signals(self, model_name: str, request_id: str) -> None:234"""Send the stop signal to the Triton Inference server."""235stop_inputs = self._generate_stop_signals()236self._client.async_stream_infer(237model_name,238stop_inputs,239request_id=request_id,240parameters={"Streaming": True},241)242
243@staticmethod244def _process_result(result: Dict[str, str]) -> str:245"""Post-process the result from the server."""246import google.protobuf.json_format247import tritonclient.grpc as grpcclient248from tritonclient.grpc.service_pb2 import ModelInferResponse249
250message = ModelInferResponse()251generated_text: str = ""252google.protobuf.json_format.Parse(json.dumps(result), message)253infer_result = grpcclient.InferResult(message)254np_res = infer_result.as_numpy("text_output")255
256generated_text = ""257if np_res is not None:258generated_text = "".join([token.decode() for token in np_res])259
260return generated_text261
262def _stream_callback(263self,264result_queue: Queue,265force_batch: bool,266result: Any,267error: str,268) -> None:269"""Add streamed result to queue."""270if error:271result_queue.put(error)272else:273response_raw = result.get_response(as_json=True)274if "outputs" in response_raw:275# the very last response might have no output, just the final flag276response = self._process_result(response_raw)277if force_batch:278response = self._trim_batch_response(response)279
280if response in STOP_WORDS:281result_queue.put(None)282else:283result_queue.put(response)284
285if response_raw["parameters"]["triton_final_response"]["bool_param"]:286# end of the generation287result_queue.put(None)288
289# pylint: disable-next=too-many-arguments290def _send_prompt_streaming(291self,292model_name: str,293request_inputs: Any,294request_outputs: Optional[Any],295request_id: str,296result_queue: StreamingResponseGenerator,297force_batch: bool = False,298) -> None:299"""Send the prompt and start streaming the result."""300self._client.start_stream(301callback=partial(self._stream_callback, result_queue, force_batch)302)303self._client.async_stream_infer(304model_name=model_name,305inputs=request_inputs,306outputs=request_outputs,307request_id=request_id,308)309
310def request_streaming(311self,312model_name: str,313request_id: Optional[str] = None,314force_batch: bool = False,315**params: Any,316) -> StreamingResponseGenerator:317"""Request a streaming connection."""318if not self._client.is_model_ready(model_name):319raise RuntimeError("Cannot request streaming, model is not loaded")320
321if not request_id:322request_id = str(random.randint(1, 9999999)) # nosec323
324result_queue = StreamingResponseGenerator(self, request_id, force_batch)325inputs = self._generate_inputs(stream=not force_batch, **params)326outputs = self._generate_outputs()327self._send_prompt_streaming(328model_name,329inputs,330outputs,331request_id,332result_queue,333force_batch,334)335return result_queue336
337def stop_stream(338self, model_name: str, request_id: str, signal: bool = True339) -> None:340"""Close the streaming connection."""341if signal:342self._send_stop_signals(model_name, request_id)343self._client.stop_stream()344