llama-index

Форк
0
629 строк · 21.9 Кб
1
from dataclasses import dataclass
2
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
7
from llama_index.legacy.core.llms.types import (
8
    ChatMessage,
9
    ChatResponse,
10
    ChatResponseAsyncGen,
11
    ChatResponseGen,
12
    CompletionResponse,
13
    CompletionResponseAsyncGen,
14
    CompletionResponseGen,
15
    LLMMetadata,
16
)
17
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
18
from llama_index.legacy.llms.generic_utils import (
19
    achat_to_completion_decorator,
20
    acompletion_to_chat_decorator,
21
    astream_chat_to_completion_decorator,
22
    astream_completion_to_chat_decorator,
23
    chat_to_completion_decorator,
24
    completion_to_chat_decorator,
25
    stream_chat_to_completion_decorator,
26
    stream_completion_to_chat_decorator,
27
)
28
from llama_index.legacy.llms.konko_utils import (
29
    acompletion_with_retry,
30
    completion_with_retry,
31
    from_openai_message_dict,
32
    import_konko,
33
    is_openai_v1,
34
    resolve_konko_credentials,
35
    to_openai_message_dicts,
36
)
37
from llama_index.legacy.llms.llm import LLM
38
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
39

40
DEFAULT_KONKO_MODEL = "meta-llama/llama-2-13b-chat"
41

42

43
@dataclass
44
class ModelInfo:
45
    name: str
46
    max_context_length: int
47
    is_chat_model: bool
48

49

50
class Konko(LLM):
51
    model: str = Field(
52
        default=DEFAULT_KONKO_MODEL, description="The konko model to use."
53
    )
54
    temperature: float = Field(
55
        default=DEFAULT_TEMPERATURE,
56
        description="The temperature to use during generation.",
57
        gte=0.0,
58
        lte=1.0,
59
    )
60
    max_tokens: Optional[int] = Field(
61
        default=DEFAULT_NUM_OUTPUTS,
62
        description="The maximum number of tokens to generate.",
63
        gt=0,
64
    )
65
    additional_kwargs: Dict[str, Any] = Field(
66
        default_factory=dict, description="Additional kwargs for the konko API."
67
    )
68
    max_retries: int = Field(
69
        default=10, description="The maximum number of API retries.", gte=0
70
    )
71

72
    konko_api_key: str = Field(default=None, description="The konko API key.")
73
    openai_api_key: str = Field(default=None, description="The Openai API key.")
74
    api_type: str = Field(default=None, description="The konko API type.")
75
    model_info_dict: Dict[str, ModelInfo]
76

77
    def __init__(
78
        self,
79
        model: str = DEFAULT_KONKO_MODEL,
80
        temperature: float = DEFAULT_TEMPERATURE,
81
        max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS,
82
        additional_kwargs: Optional[Dict[str, Any]] = None,
83
        max_retries: int = 10,
84
        konko_api_key: Optional[str] = None,
85
        openai_api_key: Optional[str] = None,
86
        api_type: Optional[str] = None,
87
        api_base: Optional[str] = None,
88
        api_version: Optional[str] = None,
89
        callback_manager: Optional[CallbackManager] = None,
90
        system_prompt: Optional[str] = None,
91
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
92
        completion_to_prompt: Optional[Callable[[str], str]] = None,
93
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
94
        output_parser: Optional[BaseOutputParser] = None,
95
        model_info_dict: Optional[Dict[str, ModelInfo]] = None,
96
        **kwargs: Any,
97
    ) -> None:
98
        additional_kwargs = additional_kwargs or {}
99
        (
100
            konko_api_key,
101
            openai_api_key,
102
            api_type,
103
            api_base,
104
            api_version,
105
        ) = resolve_konko_credentials(
106
            konko_api_key=konko_api_key,
107
            openai_api_key=openai_api_key,
108
            api_type=api_type,
109
            api_base=api_base,
110
            api_version=api_version,
111
        )
112
        super().__init__(
113
            model=model,
114
            temperature=temperature,
115
            max_tokens=max_tokens,
116
            additional_kwargs=additional_kwargs,
117
            max_retries=max_retries,
118
            callback_manager=callback_manager,
119
            konko_api_key=konko_api_key,
120
            openai_api_key=openai_api_key,
121
            api_type=api_type,
122
            api_version=api_version,
123
            api_base=api_base,
124
            system_prompt=system_prompt,
125
            messages_to_prompt=messages_to_prompt,
126
            completion_to_prompt=completion_to_prompt,
127
            pydantic_program_mode=pydantic_program_mode,
128
            output_parser=output_parser,
129
            model_info_dict=self._create_model_info_dict(),
130
            **kwargs,
131
        )
