llama-index

Форк
0
656 строк · 23.0 Кб
1
from typing import (
2
    Any,
3
    Awaitable,
4
    Callable,
5
    Dict,
6
    List,
7
    Optional,
8
    Protocol,
9
    Sequence,
10
    cast,
11
    runtime_checkable,
12
)
13

14
import httpx
15
import tiktoken
16
from openai import AsyncOpenAI, AzureOpenAI
17
from openai import OpenAI as SyncOpenAI
18
from openai.types.chat.chat_completion_chunk import (
19
    ChatCompletionChunk,
20
    ChoiceDelta,
21
    ChoiceDeltaToolCall,
22
)
23

24
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
25
from llama_index.legacy.callbacks import CallbackManager
26
from llama_index.legacy.constants import (
27
    DEFAULT_TEMPERATURE,
28
)
29
from llama_index.legacy.core.llms.types import (
30
    ChatMessage,
31
    ChatResponse,
32
    ChatResponseAsyncGen,
33
    ChatResponseGen,
34
    CompletionResponse,
35
    CompletionResponseAsyncGen,
36
    CompletionResponseGen,
37
    LLMMetadata,
38
    MessageRole,
39
)
40
from llama_index.legacy.llms.base import (
41
    llm_chat_callback,
42
    llm_completion_callback,
43
)
44
from llama_index.legacy.llms.generic_utils import (
45
    achat_to_completion_decorator,
46
    acompletion_to_chat_decorator,
47
    astream_chat_to_completion_decorator,
48
    astream_completion_to_chat_decorator,
49
    chat_to_completion_decorator,
50
    completion_to_chat_decorator,
51
    stream_chat_to_completion_decorator,
52
    stream_completion_to_chat_decorator,
53
)
54
from llama_index.legacy.llms.llm import LLM
55
from llama_index.legacy.llms.openai_utils import (
56
    from_openai_message,
57
    is_chat_model,
58
    is_function_calling_model,
59
    openai_modelname_to_contextsize,
60
    resolve_openai_credentials,
61
    to_openai_message_dicts,
62
)
63
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
64

65
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
66

67

68
@runtime_checkable
69
class Tokenizer(Protocol):
70
    """Tokenizers support an encode function that returns a list of ints."""
71

72
    def encode(self, text: str) -> List[int]:
73
        ...
74

75

76
class OpenAI(LLM):
77
    model: str = Field(
78
        default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use."
79
    )
80
    temperature: float = Field(
81
        default=DEFAULT_TEMPERATURE,
82
        description="The temperature to use during generation.",
83
        gte=0.0,
84
        lte=1.0,
85
    )
86
    max_tokens: Optional[int] = Field(
87
        description="The maximum number of tokens to generate.",
88
        gt=0,
89
    )
90
    additional_kwargs: Dict[str, Any] = Field(
91
        default_factory=dict, description="Additional kwargs for the OpenAI API."
92
    )
93
    max_retries: int = Field(
94
        default=3,
95
        description="The maximum number of API retries.",
96
        gte=0,
97
    )
98
    timeout: float = Field(
99
        default=60.0,
100
        description="The timeout, in seconds, for API requests.",
101
        gte=0,
102
    )
103
    default_headers: Dict[str, str] = Field(
104
        default=None, description="The default headers for API requests."
105
    )
106
    reuse_client: bool = Field(
107
        default=True,
108
        description=(
109
            "Reuse the OpenAI client between requests. When doing anything with large "
110
            "volumes of async API calls, setting this to false can improve stability."
111
        ),
112
    )
113

114
    api_key: str = Field(default=None, description="The OpenAI API key.", exclude=True)
115
    api_base: str = Field(description="The base URL for OpenAI API.")
116
    api_version: str = Field(description="The API version for OpenAI API.")
117

118
    _client: Optional[SyncOpenAI] = PrivateAttr()
