llama-index
275 строк · 10.6 Кб
1import gc
2import json
3import os
4import time
5from pathlib import Path
6from typing import Any, Callable, Dict, Optional, Sequence
7
8from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
9from llama_index.legacy.callbacks import CallbackManager
10from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
11from llama_index.legacy.llms.base import (
12ChatMessage,
13ChatResponse,
14CompletionResponse,
15LLMMetadata,
16llm_chat_callback,
17llm_completion_callback,
18)
19from llama_index.legacy.llms.custom import CustomLLM
20from llama_index.legacy.llms.generic_utils import completion_response_to_chat_response
21from llama_index.legacy.llms.nvidia_tensorrt_utils import (
22generate_completion_dict,
23get_output,
24parse_input,
25)
26
27EOS_TOKEN = 2
28PAD_TOKEN = 2
29
30
31class LocalTensorRTLLM(CustomLLM):
32model_path: Optional[str] = Field(description="The path to the trt engine.")
33temperature: float = Field(description="The temperature to use for sampling.")
34max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
35context_window: int = Field(
36description="The maximum number of context tokens for the model."
37)
38messages_to_prompt: Callable = Field(
39description="The function to convert messages to a prompt.", exclude=True
40)
41completion_to_prompt: Callable = Field(
42description="The function to convert a completion to a prompt.", exclude=True
43)
44generate_kwargs: Dict[str, Any] = Field(
45default_factory=dict, description="Kwargs used for generation."
46)
47model_kwargs: Dict[str, Any] = Field(
48default_factory=dict, description="Kwargs used for model initialization."
49)
50verbose: bool = Field(description="Whether to print verbose output.")
51
52_model: Any = PrivateAttr()
53_model_config: Any = PrivateAttr()
54_tokenizer: Any = PrivateAttr()
55_max_new_tokens = PrivateAttr()
56_sampling_config = PrivateAttr()
57_verbose = PrivateAttr()
58
59def __init__(
60self,
61model_path: Optional[str] = None,
62engine_name: Optional[str] = None,
63tokenizer_dir: Optional[str] = None,
64temperature: float = 0.1,
65max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
66context_window: int = DEFAULT_CONTEXT_WINDOW,
67messages_to_prompt: Optional[Callable] = None,
68completion_to_prompt: Optional[Callable] = None,
69callback_manager: Optional[CallbackManager] = None,
70generate_kwargs: Optional[Dict[str, Any]] = None,
71model_kwargs: Optional[Dict[str, Any]] = None,
72verbose: bool = False,
73) -> None:
74try:
75import torch
76from transformers import AutoTokenizer
77except ImportError:
78raise ImportError(
79"nvidia_tensorrt requires `pip install torch` and `pip install transformers`."
80)
81
82try:
83import tensorrt_llm
84from tensorrt_llm.runtime import ModelConfig, SamplingConfig
85except ImportError:
86print(
87"Unable to import `tensorrt_llm` module. Please ensure you have\
88`tensorrt_llm` installed in your environment. You can run\
89`pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com` to install."
90)
91
92model_kwargs = model_kwargs or {}
93model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
94self._max_new_tokens = max_new_tokens
95self._verbose = verbose
96# check if model is cached
97if model_path is not None:
98if not os.path.exists(model_path):
99raise ValueError(
100"Provided model path does not exist. "
101"Please check the path or provide a model_url to download."
102)
103else:
104engine_dir = model_path
105engine_dir_path = Path(engine_dir)
106config_path = engine_dir_path / "config.json"
107
108# config function
109with open(config_path) as f:
110config = json.load(f)
111use_gpt_attention_plugin = config["plugin_config"][
112"gpt_attention_plugin"
113]
114remove_input_padding = config["plugin_config"]["remove_input_padding"]
115tp_size = config["builder_config"]["tensor_parallel"]
116pp_size = config["builder_config"]["pipeline_parallel"]
117world_size = tp_size * pp_size
118assert (
119world_size == tensorrt_llm.mpi_world_size()
120), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
121num_heads = config["builder_config"]["num_heads"] // tp_size
122hidden_size = config["builder_config"]["hidden_size"] // tp_size
123vocab_size = config["builder_config"]["vocab_size"]
124num_layers = config["builder_config"]["num_layers"]
125num_kv_heads = config["builder_config"].get("num_kv_heads", num_heads)
126paged_kv_cache = config["plugin_config"]["paged_kv_cache"]
127if config["builder_config"].get("multi_query_mode", False):
128tensorrt_llm.logger.warning(
129"`multi_query_mode` config is deprecated. Please rebuild the engine."
130)
131num_kv_heads = 1
132num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
133
134self._model_config = ModelConfig(
135num_heads=num_heads,
136num_kv_heads=num_kv_heads,
137hidden_size=hidden_size,
138vocab_size=vocab_size,
139num_layers=num_layers,
140gpt_attention_plugin=use_gpt_attention_plugin,
141paged_kv_cache=paged_kv_cache,
142remove_input_padding=remove_input_padding,
143)
144
145assert (
146pp_size == 1
147), "Python runtime does not support pipeline parallelism"
148world_size = tp_size * pp_size
149
150runtime_rank = tensorrt_llm.mpi_rank()
151runtime_mapping = tensorrt_llm.Mapping(
152world_size, runtime_rank, tp_size=tp_size, pp_size=pp_size
153)
154
155# TensorRT-LLM must run on a GPU.
156assert (
157torch.cuda.is_available()
158), "LocalTensorRTLLM requires a Nvidia CUDA enabled GPU to operate"
159torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
160self._tokenizer = AutoTokenizer.from_pretrained(
161tokenizer_dir, legacy=False
162)
163self._sampling_config = SamplingConfig(
164end_id=EOS_TOKEN,
165pad_id=PAD_TOKEN,
166num_beams=1,
167temperature=temperature,
168)
169
170serialize_path = engine_dir_path / (engine_name if engine_name else "")
171with open(serialize_path, "rb") as f:
172engine_buffer = f.read()
173decoder = tensorrt_llm.runtime.GenerationSession(
174self._model_config, engine_buffer, runtime_mapping, debug_mode=False
175)
176self._model = decoder
177
178generate_kwargs = generate_kwargs or {}
179generate_kwargs.update(
180{"temperature": temperature, "max_tokens": max_new_tokens}
181)
182
183super().__init__(
184model_path=model_path,
185temperature=temperature,
186context_window=context_window,
187max_new_tokens=max_new_tokens,
188messages_to_prompt=messages_to_prompt,
189completion_to_prompt=completion_to_prompt,
190callback_manager=callback_manager,
191generate_kwargs=generate_kwargs,
192model_kwargs=model_kwargs,
193verbose=verbose,
194)
195
196@classmethod
197def class_name(cls) -> str:
198"""Get class name."""
199return "LocalTensorRTLLM"
200
201@property
202def metadata(self) -> LLMMetadata:
203"""LLM metadata."""
204return LLMMetadata(
205context_window=self.context_window,
206num_output=self.max_new_tokens,
207model_name=self.model_path,
208)
209
210@llm_chat_callback()
211def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
212prompt = self.messages_to_prompt(messages)
213completion_response = self.complete(prompt, formatted=True, **kwargs)
214return completion_response_to_chat_response(completion_response)
215
216@llm_completion_callback()
217def complete(
218self, prompt: str, formatted: bool = False, **kwargs: Any
219) -> CompletionResponse:
220try:
221import torch
222except ImportError:
223raise ImportError("nvidia_tensorrt requires `pip install torch`.")
224
225self.generate_kwargs.update({"stream": False})
226
227if not formatted:
228prompt = self.completion_to_prompt(prompt)
229
230input_text = prompt
231input_ids, input_lengths = parse_input(
232input_text, self._tokenizer, EOS_TOKEN, self._model_config
233)
234
235max_input_length = torch.max(input_lengths).item()
236self._model.setup(
237input_lengths.size(0), max_input_length, self._max_new_tokens, 1
238) # beam size is set to 1
239if self._verbose:
240start_time = time.time()
241
242output_ids = self._model.decode(input_ids, input_lengths, self._sampling_config)
243torch.cuda.synchronize()
244
245elapsed_time = -1.0
246if self._verbose:
247end_time = time.time()
248elapsed_time = end_time - start_time
249
250output_txt, output_token_ids = get_output(
251output_ids, input_lengths, self._max_new_tokens, self._tokenizer
252)
253
254if self._verbose:
255print(f"Input context length : {input_ids.shape[1]}")
256print(f"Inference time : {elapsed_time:.2f} seconds")
257print(f"Output context length : {len(output_token_ids)} ")
258print(
259f"Inference token/sec : {(len(output_token_ids) / elapsed_time):2f}"
260)
261
262# call garbage collected after inference
263torch.cuda.empty_cache()
264gc.collect()
265
266return CompletionResponse(
267text=output_txt,
268raw=generate_completion_dict(output_txt, self._model, self.model_path),
269)
270
271@llm_completion_callback()
272def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
273raise NotImplementedError(
274"Nvidia TensorRT-LLM does not currently support streaming completion."
275)
276