132

133
    def _get_model_name(self) -> str:
134
        return self.model
135

136
    @classmethod
137
    def class_name(cls) -> str:
138
        return "Konko_LLM"
139

140
    def _create_model_info_dict(self) -> Dict[str, ModelInfo]:
141
        konko = import_konko()
142

143
        models_info_dict = {}
144
        if is_openai_v1():
145
            models = konko.models.list().data
146
            for model in models:
147
                model_info = ModelInfo(
148
                    name=model.name,
149
                    max_context_length=model.max_context_length,
150
                    is_chat_model=model.is_chat,
151
                )
152
                models_info_dict[model.name] = model_info
153
        else:
154
            models = konko.Model.list().data
155
            for model in models:
156
                model_info = ModelInfo(
157
                    name=model["name"],
158
                    max_context_length=model["max_context_length"],
159
                    is_chat_model=model["is_chat"],
160
                )
161
                models_info_dict[model["name"]] = model_info
162

163
        return models_info_dict
164

165
    def _get_model_info(self) -> ModelInfo:
166
        model_name = self._get_model_name()
167
        model_info = self.model_info_dict.get(model_name)
168
        if model_info is None:
169
            raise ValueError(
170
                f"Unknown model: {model_name}. Please provide a valid Konko model name. "
171
                "Known models are: " + ", ".join(self.model_info_dict.keys())
172
            )
173
        return model_info
174

175
    def _is_chat_model(self) -> bool:
176
        """
177
        Check if the specified model is a chat model.
178

179
        Args:
180
        - model_id (str): The ID of the model to check.
181

182
        Returns:
183
        - bool: True if the model is a chat model, False otherwise.
184

185
        Raises:
186
        - ValueError: If the model_id is not found in the list of models.
187
        """
188
        model_info = self._get_model_info()
189
        return model_info.is_chat_model
190

191
    @property
192
    def metadata(self) -> LLMMetadata:
193
        model_info = self._get_model_info()
194
        return LLMMetadata(
195
            context_window=model_info.max_context_length,
196
            num_output=self.max_tokens,
197
            is_chat_model=model_info.is_chat_model,
198
            model_name=self.model,
199
        )
200

201
    @llm_chat_callback()
202
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
203
        if self._is_chat_model():
204
            chat_fn = self._chat
205
        else:
206
            chat_fn = completion_to_chat_decorator(self._complete)
207
        return chat_fn(messages, **kwargs)
208

209
    @llm_chat_callback()
210
    def stream_chat(
211
        self, messages: Sequence[ChatMessage], **kwargs: Any
212
    ) -> ChatResponseGen:
213
        if self._is_chat_model():
214
            stream_chat_fn = self._stream_chat
215
        else:
216
            stream_chat_fn = stream_completion_to_chat_decorator(self._stream_complete)
217
        return stream_chat_fn(messages, **kwargs)
218

219
    @property
220
    def _credential_kwargs(self) -> Dict[str, Any]:
221
        return {
222
            "konko_api_key": self.konko_api_key,
223
            "api_type": self.api_type,
224
            "openai_api_key": self.openai_api_key,
225
        }
226

227
    @property
228
    def _model_kwargs(self) -> Dict[str, Any]:
229
        base_kwargs = {
230
            "model": self.model,
231
            "temperature": self.temperature,
232
            "max_tokens": self.max_tokens,
233
        }
234
        return {
235
            **base_kwargs,
236
            **self.additional_kwargs,
237
        }
238

239
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
240
        return {
241
            **self._model_kwargs,
242
            **kwargs,
243
        }
244

245
    def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
246
        if not self._is_chat_model():
247
            raise ValueError("This model is not a chat model.")
248

249
        message_dicts = to_openai_message_dicts(messages)
250
        all_kwargs = self._get_all_kwargs(**kwargs)
251
        response = completion_with_retry(
252
            is_chat_model=self._is_chat_model(),
253
            max_retries=self.max_retries,
254
            messages=message_dicts,
255
            stream=False,
256
            **all_kwargs,
257
        )
258
        if is_openai_v1():