119
    _aclient: Optional[AsyncOpenAI] = PrivateAttr()
120
    _http_client: Optional[httpx.Client] = PrivateAttr()
121

122
    def __init__(
123
        self,
124
        model: str = DEFAULT_OPENAI_MODEL,
125
        temperature: float = DEFAULT_TEMPERATURE,
126
        max_tokens: Optional[int] = None,
127
        additional_kwargs: Optional[Dict[str, Any]] = None,
128
        max_retries: int = 3,
129
        timeout: float = 60.0,
130
        reuse_client: bool = True,
131
        api_key: Optional[str] = None,
132
        api_base: Optional[str] = None,
133
        api_version: Optional[str] = None,
134
        callback_manager: Optional[CallbackManager] = None,
135
        default_headers: Optional[Dict[str, str]] = None,
136
        http_client: Optional[httpx.Client] = None,
137
        # base class
138
        system_prompt: Optional[str] = None,
139
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
140
        completion_to_prompt: Optional[Callable[[str], str]] = None,
141
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
142
        output_parser: Optional[BaseOutputParser] = None,
143
        **kwargs: Any,
144
    ) -> None:
145
        additional_kwargs = additional_kwargs or {}
146

147
        api_key, api_base, api_version = resolve_openai_credentials(
148
            api_key=api_key,
149
            api_base=api_base,
150
            api_version=api_version,
151
        )
152

153
        super().__init__(
154
            model=model,
155
            temperature=temperature,
156
            max_tokens=max_tokens,
157
            additional_kwargs=additional_kwargs,
158
            max_retries=max_retries,
159
            callback_manager=callback_manager,
160
            api_key=api_key,
161
            api_version=api_version,
162
            api_base=api_base,
163
            timeout=timeout,
164
            reuse_client=reuse_client,
165
            default_headers=default_headers,
166
            system_prompt=system_prompt,
167
            messages_to_prompt=messages_to_prompt,
168
            completion_to_prompt=completion_to_prompt,
169
            pydantic_program_mode=pydantic_program_mode,
170
            output_parser=output_parser,
171
            **kwargs,
172
        )
173

174
        self._client = None
175
        self._aclient = None
176
        self._http_client = http_client
177

178
    def _get_client(self) -> SyncOpenAI:
179
        if not self.reuse_client:
180
            return SyncOpenAI(**self._get_credential_kwargs())
181

182
        if self._client is None:
183
            self._client = SyncOpenAI(**self._get_credential_kwargs())
184
        return self._client
185

186
    def _get_aclient(self) -> AsyncOpenAI:
187
        if not self.reuse_client:
188
            return AsyncOpenAI(**self._get_credential_kwargs())
189

190
        if self._aclient is None:
191
            self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
192
        return self._aclient
193

194
    def _get_model_name(self) -> str:
195
        model_name = self.model
196
        if "ft-" in model_name:  # legacy fine-tuning
197
            model_name = model_name.split(":")[0]
198
        elif model_name.startswith("ft:"):
199
            model_name = model_name.split(":")[1]
200
        return model_name
201

202
    def _is_azure_client(self) -> bool:
203
        return isinstance(self._get_client(), AzureOpenAI)
204

205
    @classmethod
206
    def class_name(cls) -> str:
207
        return "openai_llm"
208

209
    @property
210
    def _tokenizer(self) -> Optional[Tokenizer]:
211
        """
212
        Get a tokenizer for this model, or None if a tokenizing method is unknown.
213

214
        OpenAI can do this using the tiktoken package, subclasses may not have
215
        this convenience.
216
        """
217
        return tiktoken.encoding_for_model(self._get_model_name())
218

219
    @property
220
    def metadata(self) -> LLMMetadata:
221
        return LLMMetadata(
222
            context_window=openai_modelname_to_contextsize(self._get_model_name()),
223
            num_output=self.max_tokens or -1,
224
            is_chat_model=is_chat_model(model=self._get_model_name()),
225
            is_function_calling_model=is_function_calling_model(
226
                model=self._get_model_name()
227
            ),
228
            model_name=self.model,
229
        )
