llama-index

Форк
0
343 строки · 11.8 Кб
1
import abc
2
import json
3
import random
4
import time
5
from functools import partial
6
from queue import Queue
7
from typing import (
8
    TYPE_CHECKING,
9
    Any,
10
    Dict,
11
    List,
12
    Optional,
13
    Type,
14
    Union,
15
)
16

17
import numpy as np
18

19
if TYPE_CHECKING:
20
    import tritonclient.grpc as grpcclient
21
    import tritonclient.http as httpclient
22

23
STOP_WORDS = ["</s>"]
24
RANDOM_SEED = 0
25

26

27
class StreamingResponseGenerator(Queue):
28
    """A Generator that provides the inference results from an LLM."""
29

30
    def __init__(
31
        self, client: "GrpcTritonClient", request_id: str, force_batch: bool
32
    ) -> None:
33
        """Instantiate the generator class."""
34
        super().__init__()
35
        self._client = client
36
        self.request_id = request_id
37
        self._batch = force_batch
38

39
    def __iter__(self) -> "StreamingResponseGenerator":
40
        """Return self as a generator."""
41
        return self
42

43
    def __next__(self) -> str:
44
        """Return the next retrieved token."""
45
        val = self.get()
46
        if val is None or val in STOP_WORDS:
47
            self._stop_stream()
48
            raise StopIteration
49
        return val
50

51
    def _stop_stream(self) -> None:
52
        """Drain and shutdown the Triton stream."""
53
        self._client.stop_stream(
54
            "tensorrt_llm", self.request_id, signal=not self._batch
55
        )
56

57

58
class _BaseTritonClient(abc.ABC):
59
    """An abstraction of the connection to a triton inference server."""
60

61
    def __init__(self, server_url: str) -> None:
62
        """Initialize the client."""
63
        self._server_url = server_url
64
        self._client = self._inference_server_client(server_url)
65

66
    @property
67
    @abc.abstractmethod
68
    def _inference_server_client(
69
        self,
70
    ) -> Union[
71
        Type["grpcclient.InferenceServerClient"],
72
        Type["httpclient.InferenceServerClient"],
73
    ]:
74
        """Return the preferred InferenceServerClient class."""
75

76
    @property
77
    @abc.abstractmethod
78
    def _infer_input(
79
        self,
80
    ) -> Union[Type["grpcclient.InferInput"], Type["httpclient.InferInput"]]:
81
        """Return the preferred InferInput."""
82

83
    @property
84
    @abc.abstractmethod
85
    def _infer_output(
86
        self,
87
    ) -> Union[
88
        Type["grpcclient.InferRequestedOutput"], Type["httpclient.InferRequestedOutput"]
89
    ]:
90
        """Return the preferred InferRequestedOutput."""
91

92
    def load_model(self, model_name: str, timeout: int = 1000) -> None:
93
        """Load a model into the server."""
94
        if self._client.is_model_ready(model_name):
95
            return
96

97
        self._client.load_model(model_name)
98
        t0 = time.perf_counter()
99
        t1 = t0
100
        while not self._client.is_model_ready(model_name) and t1 - t0 < timeout:
101
            t1 = time.perf_counter()
102

103
        if not self._client.is_model_ready(model_name):
104
            raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s")
105

106
    def get_model_list(self) -> List[str]:
107
        """Get a list of models loaded in the triton server."""
108
        res = self._client.get_model_repository_index(as_json=True)
109
        return [model["name"] for model in res["models"]]
110

111
    def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int:
112
        """Get the model concurrency."""
113
        self.load_model(model_name, timeout)
114
        instances = self._client.get_model_config(model_name, as_json=True)["config"][
115
            "instance_group"
116
        ]
117
        return sum(instance["count"] * len(instance["gpus"]) for instance in instances)
118

