llama-index

Форк
0
462 строки · 16.6 Кб
1
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_TEMPERATURE
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseAsyncGen,
10
    ChatResponseGen,
11
    CompletionResponse,
12
    CompletionResponseAsyncGen,
13
    CompletionResponseGen,
14
    LLMMetadata,
15
)
16
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
17
from llama_index.legacy.llms.generic_utils import (
18
    achat_to_completion_decorator,
19
    acompletion_to_chat_decorator,
20
    astream_chat_to_completion_decorator,
21
    astream_completion_to_chat_decorator,
22
    chat_to_completion_decorator,
23
    completion_to_chat_decorator,
24
    stream_chat_to_completion_decorator,
25
    stream_completion_to_chat_decorator,
26
)
27
from llama_index.legacy.llms.litellm_utils import (
28
    acompletion_with_retry,
29
    completion_with_retry,
30
    from_litellm_message,
31
    is_function_calling_model,
32
    openai_modelname_to_contextsize,
33
    to_openai_message_dicts,
34
    validate_litellm_api_key,
35
)
36
from llama_index.legacy.llms.llm import LLM
37
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
38

39
DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo"
40

41

42
class LiteLLM(LLM):
43
    model: str = Field(
44
        default=DEFAULT_LITELLM_MODEL,
45
        description=(
46
            "The LiteLLM model to use. "
47
            "For complete list of providers https://docs.litellm.ai/docs/providers"
48
        ),
49
    )
50
    temperature: float = Field(
51
        default=DEFAULT_TEMPERATURE,
52
        description="The temperature to use during generation.",
53
        gte=0.0,
54
        lte=1.0,
55
    )
56
    max_tokens: Optional[int] = Field(
57
        description="The maximum number of tokens to generate.",
58
        gt=0,
59
    )
60
    additional_kwargs: Dict[str, Any] = Field(
61
        default_factory=dict,
62
        description="Additional kwargs for the LLM API.",
63
        # for all inputs https://docs.litellm.ai/docs/completion/input
64
    )
65
    max_retries: int = Field(
66
        default=10, description="The maximum number of API retries."
67
    )
68

69
    def __init__(
70
        self,
71
        model: str = DEFAULT_LITELLM_MODEL,
72
        temperature: float = DEFAULT_TEMPERATURE,
73
        max_tokens: Optional[int] = None,
74
        additional_kwargs: Optional[Dict[str, Any]] = None,
75
        max_retries: int = 10,
76
        api_key: Optional[str] = None,
77
        api_type: Optional[str] = None,
78
        api_base: Optional[str] = None,
79
        callback_manager: Optional[CallbackManager] = None,
80
        system_prompt: Optional[str] = None,
81
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
82
        completion_to_prompt: Optional[Callable[[str], str]] = None,
83
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
84
        output_parser: Optional[BaseOutputParser] = None,
85
        **kwargs: Any,
86
    ) -> None:
87
        if "custom_llm_provider" in kwargs:
88
            if (
89
                kwargs["custom_llm_provider"] != "ollama"
90
                and kwargs["custom_llm_provider"] != "vllm"
91
            ):  # don't check keys for local models
92
                validate_litellm_api_key(api_key, api_type)
93
        else:  # by default assume it's a hosted endpoint
94
            validate_litellm_api_key(api_key, api_type)
95

96
        additional_kwargs = additional_kwargs or {}
97
        if api_key is not None:
98
            additional_kwargs["api_key"] = api_key
99
        if api_type is not None:
100
            additional_kwargs["api_type"] = api_type
101
        if api_base is not None:
102
            additional_kwargs["api_base"] = api_base
103

104
        super().__init__(
105
            model=model,
106
            temperature=temperature,
107
            max_tokens=max_tokens,
108
            additional_kwargs=additional_kwargs,
109
            max_retries=max_retries,
110
            callback_manager=callback_manager,
111
            system_prompt=system_prompt,
112
            messages_to_prompt=messages_to_prompt,
113
            completion_to_prompt=completion_to_prompt,
114
            pydantic_program_mode=pydantic_program_mode,
115
            output_parser=output_parser,
116
            **kwargs,
117
        )