259
            message_dict = response.choices[0].message
260
        else:
261
            message_dict = response["choices"][0]["message"]
262
        message = from_openai_message_dict(message_dict)
263

264
        return ChatResponse(
265
            message=message,
266
            raw=response,
267
            additional_kwargs=self._get_response_token_counts(response),
268
        )
269

270
    def _stream_chat(
271
        self, messages: Sequence[ChatMessage], **kwargs: Any
272
    ) -> ChatResponseGen:
273
        if not self._is_chat_model():
274
            raise ValueError("This model is not a chat model.")
275

276
        message_dicts = to_openai_message_dicts(messages)
277
        all_kwargs = self._get_all_kwargs(**kwargs)
278

279
        def gen() -> ChatResponseGen:
280
            content = ""
281
            for response in completion_with_retry(
282
                is_chat_model=self._is_chat_model(),
283
                max_retries=self.max_retries,
284
                messages=message_dicts,
285
                stream=True,
286
                **all_kwargs,
287
            ):
288
                if is_openai_v1():
289
                    if len(response.choices) == 0 and response.prompt_annotations:
290
                        continue
291
                    delta = (
292
                        response.choices[0].delta if len(response.choices) > 0 else {}
293
                    )
294
                    role_value = delta.role
295
                    content_delta = delta.content or ""
296
                else:
297
                    if "choices" not in response or len(response["choices"]) == 0:
298
                        continue
299
                    delta = response["choices"][0].get("delta", {})
300
                    role_value = delta["role"]
301
                    content_delta = delta["content"] or ""
302

303
                role = role_value if role_value is not None else "assistant"
304
                content += content_delta
305
                yield ChatResponse(
306
                    message=ChatMessage(
307
                        role=role,
308
                        content=content,
309
                    ),
310
                    delta=content_delta,
311
                    raw=response,
312
                    additional_kwargs=self._get_response_token_counts(response),
313
                )
314

315
        return gen()
316

317
    @llm_completion_callback()
318
    def complete(
319
        self, prompt: str, formatted: bool = False, **kwargs: Any
320
    ) -> CompletionResponse:
321
        if self._is_chat_model():
322
            complete_fn = chat_to_completion_decorator(self._chat)
323
        else:
324
            complete_fn = self._complete
325
        return complete_fn(prompt, **kwargs)
326

327
    @llm_completion_callback()
328
    def stream_complete(
329
        self, prompt: str, formatted: bool = False, **kwargs: Any
330
    ) -> CompletionResponseGen:
331
        if self._is_chat_model():
332
            stream_complete_fn = stream_chat_to_completion_decorator(self._stream_chat)
333
        else:
334
            stream_complete_fn = self._stream_complete
335
        return stream_complete_fn(prompt, **kwargs)
336

337
    def _get_response_token_counts(self, raw_response: Any) -> dict:
338
        """Get the token usage reported by the response."""
339
        if not isinstance(raw_response, dict):
340
            return {}
341

342
        usage = raw_response.get("usage", {})
343
        # NOTE: other model providers that use the OpenAI client may not report usage
344
        if usage is None:
345
            return {}
346

347
        return {
348
            "prompt_tokens": usage.get("prompt_tokens", 0),
349
            "completion_tokens": usage.get("completion_tokens", 0),
350
            "total_tokens": usage.get("total_tokens", 0),
351
        }
352

353
    def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
354
        if self._is_chat_model():
355
            raise ValueError("This model is a chat model.")
356

357
        all_kwargs = self._get_all_kwargs(**kwargs)
358
        if self.max_tokens is None:
359
            # NOTE: non-chat completion endpoint requires max_tokens to be set
360
            max_tokens = self._get_max_token_for_prompt(prompt)
361
            all_kwargs["max_tokens"] = max_tokens
362

363
        response = completion_with_retry(
364
            is_chat_model=self._is_chat_model(),
365
            max_retries=self.max_retries,
366
            prompt=prompt,
367
            stream=False,
368
            **all_kwargs,
369
        )
370
        if is_openai_v1():
371
            text = response.choices[0].text
372
        else:
373
            text = response["choices"][0]["text"]
374

375
        return CompletionResponse(
376
            text=text,
377
            raw=response,
378
            additional_kwargs=self._get_response_token_counts(response),
379
        )
380

381
    def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
382
        if self._is_chat_model():
