llama-index

Форк
0
304 строки · 10.4 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_TEMPERATURE
6

7
# from mistralai.models.chat_completion import ChatMessage
8
from llama_index.legacy.core.llms.types import (
9
    ChatMessage,
10
    ChatResponse,
11
    ChatResponseAsyncGen,
12
    ChatResponseGen,
13
    CompletionResponse,
14
    CompletionResponseAsyncGen,
15
    CompletionResponseGen,
16
    LLMMetadata,
17
    MessageRole,
18
)
19
from llama_index.legacy.llms.base import (
20
    llm_chat_callback,
21
    llm_completion_callback,
22
)
23
from llama_index.legacy.llms.generic_utils import (
24
    achat_to_completion_decorator,
25
    astream_chat_to_completion_decorator,
26
    chat_to_completion_decorator,
27
    get_from_param_or_env,
28
    stream_chat_to_completion_decorator,
29
)
30
from llama_index.legacy.llms.llm import LLM
31
from llama_index.legacy.llms.mistralai_utils import (
32
    mistralai_modelname_to_contextsize,
33
)
34
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
35

36
DEFAULT_MISTRALAI_MODEL = "mistral-tiny"
37
DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai"
38
DEFAULT_MISTRALAI_MAX_TOKENS = 512
39

40

41
class MistralAI(LLM):
42
    model: str = Field(
43
        default=DEFAULT_MISTRALAI_MODEL, description="The mistralai model to use."
44
    )
45
    temperature: float = Field(
46
        default=DEFAULT_TEMPERATURE,
47
        description="The temperature to use for sampling.",
48
        gte=0.0,
49
        lte=1.0,
50
    )
51
    max_tokens: int = Field(
52
        default=DEFAULT_MISTRALAI_MAX_TOKENS,
53
        description="The maximum number of tokens to generate.",
54
        gt=0,
55
    )
56

57
    timeout: float = Field(
58
        default=120, description="The timeout to use in seconds.", gte=0
59
    )
60
    max_retries: int = Field(
61
        default=5, description="The maximum number of API retries.", gte=0
62
    )
63
    safe_mode: bool = Field(
64
        default=False,
65
        description="The parameter to enforce guardrails in chat generations.",
66
    )
67
    random_seed: str = Field(
68
        default=None, description="The random seed to use for sampling."
69
    )
70
    additional_kwargs: Dict[str, Any] = Field(
71
        default_factory=dict, description="Additional kwargs for the MistralAI API."
72
    )
73

74
    _client: Any = PrivateAttr()
75
    _aclient: Any = PrivateAttr()
76

77
    def __init__(
78
        self,
79
        model: str = DEFAULT_MISTRALAI_MODEL,
80
        temperature: float = DEFAULT_TEMPERATURE,
81
        max_tokens: int = DEFAULT_MISTRALAI_MAX_TOKENS,
82
        timeout: int = 120,
83
        max_retries: int = 5,
84
        safe_mode: bool = False,
85
        random_seed: Optional[int] = None,
86
        api_key: Optional[str] = None,
87
        additional_kwargs: Optional[Dict[str, Any]] = None,
88
        callback_manager: Optional[CallbackManager] = None,
89
        system_prompt: Optional[str] = None,
90
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
91
        completion_to_prompt: Optional[Callable[[str], str]] = None,
92
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
93
        output_parser: Optional[BaseOutputParser] = None,
94
    ) -> None:
95
        try:
96
            from mistralai.async_client import MistralAsyncClient
97
            from mistralai.client import MistralClient
98
        except ImportError as e:
99
            raise ImportError(
100
                "You must install the `mistralai` package to use mistralai."
101
                "Please `pip install mistralai`"
102
            ) from e
103

104
        additional_kwargs = additional_kwargs or {}
105
        callback_manager = callback_manager or CallbackManager([])
106

107
        api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
108

109
        if not api_key:
110
            raise ValueError(
111
                "You must provide an API key to use mistralai. "
112
                "You can either pass it in as an argument or set it `MISTRAL_API_KEY`."
113
            )
114

115
        self._client = MistralClient(
116
            api_key=api_key,
117
            endpoint=DEFAULT_MISTRALAI_ENDPOINT,
118
            timeout=timeout,
119
            max_retries=max_retries,
120
        )
121
        self._aclient = MistralAsyncClient(
122
            api_key=api_key,
123
            endpoint=DEFAULT_MISTRALAI_ENDPOINT,
124
            timeout=timeout,
125
            max_retries=max_retries,
126
        )
127

128
        super().__init__(
129
            temperature=temperature,
130
            max_tokens=max_tokens,
131
            additional_kwargs=additional_kwargs,
132
            timeout=timeout,
133
            max_retries=max_retries,
134
            safe_mode=safe_mode,
135
            random_seed=random_seed,
136
            model=model,
137
            callback_manager=callback_manager,
138
            system_prompt=system_prompt,
139
            messages_to_prompt=messages_to_prompt,
140
            completion_to_prompt=completion_to_prompt,
141
            pydantic_program_mode=pydantic_program_mode,
142
            output_parser=output_parser,
143
        )
144

145
    @classmethod
146
    def class_name(cls) -> str:
147
        return "MistralAI_LLM"
148

149
    @property
150
    def metadata(self) -> LLMMetadata:
151
        return LLMMetadata(
152
            context_window=mistralai_modelname_to_contextsize(self.model),
153
            num_output=self.max_tokens,
154
            is_chat_model=True,
155
            model_name=self.model,
156
            safe_mode=self.safe_mode,
157
            random_seed=self.random_seed,
158
        )
159

160
    @property
161
    def _model_kwargs(self) -> Dict[str, Any]:
162
        base_kwargs = {
163
            "model": self.model,
164
            "temperature": self.temperature,
165
            "max_tokens": self.max_tokens,
166
            "random_seed": self.random_seed,
167
            "safe_mode": self.safe_mode,
168
        }
169
        return {
170
            **base_kwargs,
171
            **self.additional_kwargs,
172
        }
173

174
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
175
        return {
176
            **self._model_kwargs,
177
            **kwargs,
178
        }
179

180
    @llm_chat_callback()
181
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
182
        # convert messages to mistral ChatMessage
183
        from mistralai.client import ChatMessage as mistral_chatmessage
184

185
        messages = [
186
            mistral_chatmessage(role=x.role, content=x.content) for x in messages
187
        ]
188
        all_kwargs = self._get_all_kwargs(**kwargs)
189
        response = self._client.chat(messages=messages, **all_kwargs)
190
        return ChatResponse(
191
            message=ChatMessage(
192
                role=MessageRole.ASSISTANT, content=response.choices[0].message.content
193
            ),
194
            raw=dict(response),
195
        )
196

197
    @llm_completion_callback()
198
    def complete(
199
        self, prompt: str, formatted: bool = False, **kwargs: Any
200
    ) -> CompletionResponse:
201
        complete_fn = chat_to_completion_decorator(self.chat)
202
        return complete_fn(prompt, **kwargs)
203

204
    @llm_chat_callback()
205
    def stream_chat(
206
        self, messages: Sequence[ChatMessage], **kwargs: Any
207
    ) -> ChatResponseGen:
208
        # convert messages to mistral ChatMessage
209
        from mistralai.client import ChatMessage as mistral_chatmessage
210

211
        messages = [
212
            mistral_chatmessage(role=message.role, content=message.content)
213
            for message in messages
214
        ]
215
        all_kwargs = self._get_all_kwargs(**kwargs)
216

217
        response = self._client.chat_stream(messages=messages, **all_kwargs)
218

219
        def gen() -> ChatResponseGen:
220
            content = ""
221
            role = MessageRole.ASSISTANT
222
            for chunk in response:
223
                content_delta = chunk.choices[0].delta.content
224
                if content_delta is None:
225
                    continue
226
                content += content_delta
227
                yield ChatResponse(
228
                    message=ChatMessage(role=role, content=content),
229
                    delta=content_delta,
230
                    raw=chunk,
231
                )
232

233
        return gen()
234

235
    @llm_completion_callback()
236
    def stream_complete(
237
        self, prompt: str, formatted: bool = False, **kwargs: Any
238
    ) -> CompletionResponseGen:
239
        stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat)
240
        return stream_complete_fn(prompt, **kwargs)
241

242
    @llm_chat_callback()
243
    async def achat(
244
        self, messages: Sequence[ChatMessage], **kwargs: Any
245
    ) -> ChatResponse:
246
        # convert messages to mistral ChatMessage
247
        from mistralai.client import ChatMessage as mistral_chatmessage
248

249
        messages = [
250
            mistral_chatmessage(role=message.role, content=message.content)
251
            for message in messages
252
        ]
253
        all_kwargs = self._get_all_kwargs(**kwargs)
254
        response = await self._aclient.chat(messages=messages, **all_kwargs)
255
        return ChatResponse(
256
            message=ChatMessage(
257
                role=MessageRole.ASSISTANT, content=response.choices[0].message.content
258
            ),
259
            raw=dict(response),
260
        )
261

262
    @llm_completion_callback()
263
    async def acomplete(
264
        self, prompt: str, formatted: bool = False, **kwargs: Any
265
    ) -> CompletionResponse:
266
        acomplete_fn = achat_to_completion_decorator(self.achat)
267
        return await acomplete_fn(prompt, **kwargs)
268

269
    @llm_chat_callback()
270
    async def astream_chat(
271
        self, messages: Sequence[ChatMessage], **kwargs: Any
272
    ) -> ChatResponseAsyncGen:
273
        # convert messages to mistral ChatMessage
274
        from mistralai.client import ChatMessage as mistral_chatmessage
275

276
        messages = [
277
            mistral_chatmessage(role=x.role, content=x.content) for x in messages
278
        ]
279
        all_kwargs = self._get_all_kwargs(**kwargs)
280

281
        response = await self._aclient.chat_stream(messages=messages, **all_kwargs)
282

283
        async def gen() -> ChatResponseAsyncGen:
284
            content = ""
285
            role = MessageRole.ASSISTANT
286
            async for chunk in response:
287
                content_delta = chunk.choices[0].delta.content
288
                if content_delta is None:
289
                    continue
290
                content += content_delta
291
                yield ChatResponse(
292
                    message=ChatMessage(role=role, content=content),
293
                    delta=content_delta,
294
                    raw=chunk,
295
                )
296

297
        return gen()
298

299
    @llm_completion_callback()
300
    async def astream_complete(
301
        self, prompt: str, formatted: bool = False, **kwargs: Any
302
    ) -> CompletionResponseAsyncGen:
303
        astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
304
        return await astream_complete_fn(prompt, **kwargs)
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.