118

119
    def _get_model_name(self) -> str:
120
        model_name = self.model
121
        if "ft-" in model_name:  # legacy fine-tuning
122
            model_name = model_name.split(":")[0]
123
        elif model_name.startswith("ft:"):
124
            model_name = model_name.split(":")[1]
125

126
        return model_name
127

128
    @classmethod
129
    def class_name(cls) -> str:
130
        return "litellm_llm"
131

132
    @property
133
    def metadata(self) -> LLMMetadata:
134
        return LLMMetadata(
135
            context_window=openai_modelname_to_contextsize(self._get_model_name()),
136
            num_output=self.max_tokens or -1,
137
            is_chat_model=True,
138
            is_function_calling_model=is_function_calling_model(self._get_model_name()),
139
            model_name=self.model,
140
        )
141

142
    @llm_chat_callback()
143
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
144
        if self._is_chat_model:
145
            chat_fn = self._chat
146
        else:
147
            chat_fn = completion_to_chat_decorator(self._complete)
148
        return chat_fn(messages, **kwargs)
149

150
    @llm_chat_callback()
151
    def stream_chat(
152
        self, messages: Sequence[ChatMessage], **kwargs: Any
153
    ) -> ChatResponseGen:
154
        if self._is_chat_model:
155
            stream_chat_fn = self._stream_chat
156
        else:
157
            stream_chat_fn = stream_completion_to_chat_decorator(self._stream_complete)
158
        return stream_chat_fn(messages, **kwargs)
159

160
    @llm_completion_callback()
161
    def complete(
162
        self, prompt: str, formatted: bool = False, **kwargs: Any
163
    ) -> CompletionResponse:
164
        # litellm assumes all llms are chat llms
165
        if self._is_chat_model:
166
            complete_fn = chat_to_completion_decorator(self._chat)
167
        else:
168
            complete_fn = self._complete
169

170
        return complete_fn(prompt, **kwargs)
171

172
    @llm_completion_callback()
173
    def stream_complete(
174
        self, prompt: str, formatted: bool = False, **kwargs: Any
175
    ) -> CompletionResponseGen:
176
        if self._is_chat_model:
177
            stream_complete_fn = stream_chat_to_completion_decorator(self._stream_chat)
178
        else:
179
            stream_complete_fn = self._stream_complete
180
        return stream_complete_fn(prompt, **kwargs)
181

182
    @property
183
    def _is_chat_model(self) -> bool:
184
        # litellm assumes all llms are chat llms
185
        return True
186

187
    @property
188
    def _model_kwargs(self) -> Dict[str, Any]:
189
        base_kwargs = {
190
            "model": self.model,
191
            "temperature": self.temperature,
192
            "max_tokens": self.max_tokens,
193
        }
194
        return {
195
            **base_kwargs,
196
            **self.additional_kwargs,
197
        }
198

199
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
200
        return {
201
            **self._model_kwargs,
202
            **kwargs,
203
        }
204

205
    def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
206
        if not self._is_chat_model:
207
            raise ValueError("This model is not a chat model.")
208

209
        message_dicts = to_openai_message_dicts(messages)
210
        all_kwargs = self._get_all_kwargs(**kwargs)
211
        if "max_tokens" in all_kwargs and all_kwargs["max_tokens"] is None:
212
            all_kwargs.pop(
213
                "max_tokens"
214
            )  # don't send max_tokens == None, this throws errors for Non OpenAI providers
215

216
        response = completion_with_retry(
217
            is_chat_model=self._is_chat_model,
218
            max_retries=self.max_retries,
219
            messages=message_dicts,
220
            stream=False,
221
            **all_kwargs,
222
        )
223
        message_dict = response["choices"][0]["message"]
224
        message = from_litellm_message(message_dict)
225

226
        return ChatResponse(
227
            message=message,
228
            raw=response,
229
            additional_kwargs=self._get_response_token_counts(response),
230
        )
231

232
    def _stream_chat(
233
        self, messages: Sequence[ChatMessage], **kwargs: Any
234
    ) -> ChatResponseGen:
235
        if not self._is_chat_model:
236
            raise ValueError("This model is not a chat model.")