230

231
    @llm_chat_callback()
232
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
233
        if self._use_chat_completions(kwargs):
234
            chat_fn = self._chat
235
        else:
236
            chat_fn = completion_to_chat_decorator(self._complete)
237
        return chat_fn(messages, **kwargs)
238

239
    @llm_chat_callback()
240
    def stream_chat(
241
        self, messages: Sequence[ChatMessage], **kwargs: Any
242
    ) -> ChatResponseGen:
243
        if self._use_chat_completions(kwargs):
244
            stream_chat_fn = self._stream_chat
245
        else:
246
            stream_chat_fn = stream_completion_to_chat_decorator(self._stream_complete)
247
        return stream_chat_fn(messages, **kwargs)
248

249
    @llm_completion_callback()
250
    def complete(
251
        self, prompt: str, formatted: bool = False, **kwargs: Any
252
    ) -> CompletionResponse:
253
        if self._use_chat_completions(kwargs):
254
            complete_fn = chat_to_completion_decorator(self._chat)
255
        else:
256
            complete_fn = self._complete
257
        return complete_fn(prompt, **kwargs)
258

259
    @llm_completion_callback()
260
    def stream_complete(
261
        self, prompt: str, formatted: bool = False, **kwargs: Any
262
    ) -> CompletionResponseGen:
263
        if self._use_chat_completions(kwargs):
264
            stream_complete_fn = stream_chat_to_completion_decorator(self._stream_chat)
265
        else:
266
            stream_complete_fn = self._stream_complete
267
        return stream_complete_fn(prompt, **kwargs)
268

269
    def _use_chat_completions(self, kwargs: Dict[str, Any]) -> bool:
270
        if "use_chat_completions" in kwargs:
271
            return kwargs["use_chat_completions"]
272
        return self.metadata.is_chat_model
273

274
    def _get_credential_kwargs(self) -> Dict[str, Any]:
275
        return {
276
            "api_key": self.api_key,
277
            "base_url": self.api_base,
278
            "max_retries": self.max_retries,
279
            "timeout": self.timeout,
280
            "default_headers": self.default_headers,
281
            "http_client": self._http_client,
282
        }
283

284
    def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
285
        base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
286
        if self.max_tokens is not None:
287
            # If max_tokens is None, don't include in the payload:
288
            # https://platform.openai.com/docs/api-reference/chat
289
            # https://platform.openai.com/docs/api-reference/completions
290
            base_kwargs["max_tokens"] = self.max_tokens
291
        return {**base_kwargs, **self.additional_kwargs}
292

293
    def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
294
        client = self._get_client()
295
        message_dicts = to_openai_message_dicts(messages)
296
        response = client.chat.completions.create(
297
            messages=message_dicts,
298
            stream=False,
299
            **self._get_model_kwargs(**kwargs),
300
        )
301
        openai_message = response.choices[0].message
302
        message = from_openai_message(openai_message)
303

304
        return ChatResponse(
305
            message=message,
306
            raw=response,
307
            additional_kwargs=self._get_response_token_counts(response),
308
        )
309

310
    def _update_tool_calls(
311
        self,
312
        tool_calls: List[ChoiceDeltaToolCall],
313
        tool_calls_delta: Optional[List[ChoiceDeltaToolCall]],
314
    ) -> List[ChoiceDeltaToolCall]:
315
        """Use the tool_calls_delta objects received from openai stream chunks
316
        to update the running tool_calls object.
317

318
        Args:
319
            tool_calls (List[ChoiceDeltaToolCall]): the list of tool calls
320
            tool_calls_delta (ChoiceDeltaToolCall): the delta to update tool_calls
321

322
        Returns:
323
            List[ChoiceDeltaToolCall]: the updated tool calls
324
        """