383
            raise ValueError("This model is a chat model.")
384

385
        all_kwargs = self._get_all_kwargs(**kwargs)
386
        if self.max_tokens is None:
387
            # NOTE: non-chat completion endpoint requires max_tokens to be set
388
            max_tokens = self._get_max_token_for_prompt(prompt)
389
            all_kwargs["max_tokens"] = max_tokens
390

391
        def gen() -> CompletionResponseGen:
392
            text = ""
393
            for response in completion_with_retry(
394
                is_chat_model=self._is_chat_model(),
395
                max_retries=self.max_retries,
396
                prompt=prompt,
397
                stream=True,
398
                **all_kwargs,
399
            ):
400
                if is_openai_v1():
401
                    if len(response.choices) > 0:
402
                        delta = response.choices[0].text
403
                    else:
404
                        delta = ""
405
                else:
406
                    if len(response["choices"]) > 0:
407
                        delta = response["choices"][0].text
408
                    else:
409
                        delta = ""
410
                text += delta
411
                yield CompletionResponse(
412
                    delta=delta,
413
                    text=text,
414
                    raw=response,
415
                    additional_kwargs=self._get_response_token_counts(response),
416
                )
417

418
        return gen()
419

420
    def _get_max_token_for_prompt(self, prompt: str) -> int:
421
        try:
422
            import tiktoken
423
        except ImportError:
424
            raise ImportError(
425
                "Please install tiktoken to use the max_tokens=None feature."
426
            )
427
        context_window = self.metadata.context_window
428
        encoding = tiktoken.encoding_for_model(self._get_model_name())
429
        tokens = encoding.encode(prompt)
430
        max_token = context_window - len(tokens)
431
        if max_token <= 0:
432
            raise ValueError(
433
                f"The prompt is too long for the model. "
434
                f"Please use a prompt that is less than {context_window} tokens."
435
            )
436
        return max_token
437

438
    # ===== Async Endpoints =====
439
    @llm_chat_callback()
440
    async def achat(
441
        self,
442
        messages: Sequence[ChatMessage],
443
        **kwargs: Any,
444
    ) -> ChatResponse:
445
        achat_fn: Callable[..., Awaitable[ChatResponse]]
446
        if self._is_chat_model():
447
            achat_fn = self._achat
448
        else:
449
            achat_fn = acompletion_to_chat_decorator(self._acomplete)
450
        return await achat_fn(messages, **kwargs)
451

452
    @llm_chat_callback()
453
    async def astream_chat(
454
        self,
455
        messages: Sequence[ChatMessage],
456
        **kwargs: Any,
457
    ) -> ChatResponseAsyncGen:
458
        astream_chat_fn: Callable[..., Awaitable[ChatResponseAsyncGen]]
459
        if self._is_chat_model():
460
            astream_chat_fn = self._astream_chat
461
        else:
462
            astream_chat_fn = astream_completion_to_chat_decorator(
463
                self._astream_complete
464
            )
465
        return await astream_chat_fn(messages, **kwargs)
466

467
    @llm_completion_callback()
468
    async def acomplete(
469
        self, prompt: str, formatted: bool = False, **kwargs: Any
470
    ) -> CompletionResponse:
471
        if self._is_chat_model():
472
            acomplete_fn = achat_to_completion_decorator(self._achat)
473
        else:
474
            acomplete_fn = self._acomplete
475
        return await acomplete_fn(prompt, **kwargs)
476

477
    @llm_completion_callback()
478
    async def astream_complete(
479
        self, prompt: str, formatted: bool = False, **kwargs: Any
480
    ) -> CompletionResponseAsyncGen:
481
        if self._is_chat_model():
482
            astream_complete_fn = astream_chat_to_completion_decorator(
483
                self._astream_chat
484
            )
485
        else:
486
            astream_complete_fn = self._astream_complete
487
        return await astream_complete_fn(prompt, **kwargs)
488

489
    async def _achat(
490
        self, messages: Sequence[ChatMessage], **kwargs: Any
491
    ) -> ChatResponse:
492
        if not self._is_chat_model():
493
            raise ValueError("This model is not a chat model.")
494

495
        message_dicts = to_openai_message_dicts(messages)
496
        all_kwargs = self._get_all_kwargs(**kwargs)
497
        response = await acompletion_with_retry(
498
            is_chat_model=self._is_chat_model(),
499
            max_retries=self.max_retries,
500
            messages=message_dicts,
501
            stream=False,
502
            **all_kwargs,
503
        )