237

238
        message_dicts = to_openai_message_dicts(messages)
239
        all_kwargs = self._get_all_kwargs(**kwargs)
240
        if "max_tokens" in all_kwargs and all_kwargs["max_tokens"] is None:
241
            all_kwargs.pop(
242
                "max_tokens"
243
            )  # don't send max_tokens == None, this throws errors for Non OpenAI providers
244

245
        def gen() -> ChatResponseGen:
246
            content = ""
247
            function_call: Optional[dict] = None
248
            for response in completion_with_retry(
249
                is_chat_model=self._is_chat_model,
250
                max_retries=self.max_retries,
251
                messages=message_dicts,
252
                stream=True,
253
                **all_kwargs,
254
            ):
255
                delta = response["choices"][0]["delta"]
256
                role = delta.get("role", "assistant")
257
                content_delta = delta.get("content", "") or ""
258
                content += content_delta
259

260
                function_call_delta = delta.get("function_call", None)
261
                if function_call_delta is not None:
262
                    if function_call is None:
263
                        function_call = function_call_delta
264

265
                        ## ensure we do not add a blank function call
266
                        if function_call.get("function_name", "") is None:
267
                            del function_call["function_name"]
268
                    else:
269
                        function_call["arguments"] += function_call_delta["arguments"]
270

271
                additional_kwargs = {}
272
                if function_call is not None:
273
                    additional_kwargs["function_call"] = function_call
274

275
                yield ChatResponse(
276
                    message=ChatMessage(
277
                        role=role,
278
                        content=content,
279
                        additional_kwargs=additional_kwargs,
280
                    ),
281
                    delta=content_delta,
282
                    raw=response,
283
                    additional_kwargs=self._get_response_token_counts(response),
284
                )
285

286
        return gen()
287

288
    def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
289
        raise NotImplementedError("litellm assumes all llms are chat llms.")
290

291
    def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
292
        raise NotImplementedError("litellm assumes all llms are chat llms.")
293

294
    def _get_max_token_for_prompt(self, prompt: str) -> int:
295
        try:
296
            import tiktoken
297
        except ImportError:
298
            raise ImportError(
299
                "Please install tiktoken to use the max_tokens=None feature."
300
            )
301
        context_window = self.metadata.context_window
302
        try:
303
            encoding = tiktoken.encoding_for_model(self._get_model_name())
304
        except KeyError:
305
            encoding = encoding = tiktoken.get_encoding(
306
                "cl100k_base"
307
            )  # default to using cl10k_base
308
        tokens = encoding.encode(prompt)
309
        max_token = context_window - len(tokens)
310
        if max_token <= 0:
311
            raise ValueError(
312
                f"The prompt is too long for the model. "
313
                f"Please use a prompt that is less than {context_window} tokens."
314
            )
315
        return max_token
316

317
    def _get_response_token_counts(self, raw_response: Any) -> dict:
318
        """Get the token usage reported by the response."""
319
        if not isinstance(raw_response, dict):
320
            return {}
321

322
        usage = raw_response.get("usage", {})
323
        return {
324
            "prompt_tokens": usage.get("prompt_tokens", 0),
325
            "completion_tokens": usage.get("completion_tokens", 0),
326
            "total_tokens": usage.get("total_tokens", 0),
327
        }
328

329
    # ===== Async Endpoints =====
330
    @llm_chat_callback()
331
    async def achat(
332
        self,
333
        messages: Sequence[ChatMessage],
334
        **kwargs: Any,
335
    ) -> ChatResponse:
336
        achat_fn: Callable[..., Awaitable[ChatResponse]]
337
        if self._is_chat_model:
338
            achat_fn = self._achat
339
        else:
340
            achat_fn = acompletion_to_chat_decorator(self._acomplete)
341
        return await achat_fn(messages, **kwargs)
342

343
    @llm_chat_callback()
344
    async def astream_chat(
345
        self,
346
        messages: Sequence[ChatMessage],
347
        **kwargs: Any,
348
    ) -> ChatResponseAsyncGen:
349
        astream_chat_fn: Callable[..., Awaitable[ChatResponseAsyncGen]]
