llama-index

Форк
0
347 строк · 11.4 Кб
1
import warnings
2
from typing import Any, Callable, Dict, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseAsyncGen,
10
    ChatResponseGen,
11
    CompletionResponse,
12
    CompletionResponseAsyncGen,
13
    CompletionResponseGen,
14
    LLMMetadata,
15
    MessageRole,
16
)
17
from llama_index.legacy.llms.base import (
18
    llm_chat_callback,
19
    llm_completion_callback,
20
)
21
from llama_index.legacy.llms.cohere_utils import (
22
    CHAT_MODELS,
23
    acompletion_with_retry,
24
    cohere_modelname_to_contextsize,
25
    completion_with_retry,
26
    messages_to_cohere_history,
27
)
28
from llama_index.legacy.llms.llm import LLM
29
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
30

31

32
class Cohere(LLM):
33
    model: str = Field(description="The cohere model to use.")
34
    temperature: float = Field(description="The temperature to use for sampling.")
35
    max_retries: int = Field(
36
        default=10, description="The maximum number of API retries."
37
    )
38
    additional_kwargs: Dict[str, Any] = Field(
39
        default_factory=dict, description="Additional kwargs for the Cohere API."
40
    )
41
    max_tokens: int = Field(description="The maximum number of tokens to generate.")
42

43
    _client: Any = PrivateAttr()
44
    _aclient: Any = PrivateAttr()
45

46
    def __init__(
47
        self,
48
        model: str = "command",
49
        temperature: float = 0.5,
50
        max_tokens: int = 512,
51
        timeout: Optional[float] = None,
52
        max_retries: int = 10,
53
        api_key: Optional[str] = None,
54
        additional_kwargs: Optional[Dict[str, Any]] = None,
55
        callback_manager: Optional[CallbackManager] = None,
56
        system_prompt: Optional[str] = None,
57
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
58
        completion_to_prompt: Optional[Callable[[str], str]] = None,
59
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
60
        output_parser: Optional[BaseOutputParser] = None,
61
    ) -> None:
62
        try:
63
            import cohere
64
        except ImportError as e:
65
            raise ImportError(
66
                "You must install the `cohere` package to use Cohere."
67
                "Please `pip install cohere`"
68
            ) from e
69
        additional_kwargs = additional_kwargs or {}
70
        callback_manager = callback_manager or CallbackManager([])
71

72
        self._client = cohere.Client(api_key, client_name="llama_index")
73
        self._aclient = cohere.AsyncClient(api_key, client_name="llama_index")
74

75
        super().__init__(
76
            temperature=temperature,
77
            additional_kwargs=additional_kwargs,
78
            timeout=timeout,
79
            max_retries=max_retries,
80
            model=model,
81
            callback_manager=callback_manager,
82
            max_tokens=max_tokens,
83
            system_prompt=system_prompt,
84
            messages_to_prompt=messages_to_prompt,
85
            completion_to_prompt=completion_to_prompt,
86
            pydantic_program_mode=pydantic_program_mode,
87
            output_parser=output_parser,
88
        )
89

90
    @classmethod
91
    def class_name(cls) -> str:
92
        """Get class name."""
93
        return "Cohere_LLM"
94

95
    @property
96
    def metadata(self) -> LLMMetadata:
97
        return LLMMetadata(
98
            context_window=cohere_modelname_to_contextsize(self.model),
99
            num_output=self.max_tokens,
100
            is_chat_model=True,
101
            model_name=self.model,
102
            system_role=MessageRole.CHATBOT,
103
        )
104

105
    @property
106
    def _model_kwargs(self) -> Dict[str, Any]:
107
        base_kwargs = {
108
            "model": self.model,
109
            "temperature": self.temperature,
110
        }
111
        return {
112
            **base_kwargs,
113
            **self.additional_kwargs,
114
        }
115

116
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
117
        return {
118
            **self._model_kwargs,
119
            **kwargs,
120
        }