325
        # openai provides chunks consisting of tool_call deltas one tool at a time
326
        if tool_calls_delta is None:
327
            return tool_calls
328

329
        tc_delta = tool_calls_delta[0]
330

331
        if len(tool_calls) == 0:
332
            tool_calls.append(tc_delta)
333
        else:
334
            # we need to either update latest tool_call or start a
335
            # new tool_call (i.e., multiple tools in this turn) and
336
            # accumulate that new tool_call with future delta chunks
337
            t = tool_calls[-1]
338
            if t.index != tc_delta.index:
339
                # the start of a new tool call, so append to our running tool_calls list
340
                tool_calls.append(tc_delta)
341
            else:
342
                # not the start of a new tool call, so update last item of tool_calls
343

344
                # validations to get passed by mypy
345
                assert t.function is not None
346
                assert tc_delta.function is not None
347
                assert t.function.arguments is not None
348
                assert t.function.name is not None
349
                assert t.id is not None
350

351
                t.function.arguments += tc_delta.function.arguments or ""
352
                t.function.name += tc_delta.function.name or ""
353
                t.id += tc_delta.id or ""
354
        return tool_calls
355

356
    def _stream_chat(
357
        self, messages: Sequence[ChatMessage], **kwargs: Any
358
    ) -> ChatResponseGen:
359
        client = self._get_client()
360
        message_dicts = to_openai_message_dicts(messages)
361

362
        def gen() -> ChatResponseGen:
363
            content = ""
364
            tool_calls: List[ChoiceDeltaToolCall] = []
365

366
            is_function = False
367
            for response in client.chat.completions.create(
368
                messages=message_dicts,
369
                stream=True,
370
                **self._get_model_kwargs(**kwargs),
371
            ):
372
                response = cast(ChatCompletionChunk, response)
373
                if len(response.choices) > 0:
374
                    delta = response.choices[0].delta
375
                else:
376
                    if self._is_azure_client():
377
                        continue
378
                    else:
379
                        delta = ChoiceDelta()
380

381
                # check if this chunk is the start of a function call
382
                if delta.tool_calls:
383
                    is_function = True
384

385
                # update using deltas
386
                role = delta.role or MessageRole.ASSISTANT
387
                content_delta = delta.content or ""
388
                content += content_delta
389

390
                additional_kwargs = {}
391
                if is_function:
392
                    tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
393
                    additional_kwargs["tool_calls"] = tool_calls
394

395
                yield ChatResponse(
396
                    message=ChatMessage(
397
                        role=role,
398
                        content=content,
399
                        additional_kwargs=additional_kwargs,
400
                    ),
401
                    delta=content_delta,
402
                    raw=response,
403
                    additional_kwargs=self._get_response_token_counts(response),
404
                )
405

406
        return gen()
407

408
    def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
409
        client = self._get_client()
410
        all_kwargs = self._get_model_kwargs(**kwargs)
411
        self._update_max_tokens(all_kwargs, prompt)
412

413
        response = client.completions.create(
414
            prompt=prompt,
415
            stream=False,
416
            **all_kwargs,
417
        )
418
        text = response.choices[0].text
419
        return CompletionResponse(
420
            text=text,
421
            raw=response,
422
            additional_kwargs=self._get_response_token_counts(response),
423
        )
424

425
    def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
426
        client = self._get_client()
427
        all_kwargs = self._get_model_kwargs(**kwargs)
428
        self._update_max_tokens(all_kwargs, prompt)
429

430
        def gen() -> CompletionResponseGen:
431
            text = ""
432
            for response in client.completions.create(
433
                prompt=prompt,
434
                stream=True,
435
                **all_kwargs,
436
            ):
437
                if len(response.choices) > 0:
438
                    delta = response.choices[0].text
439
                else:
440
                    delta = ""
441
                text += delta