350
        if self._is_chat_model:
351
            astream_chat_fn = self._astream_chat
352
        else:
353
            astream_chat_fn = astream_completion_to_chat_decorator(
354
                self._astream_complete
355
            )
356
        return await astream_chat_fn(messages, **kwargs)
357

358
    @llm_completion_callback()
359
    async def acomplete(
360
        self, prompt: str, formatted: bool = False, **kwargs: Any
361
    ) -> CompletionResponse:
362
        if self._is_chat_model:
363
            acomplete_fn = achat_to_completion_decorator(self._achat)
364
        else:
365
            acomplete_fn = self._acomplete
366
        return await acomplete_fn(prompt, **kwargs)
367

368
    @llm_completion_callback()
369
    async def astream_complete(
370
        self, prompt: str, formatted: bool = False, **kwargs: Any
371
    ) -> CompletionResponseAsyncGen:
372
        if self._is_chat_model:
373
            astream_complete_fn = astream_chat_to_completion_decorator(
374
                self._astream_chat
375
            )
376
        else:
377
            astream_complete_fn = self._astream_complete
378
        return await astream_complete_fn(prompt, **kwargs)
379

380
    async def _achat(
381
        self, messages: Sequence[ChatMessage], **kwargs: Any
382
    ) -> ChatResponse:
383
        if not self._is_chat_model:
384
            raise ValueError("This model is not a chat model.")
385

386
        message_dicts = to_openai_message_dicts(messages)
387
        all_kwargs = self._get_all_kwargs(**kwargs)
388
        response = await acompletion_with_retry(
389
            is_chat_model=self._is_chat_model,
390
            max_retries=self.max_retries,
391
            messages=message_dicts,
392
            stream=False,
393
            **all_kwargs,
394
        )
395
        message_dict = response["choices"][0]["message"]
396
        message = from_litellm_message(message_dict)
397

398
        return ChatResponse(
399
            message=message,
400
            raw=response,
401
            additional_kwargs=self._get_response_token_counts(response),
402
        )
403

404
    async def _astream_chat(
405
        self, messages: Sequence[ChatMessage], **kwargs: Any
406
    ) -> ChatResponseAsyncGen:
407
        if not self._is_chat_model:
408
            raise ValueError("This model is not a chat model.")
409

410
        message_dicts = to_openai_message_dicts(messages)
411
        all_kwargs = self._get_all_kwargs(**kwargs)
412

413
        async def gen() -> ChatResponseAsyncGen:
414
            content = ""
415
            function_call: Optional[dict] = None
416
            async for response in await acompletion_with_retry(
417
                is_chat_model=self._is_chat_model,
418
                max_retries=self.max_retries,
419
                messages=message_dicts,
420
                stream=True,
421
                **all_kwargs,
422
            ):
423
                delta = response["choices"][0]["delta"]
424
                role = delta.get("role", "assistant")
425
                content_delta = delta.get("content", "") or ""
426
                content += content_delta
427

428
                function_call_delta = delta.get("function_call", None)
429
                if function_call_delta is not None:
430
                    if function_call is None:
431
                        function_call = function_call_delta
432

433
                        ## ensure we do not add a blank function call
434
                        if function_call.get("function_name", "") is None:
435
                            del function_call["function_name"]
436
                    else:
437
                        function_call["arguments"] += function_call_delta["arguments"]
438

439
                additional_kwargs = {}
440
                if function_call is not None:
441
                    additional_kwargs["function_call"] = function_call
442

443
                yield ChatResponse(
444
                    message=ChatMessage(
445
                        role=role,
446
                        content=content,
447
                        additional_kwargs=additional_kwargs,
448
                    ),
449
                    delta=content_delta,
450
                    raw=response,
451
                    additional_kwargs=self._get_response_token_counts(response),
452
                )
453

454
        return gen()
455

456
    async def _acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
457
        raise NotImplementedError("litellm assumes all llms are chat llms.")
458

459
    async def _astream_complete(
460
        self, prompt: str, **kwargs: Any
461
    ) -> CompletionResponseAsyncGen:
462
        raise NotImplementedError("litellm assumes all llms are chat llms.")
463

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.