504
        if is_openai_v1:  # type: ignore
505
            message_dict = response.choices[0].message
506
        else:
507
            message_dict = response["choices"][0]["message"]
508
        message = from_openai_message_dict(message_dict)
509

510
        return ChatResponse(
511
            message=message,
512
            raw=response,
513
            additional_kwargs=self._get_response_token_counts(response),
514
        )
515

516
    async def _astream_chat(
517
        self, messages: Sequence[ChatMessage], **kwargs: Any
518
    ) -> ChatResponseAsyncGen:
519
        if not self._is_chat_model():
520
            raise ValueError("This model is not a chat model.")
521

522
        message_dicts = to_openai_message_dicts(messages)
523
        all_kwargs = self._get_all_kwargs(**kwargs)
524

525
        async def gen() -> ChatResponseAsyncGen:
526
            content = ""
527
            function_call: Optional[dict] = None
528
            async for response in await acompletion_with_retry(
529
                is_chat_model=self._is_chat_model(),
530
                max_retries=self.max_retries,
531
                messages=message_dicts,
532
                stream=True,
533
                **all_kwargs,
534
            ):
535
                if is_openai_v1():
536
                    if len(response.choices) > 0:
537
                        delta = response.choices[0].delta
538
                    else:
539
                        delta = {}
540
                    role = delta.role
541
                    content_delta = delta.content
542
                else:
543
                    if len(response["choices"]) > 0:
544
                        delta = response["choices"][0].delta
545
                    else:
546
                        delta = {}
547
                    role = delta["role"]
548
                    content_delta = delta["content"]
549
                content += content_delta
550

551
                yield ChatResponse(
552
                    message=ChatMessage(
553
                        role=role,
554
                        content=content,
555
                    ),
556
                    delta=content_delta,
557
                    raw=response,
558
                    additional_kwargs=self._get_response_token_counts(response),
559
                )
560

561
        return gen()
562

563
    async def _acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
564
        if self._is_chat_model():
565
            raise ValueError("This model is a chat model.")
566

567
        all_kwargs = self._get_all_kwargs(**kwargs)
568
        if self.max_tokens is None:
569
            # NOTE: non-chat completion endpoint requires max_tokens to be set
570
            max_tokens = self._get_max_token_for_prompt(prompt)
571
            all_kwargs["max_tokens"] = max_tokens
572

573
        response = await acompletion_with_retry(
574
            is_chat_model=self._is_chat_model(),
575
            max_retries=self.max_retries,
576
            prompt=prompt,
577
            stream=False,
578
            **all_kwargs,
579
        )
580
        if is_openai_v1():
581
            text = response.choices[0].text
582
        else:
583
            text = response["choices"][0]["text"]
584
        return CompletionResponse(
585
            text=text,
586
            raw=response,
587
            additional_kwargs=self._get_response_token_counts(response),
588
        )
589

590
    async def _astream_complete(
591
        self, prompt: str, **kwargs: Any
592
    ) -> CompletionResponseAsyncGen:
593
        if self._is_chat_model():
594
            raise ValueError("This model is a chat model.")
595

596
        all_kwargs = self._get_all_kwargs(**kwargs)
597
        if self.max_tokens is None:
598
            # NOTE: non-chat completion endpoint requires max_tokens to be set
599
            max_tokens = self._get_max_token_for_prompt(prompt)
600
            all_kwargs["max_tokens"] = max_tokens
601

602
        async def gen() -> CompletionResponseAsyncGen:
603
            text = ""
604
            async for response in await acompletion_with_retry(
605
                is_chat_model=self._is_chat_model(),
606
                max_retries=self.max_retries,
607
                prompt=prompt,
608
                stream=True,
609
                **all_kwargs,
610
            ):
611
                if is_openai_v1():
612
                    if len(response.choices) > 0:
613
                        delta = response.choices[0].text
614
                    else:
615
                        delta = ""
616
                else:
617
                    if len(response["choices"]) > 0:
618
                        delta = response["choices"][0].text
619
                    else:
620
                        delta = ""
621
                text += delta
622
                yield CompletionResponse(
623
                    delta=delta,
624
                    text=text,
625
                    raw=response,
626
                    additional_kwargs=self._get_response_token_counts(response),
627
                )
628

629
        return gen()
630

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.