llama-index

Форк
0
275 строк · 10.6 Кб
1
import gc
2
import json
3
import os
4
import time
5
from pathlib import Path
6
from typing import Any, Callable, Dict, Optional, Sequence
7

8
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
9
from llama_index.legacy.callbacks import CallbackManager
10
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
11
from llama_index.legacy.llms.base import (
12
    ChatMessage,
13
    ChatResponse,
14
    CompletionResponse,
15
    LLMMetadata,
16
    llm_chat_callback,
17
    llm_completion_callback,
18
)
19
from llama_index.legacy.llms.custom import CustomLLM
20
from llama_index.legacy.llms.generic_utils import completion_response_to_chat_response
21
from llama_index.legacy.llms.nvidia_tensorrt_utils import (
22
    generate_completion_dict,
23
    get_output,
24
    parse_input,
25
)
26

27
EOS_TOKEN = 2
28
PAD_TOKEN = 2
29

30

31
class LocalTensorRTLLM(CustomLLM):
32
    model_path: Optional[str] = Field(description="The path to the trt engine.")
33
    temperature: float = Field(description="The temperature to use for sampling.")
34
    max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
35
    context_window: int = Field(
36
        description="The maximum number of context tokens for the model."
37
    )
38
    messages_to_prompt: Callable = Field(
39
        description="The function to convert messages to a prompt.", exclude=True
40
    )
41
    completion_to_prompt: Callable = Field(
42
        description="The function to convert a completion to a prompt.", exclude=True
43
    )
44
    generate_kwargs: Dict[str, Any] = Field(
45
        default_factory=dict, description="Kwargs used for generation."
46
    )
47
    model_kwargs: Dict[str, Any] = Field(
48
        default_factory=dict, description="Kwargs used for model initialization."
49
    )
50
    verbose: bool = Field(description="Whether to print verbose output.")
51

52
    _model: Any = PrivateAttr()
53
    _model_config: Any = PrivateAttr()
54
    _tokenizer: Any = PrivateAttr()
55
    _max_new_tokens = PrivateAttr()
56
    _sampling_config = PrivateAttr()
57
    _verbose = PrivateAttr()
58

59
    def __init__(
60
        self,
61
        model_path: Optional[str] = None,
62
        engine_name: Optional[str] = None,
63
        tokenizer_dir: Optional[str] = None,
64
        temperature: float = 0.1,
65
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
66
        context_window: int = DEFAULT_CONTEXT_WINDOW,
67
        messages_to_prompt: Optional[Callable] = None,
68
        completion_to_prompt: Optional[Callable] = None,
69
        callback_manager: Optional[CallbackManager] = None,
70
        generate_kwargs: Optional[Dict[str, Any]] = None,
71
        model_kwargs: Optional[Dict[str, Any]] = None,
72
        verbose: bool = False,
73
    ) -> None:
74
        try:
75
            import torch
76
            from transformers import AutoTokenizer
77
        except ImportError:
78
            raise ImportError(
79
                "nvidia_tensorrt requires `pip install torch` and `pip install transformers`."
80
            )
81

82
        try:
83
            import tensorrt_llm
84
            from tensorrt_llm.runtime import ModelConfig, SamplingConfig
85
        except ImportError:
86
            print(
87
                "Unable to import `tensorrt_llm` module. Please ensure you have\
88
                  `tensorrt_llm` installed in your environment. You can run\
89
                  `pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com` to install."
90
            )
91

92
        model_kwargs = model_kwargs or {}
93
        model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
94
        self._max_new_tokens = max_new_tokens
95
        self._verbose = verbose
96
        # check if model is cached
97
        if model_path is not None:
98
            if not os.path.exists(model_path):
99
                raise ValueError(
100
                    "Provided model path does not exist. "
101
                    "Please check the path or provide a model_url to download."
102
                )
103
            else:
104
                engine_dir = model_path
105
                engine_dir_path = Path(engine_dir)
106
                config_path = engine_dir_path / "config.json"
107

108
                # config function
109
                with open(config_path) as f:
110
                    config = json.load(f)
111
                use_gpt_attention_plugin = config["plugin_config"][
112
                    "gpt_attention_plugin"
113
                ]
114
                remove_input_padding = config["plugin_config"]["remove_input_padding"]
115
                tp_size = config["builder_config"]["tensor_parallel"]
116
                pp_size = config["builder_config"]["pipeline_parallel"]
117
                world_size = tp_size * pp_size
118
                assert (
119
                    world_size == tensorrt_llm.mpi_world_size()
120
                ), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
121
                num_heads = config["builder_config"]["num_heads"] // tp_size
122
                hidden_size = config["builder_config"]["hidden_size"] // tp_size
123
                vocab_size = config["builder_config"]["vocab_size"]
124
                num_layers = config["builder_config"]["num_layers"]
125
                num_kv_heads = config["builder_config"].get("num_kv_heads", num_heads)
126
                paged_kv_cache = config["plugin_config"]["paged_kv_cache"]
127
                if config["builder_config"].get("multi_query_mode", False):
128
                    tensorrt_llm.logger.warning(
129
                        "`multi_query_mode` config is deprecated. Please rebuild the engine."
130
                    )
131
                    num_kv_heads = 1