121

122
    @llm_chat_callback()
123
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
124
        history = messages_to_cohere_history(messages[:-1])
125
        prompt = messages[-1].content
126
        all_kwargs = self._get_all_kwargs(**kwargs)
127
        if all_kwargs["model"] not in CHAT_MODELS:
128
            raise ValueError(f"{all_kwargs['model']} not supported for chat")
129

130
        if "stream" in all_kwargs:
131
            warnings.warn(
132
                "Parameter `stream` is not supported by the `chat` method."
133
                "Use the `stream_chat` method instead"
134
            )
135
        response = completion_with_retry(
136
            client=self._client,
137
            max_retries=self.max_retries,
138
            chat=True,
139
            message=prompt,
140
            chat_history=history,
141
            **all_kwargs,
142
        )
143
        return ChatResponse(
144
            message=ChatMessage(role=MessageRole.ASSISTANT, content=response.text),
145
            raw=response.__dict__,
146
        )
147

148
    @llm_completion_callback()
149
    def complete(
150
        self, prompt: str, formatted: bool = False, **kwargs: Any
151
    ) -> CompletionResponse:
152
        all_kwargs = self._get_all_kwargs(**kwargs)
153
        if "stream" in all_kwargs:
154
            warnings.warn(
155
                "Parameter `stream` is not supported by the `chat` method."
156
                "Use the `stream_chat` method instead"
157
            )
158

159
        response = completion_with_retry(
160
            client=self._client,
161
            max_retries=self.max_retries,
162
            chat=False,
163
            prompt=prompt,
164
            **all_kwargs,
165
        )
166

167
        return CompletionResponse(
168
            text=response.generations[0].text,
169
            raw=response.__dict__,
170
        )
171

172
    @llm_chat_callback()
173
    def stream_chat(
174
        self, messages: Sequence[ChatMessage], **kwargs: Any
175
    ) -> ChatResponseGen:
176
        history = messages_to_cohere_history(messages[:-1])
177
        prompt = messages[-1].content
178
        all_kwargs = self._get_all_kwargs(**kwargs)
179
        all_kwargs["stream"] = True
180
        if all_kwargs["model"] not in CHAT_MODELS:
181
            raise ValueError(f"{all_kwargs['model']} not supported for chat")
182
        response = completion_with_retry(
183
            client=self._client,
184
            max_retries=self.max_retries,
185
            chat=True,
186
            message=prompt,
187
            chat_history=history,
188
            **all_kwargs,
189
        )
190

191
        def gen() -> ChatResponseGen:
192
            content = ""
193
            role = MessageRole.ASSISTANT
194
            for r in response:
195
                if "text" in r.__dict__:
196
                    content_delta = r.text
197
                else:
198
                    content_delta = ""
199
                content += content_delta
200
                yield ChatResponse(
201
                    message=ChatMessage(role=role, content=content),
202
                    delta=content_delta,
203
                    raw=r.__dict__,
204
                )
205

206
        return gen()
207

208
    @llm_completion_callback()
209
    def stream_complete(
210
        self, prompt: str, formatted: bool = False, **kwargs: Any
211
    ) -> CompletionResponseGen:
212
        all_kwargs = self._get_all_kwargs(**kwargs)
213
        all_kwargs["stream"] = True
214

215
        response = completion_with_retry(
216
            client=self._client,
217
            max_retries=self.max_retries,
218
            chat=False,
219
            prompt=prompt,
220
            **all_kwargs,
221
        )
222

223
        def gen() -> CompletionResponseGen:
224
            content = ""
225
            for r in response:
226
                content_delta = r.text
227
                content += content_delta
228
                yield CompletionResponse(
229
                    text=content, delta=content_delta, raw=r._asdict()
230
                )
231

232
        return gen()
233

234
    @llm_chat_callback()
235
    async def achat(
236
        self, messages: Sequence[ChatMessage], **kwargs: Any
237
    ) -> ChatResponse:
238
        history = messages_to_cohere_history(messages[:-1])
239
        prompt = messages[-1].content
240
        all_kwargs = self._get_all_kwargs(**kwargs)
241
        if all_kwargs["model"] not in CHAT_MODELS:
242
            raise ValueError(f"{all_kwargs['model']} not supported for chat")
243
        if "stream" in all_kwargs:
244
            warnings.warn(
245
                "Parameter `stream` is not supported by the `chat` method."
246
                "Use the `stream_chat` method instead"
247
            )
248

249
        response = await acompletion_with_retry(
250
            aclient=self._aclient,
251
            max_retries=self.max_retries,
252
            chat=True,
253
            message=prompt,
254
            chat_history=history,
255
            **all_kwargs,
256
        )
257

258
        return ChatResponse(
259
            message=ChatMessage(role=MessageRole.ASSISTANT, content=response.text),
260
            raw=response.__dict__,
261
        )
262

263
    @llm_completion_callback()
264
    async def acomplete(
265
        self, prompt: str, formatted: bool = False, **kwargs: Any
266
    ) -> CompletionResponse:
267
        all_kwargs = self._get_all_kwargs(**kwargs)
268
        if "stream" in all_kwargs:
269
            warnings.warn(
270
                "Parameter `stream` is not supported by the `chat` method."
271
                "Use the `stream_chat` method instead"
272
            )
273

274
        response = await acompletion_with_retry(
275
            aclient=self._aclient,
276
            max_retries=self.max_retries,
277
            chat=False,
278
            prompt=prompt,
279
            **all_kwargs,
280
        )
281

282
        return CompletionResponse(
283
            text=response.generations[0].text,
284
            raw=response.__dict__,
285
        )
286

287
    @llm_chat_callback()
288
    async def astream_chat(
289
        self, messages: Sequence[ChatMessage], **kwargs: Any
290
    ) -> ChatResponseAsyncGen:
291
        history = messages_to_cohere_history(messages[:-1])
292
        prompt = messages[-1].content
293
        all_kwargs = self._get_all_kwargs(**kwargs)
294
        all_kwargs["stream"] = True
295
        if all_kwargs["model"] not in CHAT_MODELS:
296
            raise ValueError(f"{all_kwargs['model']} not supported for chat")
297
        response = await acompletion_with_retry(
298
            aclient=self._aclient,
299
            max_retries=self.max_retries,
300
            chat=True,
301
            message=prompt,
302
            chat_history=history,
303
            **all_kwargs,
304
        )
305

306
        async def gen() -> ChatResponseAsyncGen:
307
            content = ""
308
            role = MessageRole.ASSISTANT
309
            async for r in response:
310
                if "text" in r.__dict__:
311
                    content_delta = r.text
312
                else:
313
                    content_delta = ""
314
                content += content_delta
315
                yield ChatResponse(
316
                    message=ChatMessage(role=role, content=content),
317
                    delta=content_delta,
318
                    raw=r.__dict__,
319
                )
320

321
        return gen()
322

323
    @llm_completion_callback()
324
    async def astream_complete(
325
        self, prompt: str, formatted: bool = False, **kwargs: Any
326
    ) -> CompletionResponseAsyncGen:
327
        all_kwargs = self._get_all_kwargs(**kwargs)
328
        all_kwargs["stream"] = True
329

330
        response = await acompletion_with_retry(
331
            aclient=self._aclient,
332
            max_retries=self.max_retries,
333
            chat=False,
334
            prompt=prompt,
335
            **all_kwargs,
336
        )
337

338
        async def gen() -> CompletionResponseAsyncGen:
339
            content = ""
340
            async for r in response:
341
                content_delta = r.text
342
                content += content_delta
343
                yield CompletionResponse(
344
                    text=content, delta=content_delta, raw=r._asdict()
345
                )
346

347
        return gen()
348

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.