llama-index

Форк
0
168 строк · 6.0 Кб
1
from typing import Any, Optional, Sequence, Union
2

3
from llama_index.legacy.bridge.pydantic import Field
4
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW
5
from llama_index.legacy.llms.generic_utils import (
6
    async_stream_completion_response_to_chat_response,
7
    completion_response_to_chat_response,
8
    stream_completion_response_to_chat_response,
9
)
10
from llama_index.legacy.llms.openai import OpenAI, Tokenizer
11
from llama_index.legacy.llms.types import (
12
    ChatMessage,
13
    ChatResponse,
14
    ChatResponseAsyncGen,
15
    ChatResponseGen,
16
    CompletionResponse,
17
    CompletionResponseAsyncGen,
18
    CompletionResponseGen,
19
    LLMMetadata,
20
)
21

22

23
class OpenAILike(OpenAI):
24
    """
25
    OpenAILike is a thin wrapper around the OpenAI model that makes it compatible with
26
    3rd party tools that provide an openai-compatible api.
27

28
    Currently, llama_index prevents using custom models with their OpenAI class
29
    because they need to be able to infer some metadata from the model name.
30

31
    NOTE: You still need to set the OPENAI_BASE_API and OPENAI_API_KEY environment
32
    variables or the api_key and api_base constructor arguments.
33
    OPENAI_API_KEY/api_key can normally be set to anything in this case,
34
    but will depend on the tool you're using.
35
    """
36

37
    context_window: int = Field(
38
        default=DEFAULT_CONTEXT_WINDOW,
39
        description=LLMMetadata.__fields__["context_window"].field_info.description,
40
    )
41
    is_chat_model: bool = Field(
42
        default=False,
43
        description=LLMMetadata.__fields__["is_chat_model"].field_info.description,
44
    )
45
    is_function_calling_model: bool = Field(
46
        default=False,
47
        description=LLMMetadata.__fields__[
48
            "is_function_calling_model"
49
        ].field_info.description,
50
    )
51
    tokenizer: Union[Tokenizer, str, None] = Field(
52
        default=None,
53
        description=(
54
            "An instance of a tokenizer object that has an encode method, or the name"
55
            " of a tokenizer model from Hugging Face. If left as None, then this"
56
            " disables inference of max_tokens."
57
        ),
58
    )
59

60
    @property
61
    def metadata(self) -> LLMMetadata:
62
        return LLMMetadata(
63
            context_window=self.context_window,
64
            num_output=self.max_tokens or -1,
65
            is_chat_model=self.is_chat_model,
66
            is_function_calling_model=self.is_function_calling_model,
67
            model_name=self.model,
68
        )
69

70
    @property
71
    def _tokenizer(self) -> Optional[Tokenizer]:
72
        if isinstance(self.tokenizer, str):
73
            try:
74
                from transformers import AutoTokenizer
75
            except ImportError as exc:
76
                raise ImportError(
77
                    "Please install transformers (pip install transformers) to use "
78
                    "huggingface tokenizers with OpenAILike."
79
                ) from exc
80

81
            return AutoTokenizer.from_pretrained(self.tokenizer)
82
        return self.tokenizer
83

84
    @classmethod
85
    def class_name(cls) -> str:
86
        return "OpenAILike"
87

88
    def complete(
89
        self, prompt: str, formatted: bool = False, **kwargs: Any
90
    ) -> CompletionResponse:
91
        """Complete the prompt."""
92
        if not formatted:
93
            prompt = self.completion_to_prompt(prompt)
94

95
        return super().complete(prompt, **kwargs)
96

97
    def stream_complete(
98
        self, prompt: str, formatted: bool = False, **kwargs: Any
99
    ) -> CompletionResponseGen:
100
        """Stream complete the prompt."""
101
        if not formatted:
102
            prompt = self.completion_to_prompt(prompt)
103

104
        return super().stream_complete(prompt, **kwargs)
105

106
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
107
        """Chat with the model."""
108
        if not self.metadata.is_chat_model:
109
            prompt = self.messages_to_prompt(messages)
110
            completion_response = self.complete(prompt, formatted=True, **kwargs)
111
            return completion_response_to_chat_response(completion_response)
112

113
        return super().chat(messages, **kwargs)
114

115
    def stream_chat(
116
        self, messages: Sequence[ChatMessage], **kwargs: Any
117
    ) -> ChatResponseGen:
118
        if not self.metadata.is_chat_model:
119
            prompt = self.messages_to_prompt(messages)
120
            completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
121
            return stream_completion_response_to_chat_response(completion_response)
122

123
        return super().stream_chat(messages, **kwargs)
124

125
    # -- Async methods --
126

127
    async def acomplete(
128
        self, prompt: str, formatted: bool = False, **kwargs: Any
129
    ) -> CompletionResponse:
130
        """Complete the prompt."""
131
        if not formatted:
132
            prompt = self.completion_to_prompt(prompt)
133

134
        return await super().acomplete(prompt, **kwargs)
135

136
    async def astream_complete(
137
        self, prompt: str, formatted: bool = False, **kwargs: Any
138
    ) -> CompletionResponseAsyncGen:
139
        """Stream complete the prompt."""
140
        if not formatted:
141
            prompt = self.completion_to_prompt(prompt)
142

143
        return await super().astream_complete(prompt, **kwargs)
144

145
    async def achat(
146
        self, messages: Sequence[ChatMessage], **kwargs: Any
147
    ) -> ChatResponse:
148
        """Chat with the model."""
149
        if not self.metadata.is_chat_model:
150
            prompt = self.messages_to_prompt(messages)
151
            completion_response = await self.acomplete(prompt, formatted=True, **kwargs)
152
            return completion_response_to_chat_response(completion_response)
153

154
        return await super().achat(messages, **kwargs)
155

156
    async def astream_chat(
157
        self, messages: Sequence[ChatMessage], **kwargs: Any
158
    ) -> ChatResponseAsyncGen:
159
        if not self.metadata.is_chat_model:
160
            prompt = self.messages_to_prompt(messages)
161
            completion_response = await self.astream_complete(
162
                prompt, formatted=True, **kwargs
163
            )
164
            return async_stream_completion_response_to_chat_response(
165
                completion_response
166
            )
167

168
        return await super().astream_chat(messages, **kwargs)
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.