132
                num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
133

134
                self._model_config = ModelConfig(
135
                    num_heads=num_heads,
136
                    num_kv_heads=num_kv_heads,
137
                    hidden_size=hidden_size,
138
                    vocab_size=vocab_size,
139
                    num_layers=num_layers,
140
                    gpt_attention_plugin=use_gpt_attention_plugin,
141
                    paged_kv_cache=paged_kv_cache,
142
                    remove_input_padding=remove_input_padding,
143
                )
144

145
                assert (
146
                    pp_size == 1
147
                ), "Python runtime does not support pipeline parallelism"
148
                world_size = tp_size * pp_size
149

150
                runtime_rank = tensorrt_llm.mpi_rank()
151
                runtime_mapping = tensorrt_llm.Mapping(
152
                    world_size, runtime_rank, tp_size=tp_size, pp_size=pp_size
153
                )
154

155
                # TensorRT-LLM must run on a GPU.
156
                assert (
157
                    torch.cuda.is_available()
158
                ), "LocalTensorRTLLM requires a Nvidia CUDA enabled GPU to operate"
159
                torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
160
                self._tokenizer = AutoTokenizer.from_pretrained(
161
                    tokenizer_dir, legacy=False
162
                )
163
                self._sampling_config = SamplingConfig(
164
                    end_id=EOS_TOKEN,
165
                    pad_id=PAD_TOKEN,
166
                    num_beams=1,
167
                    temperature=temperature,
168
                )
169

170
                serialize_path = engine_dir_path / (engine_name if engine_name else "")
171
                with open(serialize_path, "rb") as f:
172
                    engine_buffer = f.read()
173
                decoder = tensorrt_llm.runtime.GenerationSession(
174
                    self._model_config, engine_buffer, runtime_mapping, debug_mode=False
175
                )
176
                self._model = decoder
177

178
        generate_kwargs = generate_kwargs or {}
179
        generate_kwargs.update(
180
            {"temperature": temperature, "max_tokens": max_new_tokens}
181
        )
182

183
        super().__init__(
184
            model_path=model_path,
185
            temperature=temperature,
186
            context_window=context_window,
187
            max_new_tokens=max_new_tokens,
188
            messages_to_prompt=messages_to_prompt,
189
            completion_to_prompt=completion_to_prompt,
190
            callback_manager=callback_manager,
191
            generate_kwargs=generate_kwargs,
192
            model_kwargs=model_kwargs,
193
            verbose=verbose,
194
        )
195

196
    @classmethod
197
    def class_name(cls) -> str:
198
        """Get class name."""
199
        return "LocalTensorRTLLM"
200

201
    @property
202
    def metadata(self) -> LLMMetadata:
203
        """LLM metadata."""
204
        return LLMMetadata(
205
            context_window=self.context_window,
206
            num_output=self.max_new_tokens,
207
            model_name=self.model_path,
208
        )
209

210
    @llm_chat_callback()
211
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
212
        prompt = self.messages_to_prompt(messages)
213
        completion_response = self.complete(prompt, formatted=True, **kwargs)
214
        return completion_response_to_chat_response(completion_response)
215

216
    @llm_completion_callback()
217
    def complete(
218
        self, prompt: str, formatted: bool = False, **kwargs: Any
219
    ) -> CompletionResponse:
220
        try:
221
            import torch
222
        except ImportError:
223
            raise ImportError("nvidia_tensorrt requires `pip install torch`.")
224

225
        self.generate_kwargs.update({"stream": False})
226

227
        if not formatted:
228
            prompt = self.completion_to_prompt(prompt)
229

230
        input_text = prompt
231
        input_ids, input_lengths = parse_input(
232
            input_text, self._tokenizer, EOS_TOKEN, self._model_config
233
        )
234

235
        max_input_length = torch.max(input_lengths).item()
236
        self._model.setup(
237
            input_lengths.size(0), max_input_length, self._max_new_tokens, 1
238
        )  # beam size is set to 1
239
        if self._verbose:
240
            start_time = time.time()
241

242
        output_ids = self._model.decode(input_ids, input_lengths, self._sampling_config)
243
        torch.cuda.synchronize()
244

245
        elapsed_time = -1.0
246
        if self._verbose:
247
            end_time = time.time()
248
            elapsed_time = end_time - start_time
249

250
        output_txt, output_token_ids = get_output(
251
            output_ids, input_lengths, self._max_new_tokens, self._tokenizer
252
        )
253

254
        if self._verbose:
255
            print(f"Input context length  : {input_ids.shape[1]}")
256
            print(f"Inference time        : {elapsed_time:.2f} seconds")
257
            print(f"Output context length : {len(output_token_ids)} ")
258
            print(
259
                f"Inference token/sec   : {(len(output_token_ids) / elapsed_time):2f}"
260
            )
261

262
        # call garbage collected after inference
263
        torch.cuda.empty_cache()
264
        gc.collect()
265

266
        return CompletionResponse(
267
            text=output_txt,
268
            raw=generate_completion_dict(output_txt, self._model, self.model_path),
269
        )
270

271
    @llm_completion_callback()
272
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
273
        raise NotImplementedError(
274
            "Nvidia TensorRT-LLM does not currently support streaming completion."
275
        )
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.