119
    def _generate_stop_signals(
120
        self,
121
    ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:
122
        """Generate the signal to stop the stream."""
123
        inputs = [
124
            self._infer_input("input_ids", [1, 1], "INT32"),
125
            self._infer_input("input_lengths", [1, 1], "INT32"),
126
            self._infer_input("request_output_len", [1, 1], "UINT32"),
127
            self._infer_input("stop", [1, 1], "BOOL"),
128
        ]
129
        inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
130
        inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
131
        inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
132
        inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
133
        return inputs
134

135
    def _generate_outputs(
136
        self,
137
    ) -> List[
138
        Union["grpcclient.InferRequestedOutput", "httpclient.InferRequestedOutput"]
139
    ]:
140
        """Generate the expected output structure."""
141
        return [self._infer_output("text_output")]
142

143
    def _prepare_tensor(
144
        self, name: str, input_data: Any
145
    ) -> Union["grpcclient.InferInput", "httpclient.InferInput"]:
146
        """Prepare an input data structure."""
147
        from tritonclient.utils import np_to_triton_dtype
148

149
        t = self._infer_input(
150
            name, input_data.shape, np_to_triton_dtype(input_data.dtype)
151
        )
152
        t.set_data_from_numpy(input_data)
153
        return t
154

155
    def _generate_inputs(  # pylint: disable=too-many-arguments,too-many-locals
156
        self,
157
        prompt: str,
158
        tokens: int = 300,
159
        temperature: float = 1.0,
160
        top_k: float = 1,
161
        top_p: float = 0,
162
        beam_width: int = 1,
163
        repetition_penalty: float = 1,
164
        length_penalty: float = 1.0,
165
        stream: bool = True,
166
    ) -> List[Union["grpcclient.InferInput", "httpclient.InferInput"]]:
167
        """Create the input for the triton inference server."""
168
        query = np.array(prompt).astype(object)
169
        request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
170
        runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
171
        runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
172
        temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
173
        len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
174
        repetition_penalty_array = (
175
            np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
176
        )
177
        random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1))
178
        beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
179
        streaming_data = np.array([[stream]], dtype=bool)
180

181
        return [
182
            self._prepare_tensor("text_input", query),
183
            self._prepare_tensor("max_tokens", request_output_len),
184
            self._prepare_tensor("top_k", runtime_top_k),
185
            self._prepare_tensor("top_p", runtime_top_p),
186
            self._prepare_tensor("temperature", temperature_array),
187
            self._prepare_tensor("length_penalty", len_penalty),
188
            self._prepare_tensor("repetition_penalty", repetition_penalty_array),
189
            self._prepare_tensor("random_seed", random_seed),
190
            self._prepare_tensor("beam_width", beam_width_array),
191
            self._prepare_tensor("stream", streaming_data),
192
        ]
193

194
    def _trim_batch_response(self, result_str: str) -> str:
195
        """Trim the resulting response from a batch request by removing provided prompt and extra generated text."""
196
        # extract the generated part of the prompt
197
        split = result_str.split("[/INST]", 1)
198
        generated = split[-1]
199
        end_token = generated.find("</s>")
200
        if end_token == -1:
201
            return generated
202
        return generated[:end_token].strip()
203

204

205
class GrpcTritonClient(_BaseTritonClient):
206
    """GRPC connection to a triton inference server."""
207

208
    @property
209
    def _inference_server_client(
210
        self,
211
    ) -> Type["grpcclient.InferenceServerClient"]:
212
        """Return the preferred InferenceServerClient class."""
213
        import tritonclient.grpc as grpcclient
214

215
        return grpcclient.InferenceServerClient  # type: ignore
216

217
    @property
218
    def _infer_input(self) -> Type["grpcclient.InferInput"]:
219
        """Return the preferred InferInput."""
220
        import tritonclient.grpc as grpcclient
221

222
        return grpcclient.InferInput  # type: ignore
223

224
    @property
225
    def _infer_output(
226
        self,
227
    ) -> Type["grpcclient.InferRequestedOutput"]:
228
        """Return the preferred InferRequestedOutput."""