442
                yield CompletionResponse(
443
                    delta=delta,
444
                    text=text,
445
                    raw=response,
446
                    additional_kwargs=self._get_response_token_counts(response),
447
                )
448

449
        return gen()
450

451
    def _update_max_tokens(self, all_kwargs: Dict[str, Any], prompt: str) -> None:
452
        """Infer max_tokens for the payload, if possible."""
453
        if self.max_tokens is not None or self._tokenizer is None:
454
            return
455
        # NOTE: non-chat completion endpoint requires max_tokens to be set
456
        num_tokens = len(self._tokenizer.encode(prompt))
457
        max_tokens = self.metadata.context_window - num_tokens
458
        if max_tokens <= 0:
459
            raise ValueError(
460
                f"The prompt has {num_tokens} tokens, which is too long for"
461
                " the model. Please use a prompt that fits within"
462
                f" {self.metadata.context_window} tokens."
463
            )
464
        all_kwargs["max_tokens"] = max_tokens
465

466
    def _get_response_token_counts(self, raw_response: Any) -> dict:
467
        """Get the token usage reported by the response."""
468
        if not isinstance(raw_response, dict):
469
            return {}
470

471
        usage = raw_response.get("usage", {})
472
        # NOTE: other model providers that use the OpenAI client may not report usage
473
        if usage is None:
474
            return {}
475

476
        return {
477
            "prompt_tokens": usage.get("prompt_tokens", 0),
478
            "completion_tokens": usage.get("completion_tokens", 0),
479
            "total_tokens": usage.get("total_tokens", 0),
480
        }
481

482
    # ===== Async Endpoints =====
483
    @llm_chat_callback()
484
    async def achat(
485
        self,
486
        messages: Sequence[ChatMessage],
487
        **kwargs: Any,
488
    ) -> ChatResponse:
489
        achat_fn: Callable[..., Awaitable[ChatResponse]]
490
        if self._use_chat_completions(kwargs):
491
            achat_fn = self._achat
492
        else:
493
            achat_fn = acompletion_to_chat_decorator(self._acomplete)
494
        return await achat_fn(messages, **kwargs)
495

496
    @llm_chat_callback()
497
    async def astream_chat(
498
        self,
499
        messages: Sequence[ChatMessage],
500
        **kwargs: Any,
501
    ) -> ChatResponseAsyncGen:
502
        astream_chat_fn: Callable[..., Awaitable[ChatResponseAsyncGen]]
503
        if self._use_chat_completions(kwargs):
504
            astream_chat_fn = self._astream_chat
505
        else:
506
            astream_chat_fn = astream_completion_to_chat_decorator(
507
                self._astream_complete
508
            )
509
        return await astream_chat_fn(messages, **kwargs)
510

511
    @llm_completion_callback()
512
    async def acomplete(
513
        self, prompt: str, formatted: bool = False, **kwargs: Any
514
    ) -> CompletionResponse:
515
        if self._use_chat_completions(kwargs):
516
            acomplete_fn = achat_to_completion_decorator(self._achat)
517
        else:
518
            acomplete_fn = self._acomplete
519
        return await acomplete_fn(prompt, **kwargs)
520

521
    @llm_completion_callback()
522
    async def astream_complete(
523
        self, prompt: str, formatted: bool = False, **kwargs: Any
524
    ) -> CompletionResponseAsyncGen:
525
        if self._use_chat_completions(kwargs):
526
            astream_complete_fn = astream_chat_to_completion_decorator(
527
                self._astream_chat
528
            )
529
        else:
530
            astream_complete_fn = self._astream_complete
531
        return await astream_complete_fn(prompt, **kwargs)
532

533
    async def _achat(
534
        self, messages: Sequence[ChatMessage], **kwargs: Any
535
    ) -> ChatResponse:
536
        aclient = self._get_aclient()
537
        message_dicts = to_openai_message_dicts(messages)
