llama-index

Форк
0
248 строк · 8.2 Кб
1
import random
2
from typing import (
3
    Any,
4
    Dict,
5
    Optional,
6
    Sequence,
7
)
8

9
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
10
from llama_index.legacy.callbacks import CallbackManager
11
from llama_index.legacy.llms.base import (
12
    ChatMessage,
13
    ChatResponse,
14
    ChatResponseAsyncGen,
15
    ChatResponseGen,
16
    CompletionResponse,
17
    CompletionResponseAsyncGen,
18
    CompletionResponseGen,
19
    LLMMetadata,
20
    llm_chat_callback,
21
)
22
from llama_index.legacy.llms.generic_utils import (
23
    completion_to_chat_decorator,
24
)
25
from llama_index.legacy.llms.llm import LLM
26
from llama_index.legacy.llms.nvidia_triton_utils import GrpcTritonClient
27

28
DEFAULT_SERVER_URL = "localhost:8001"
29
DEFAULT_MAX_RETRIES = 3
30
DEFAULT_TIMEOUT = 60.0
31
DEFAULT_MODEL = "ensemble"
32
DEFAULT_TEMPERATURE = 1.0
33
DEFAULT_TOP_P = 0
34
DEFAULT_TOP_K = 1.0
35
DEFAULT_MAX_TOKENS = 100
36
DEFAULT_BEAM_WIDTH = 1
37
DEFAULT_REPTITION_PENALTY = 1.0
38
DEFAULT_LENGTH_PENALTY = 1.0
39
DEFAULT_REUSE_CLIENT = True
40
DEFAULT_TRITON_LOAD_MODEL = True
41

42

43
class NvidiaTriton(LLM):
44
    server_url: str = Field(
45
        default=DEFAULT_SERVER_URL,
46
        description="The URL of the Triton inference server to use.",
47
    )
48
    model_name: str = Field(
49
        default=DEFAULT_MODEL,
50
        description="The name of the Triton hosted model this client should use",
51
    )
52
    temperature: Optional[float] = Field(
53
        default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling"
54
    )
55
    top_p: Optional[float] = Field(
56
        default=DEFAULT_TOP_P, description="The top-p value to use for sampling"
57
    )
58
    top_k: Optional[float] = Field(
59
        default=DEFAULT_TOP_K, description="The top k value to use for sampling"
60
    )
61
    tokens: Optional[int] = Field(
62
        default=DEFAULT_MAX_TOKENS,
63
        description="The maximum number of tokens to generate.",
64
    )
65
    beam_width: Optional[int] = Field(
66
        default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize"
67
    )
68
    repetition_penalty: Optional[float] = Field(
69
        default=DEFAULT_REPTITION_PENALTY,
70
        description="Last n number of tokens to penalize",
71
    )
72
    length_penalty: Optional[float] = Field(
73
        default=DEFAULT_LENGTH_PENALTY,
74
        description="The penalty to apply repeated tokens",
75
    )
76
    max_retries: Optional[int] = Field(
77
        default=DEFAULT_MAX_RETRIES,
78
        description="Maximum number of attempts to retry Triton client invocation before erroring",
79
    )
80
    timeout: Optional[float] = Field(
81
        default=DEFAULT_TIMEOUT,
82
        description="Maximum time (seconds) allowed for a Triton client call before erroring",
83
    )
84
    reuse_client: Optional[bool] = Field(
85
        default=DEFAULT_REUSE_CLIENT,
86
        description="True for reusing the same client instance between invocations",
87
    )
88
    triton_load_model_call: Optional[bool] = Field(
89
        default=DEFAULT_TRITON_LOAD_MODEL,
90
        description="True if a Triton load model API call should be made before using the client",
91
    )
92

93
    _client: Optional[GrpcTritonClient] = PrivateAttr()
94

95
    def __init__(
96
        self,
97
        server_url: str = DEFAULT_SERVER_URL,
98
        model: str = DEFAULT_MODEL,
99
        temperature: float = DEFAULT_TEMPERATURE,
100
        top_p: float = DEFAULT_TOP_P,
101
        top_k: float = DEFAULT_TOP_K,
102
        tokens: Optional[int] = DEFAULT_MAX_TOKENS,
103
        beam_width: int = DEFAULT_BEAM_WIDTH,
104
        repetition_penalty: float = DEFAULT_REPTITION_PENALTY,
105
        length_penalty: float = DEFAULT_LENGTH_PENALTY,
106
        max_retries: int = DEFAULT_MAX_RETRIES,
107
        timeout: float = DEFAULT_TIMEOUT,
108
        reuse_client: bool = DEFAULT_REUSE_CLIENT,
109
        triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL,
110
        callback_manager: Optional[CallbackManager] = None,
111
        additional_kwargs: Optional[Dict[str, Any]] = None,
112
        **kwargs: Any,
113
    ) -> None:
114
        additional_kwargs = additional_kwargs or {}
115

116
        super().__init__(
117
            server_url=server_url,
118
            model=model,
119
            temperature=temperature,
120
            top_p=top_p,
121
            top_k=top_k,
122
            tokens=tokens,
123
            beam_width=beam_width,
124
            repetition_penalty=repetition_penalty,
125
            length_penalty=length_penalty,
126
            max_retries=max_retries,
127
            timeout=timeout,
128
            reuse_client=reuse_client,
129
            triton_load_model_call=triton_load_model_call,
130
            callback_manager=callback_manager,
131
            additional_kwargs=additional_kwargs,
132
            **kwargs,
133
        )
134

135
        try:
136
            self._client = GrpcTritonClient(server_url)
137
        except ImportError as err:
138
            raise ImportError(
139
                "Could not import triton client python package. "
140
                "Please install it with `pip install tritonclient`."
141
            ) from err
142

143
    @property
144
    def _get_model_default_parameters(self) -> Dict[str, Any]:
145
        return {
146
            "tokens": self.tokens,
147
            "top_k": self.top_k,
148
            "top_p": self.top_p,
149
            "temperature": self.temperature,
150
            "repetition_penalty": self.repetition_penalty,
151
            "length_penalty": self.length_penalty,
152
            "beam_width": self.beam_width,
153
        }
154

155
    @property
156
    def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
157
        return {**self._get_model_default_parameters, **kwargs}
158

159
    @property
160
    def _identifying_params(self) -> Dict[str, Any]:
161
        """Get all the identifying parameters."""
162
        return {
163
            "server_url": self.server_url,
164
            "model_name": self.model_name,
165
        }
166

167
    def _get_client(self) -> Any:
168
        """Create or reuse a Triton client connection."""
169
        if not self.reuse_client:
170
            return GrpcTritonClient(self.server_url)
171

172
        if self._client is None:
173
            self._client = GrpcTritonClient(self.server_url)
174
        return self._client
175

176
    @property
177
    def metadata(self) -> LLMMetadata:
178
        """Gather and return metadata about the user Triton configured LLM model."""
179
        return LLMMetadata(
180
            model_name=self.model_name,
181
        )
182

183
    @llm_chat_callback()
184
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
185
        chat_fn = completion_to_chat_decorator(self.complete)
186
        return chat_fn(messages, **kwargs)
187

188
    def stream_chat(
189
        self, messages: Sequence[ChatMessage], **kwargs: Any
190
    ) -> ChatResponseGen:
191
        raise NotImplementedError
192

193
    def complete(
194
        self, prompt: str, formatted: bool = False, **kwargs: Any
195
    ) -> CompletionResponse:
196
        from tritonclient.utils import InferenceServerException
197

198
        client = self._get_client()
199

200
        invocation_params = self._get_model_default_parameters
201
        invocation_params.update(kwargs)
202
        invocation_params["prompt"] = [[prompt]]
203
        model_params = self._identifying_params
204
        model_params.update(kwargs)
205
        request_id = str(random.randint(1, 9999999))  # nosec
206

207
        if self.triton_load_model_call:
208
            client.load_model(model_params["model_name"])
209

210
        result_queue = client.request_streaming(
211
            model_params["model_name"], request_id, **invocation_params
212
        )
213

214
        response = ""
215
        for token in result_queue:
216
            if isinstance(token, InferenceServerException):
217
                client.stop_stream(model_params["model_name"], request_id)
218
                raise token
219
            response = response + token
220

221
        return CompletionResponse(
222
            text=response,
223
        )
224

225
    def stream_complete(
226
        self, prompt: str, formatted: bool = False, **kwargs: Any
227
    ) -> CompletionResponseGen:
228
        raise NotImplementedError
229

230
    async def achat(
231
        self, messages: Sequence[ChatMessage], **kwargs: Any
232
    ) -> ChatResponse:
233
        raise NotImplementedError
234

235
    async def acomplete(
236
        self, prompt: str, formatted: bool = False, **kwargs: Any
237
    ) -> CompletionResponse:
238
        raise NotImplementedError
239

240
    async def astream_chat(
241
        self, messages: Sequence[ChatMessage], **kwargs: Any
242
    ) -> ChatResponseAsyncGen:
243
        raise NotImplementedError
244

245
    async def astream_complete(
246
        self, prompt: str, formatted: bool = False, **kwargs: Any
247
    ) -> CompletionResponseAsyncGen:
248
        raise NotImplementedError
249

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.