229
        import tritonclient.grpc as grpcclient
230

231
        return grpcclient.InferRequestedOutput  # type: ignore
232

233
    def _send_stop_signals(self, model_name: str, request_id: str) -> None:
234
        """Send the stop signal to the Triton Inference server."""
235
        stop_inputs = self._generate_stop_signals()
236
        self._client.async_stream_infer(
237
            model_name,
238
            stop_inputs,
239
            request_id=request_id,
240
            parameters={"Streaming": True},
241
        )
242

243
    @staticmethod
244
    def _process_result(result: Dict[str, str]) -> str:
245
        """Post-process the result from the server."""
246
        import google.protobuf.json_format
247
        import tritonclient.grpc as grpcclient
248
        from tritonclient.grpc.service_pb2 import ModelInferResponse
249

250
        message = ModelInferResponse()
251
        generated_text: str = ""
252
        google.protobuf.json_format.Parse(json.dumps(result), message)
253
        infer_result = grpcclient.InferResult(message)
254
        np_res = infer_result.as_numpy("text_output")
255

256
        generated_text = ""
257
        if np_res is not None:
258
            generated_text = "".join([token.decode() for token in np_res])
259

260
        return generated_text
261

262
    def _stream_callback(
263
        self,
264
        result_queue: Queue,
265
        force_batch: bool,
266
        result: Any,
267
        error: str,
268
    ) -> None:
269
        """Add streamed result to queue."""
270
        if error:
271
            result_queue.put(error)
272
        else:
273
            response_raw = result.get_response(as_json=True)
274
            if "outputs" in response_raw:
275
                # the very last response might have no output, just the final flag
276
                response = self._process_result(response_raw)
277
                if force_batch:
278
                    response = self._trim_batch_response(response)
279

280
                if response in STOP_WORDS:
281
                    result_queue.put(None)
282
                else:
283
                    result_queue.put(response)
284

285
            if response_raw["parameters"]["triton_final_response"]["bool_param"]:
286
                # end of the generation
287
                result_queue.put(None)
288

289
    # pylint: disable-next=too-many-arguments
290
    def _send_prompt_streaming(
291
        self,
292
        model_name: str,
293
        request_inputs: Any,
294
        request_outputs: Optional[Any],
295
        request_id: str,
296
        result_queue: StreamingResponseGenerator,
297
        force_batch: bool = False,
298
    ) -> None:
299
        """Send the prompt and start streaming the result."""
300
        self._client.start_stream(
301
            callback=partial(self._stream_callback, result_queue, force_batch)
302
        )
303
        self._client.async_stream_infer(
304
            model_name=model_name,
305
            inputs=request_inputs,
306
            outputs=request_outputs,
307
            request_id=request_id,
308
        )
309

310
    def request_streaming(
311
        self,
312
        model_name: str,
313
        request_id: Optional[str] = None,
314
        force_batch: bool = False,
315
        **params: Any,
316
    ) -> StreamingResponseGenerator:
317
        """Request a streaming connection."""
318
        if not self._client.is_model_ready(model_name):
319
            raise RuntimeError("Cannot request streaming, model is not loaded")
320

321
        if not request_id:
322
            request_id = str(random.randint(1, 9999999))  # nosec
323

324
        result_queue = StreamingResponseGenerator(self, request_id, force_batch)
325
        inputs = self._generate_inputs(stream=not force_batch, **params)
326
        outputs = self._generate_outputs()
327
        self._send_prompt_streaming(
328
            model_name,
329
            inputs,
330
            outputs,
331
            request_id,
332
            result_queue,
333
            force_batch,
334
        )
335
        return result_queue
336

337
    def stop_stream(
338
        self, model_name: str, request_id: str, signal: bool = True
339
    ) -> None:
340
        """Close the streaming connection."""
341
        if signal:
342
            self._send_stop_signals(model_name, request_id)
343
        self._client.stop_stream()
344

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.