538
        response = await aclient.chat.completions.create(
539
            messages=message_dicts, stream=False, **self._get_model_kwargs(**kwargs)
540
        )
541
        message_dict = response.choices[0].message
542
        message = from_openai_message(message_dict)
543

544
        return ChatResponse(
545
            message=message,
546
            raw=response,
547
            additional_kwargs=self._get_response_token_counts(response),
548
        )
549

550
    async def _astream_chat(
551
        self, messages: Sequence[ChatMessage], **kwargs: Any
552
    ) -> ChatResponseAsyncGen:
553
        aclient = self._get_aclient()
554
        message_dicts = to_openai_message_dicts(messages)
555

556
        async def gen() -> ChatResponseAsyncGen:
557
            content = ""
558
            tool_calls: List[ChoiceDeltaToolCall] = []
559

560
            is_function = False
561
            first_chat_chunk = True
562
            async for response in await aclient.chat.completions.create(
563
                messages=message_dicts,
564
                stream=True,
565
                **self._get_model_kwargs(**kwargs),
566
            ):
567
                response = cast(ChatCompletionChunk, response)
568
                if len(response.choices) > 0:
569
                    # check if the first chunk has neither content nor tool_calls
570
                    # this happens when 1106 models end up calling multiple tools
571
                    if (
572
                        first_chat_chunk
573
                        and response.choices[0].delta.content is None
574
                        and response.choices[0].delta.tool_calls is None
575
                    ):
576
                        first_chat_chunk = False
577
                        continue
578
                    delta = response.choices[0].delta
579
                else:
580
                    if self._is_azure_client():
581
                        continue
582
                    else:
583
                        delta = ChoiceDelta()
584
                first_chat_chunk = False
585

586
                # check if this chunk is the start of a function call
587
                if delta.tool_calls:
588
                    is_function = True
589

590
                # update using deltas
591
                role = delta.role or MessageRole.ASSISTANT
592
                content_delta = delta.content or ""
593
                content += content_delta
594

595
                additional_kwargs = {}
596
                if is_function:
597
                    tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
598
                    additional_kwargs["tool_calls"] = tool_calls
599

600
                yield ChatResponse(
601
                    message=ChatMessage(
602
                        role=role,
603
                        content=content,
604
                        additional_kwargs=additional_kwargs,
605
                    ),
606
                    delta=content_delta,
607
                    raw=response,
608
                    additional_kwargs=self._get_response_token_counts(response),
609
                )
610

611
        return gen()
612

613
    async def _acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
614
        aclient = self._get_aclient()
615
        all_kwargs = self._get_model_kwargs(**kwargs)
616
        self._update_max_tokens(all_kwargs, prompt)
617

618
        response = await aclient.completions.create(
619
            prompt=prompt,
620
            stream=False,
621
            **all_kwargs,
622
        )
623
        text = response.choices[0].text
624
        return CompletionResponse(
625
            text=text,
626
            raw=response,
627
            additional_kwargs=self._get_response_token_counts(response),
628
        )
629

630
    async def _astream_complete(
631
        self, prompt: str, **kwargs: Any
632
    ) -> CompletionResponseAsyncGen:
633
        aclient = self._get_aclient()
634
        all_kwargs = self._get_model_kwargs(**kwargs)
635
        self._update_max_tokens(all_kwargs, prompt)
636

637
        async def gen() -> CompletionResponseAsyncGen:
638
            text = ""
639
            async for response in await aclient.completions.create(
640
                prompt=prompt,
641
                stream=True,
642
                **all_kwargs,
643
            ):
644
                if len(response.choices) > 0:
645
                    delta = response.choices[0].text
646
                else:
647
                    delta = ""
648
                text += delta
649
                yield CompletionResponse(
650
                    delta=delta,
651
                    text=text,
652
                    raw=response,
653
                    additional_kwargs=self._get_response_token_counts(response),
654
                )
655

656
        return gen()
657

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.