llama-index
347 строк · 11.4 Кб
1import warnings2from typing import Any, Callable, Dict, Optional, Sequence3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.callbacks import CallbackManager6from llama_index.legacy.core.llms.types import (7ChatMessage,8ChatResponse,9ChatResponseAsyncGen,10ChatResponseGen,11CompletionResponse,12CompletionResponseAsyncGen,13CompletionResponseGen,14LLMMetadata,15MessageRole,16)
17from llama_index.legacy.llms.base import (18llm_chat_callback,19llm_completion_callback,20)
21from llama_index.legacy.llms.cohere_utils import (22CHAT_MODELS,23acompletion_with_retry,24cohere_modelname_to_contextsize,25completion_with_retry,26messages_to_cohere_history,27)
28from llama_index.legacy.llms.llm import LLM29from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode30
31
32class Cohere(LLM):33model: str = Field(description="The cohere model to use.")34temperature: float = Field(description="The temperature to use for sampling.")35max_retries: int = Field(36default=10, description="The maximum number of API retries."37)38additional_kwargs: Dict[str, Any] = Field(39default_factory=dict, description="Additional kwargs for the Cohere API."40)41max_tokens: int = Field(description="The maximum number of tokens to generate.")42
43_client: Any = PrivateAttr()44_aclient: Any = PrivateAttr()45
46def __init__(47self,48model: str = "command",49temperature: float = 0.5,50max_tokens: int = 512,51timeout: Optional[float] = None,52max_retries: int = 10,53api_key: Optional[str] = None,54additional_kwargs: Optional[Dict[str, Any]] = None,55callback_manager: Optional[CallbackManager] = None,56system_prompt: Optional[str] = None,57messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,58completion_to_prompt: Optional[Callable[[str], str]] = None,59pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,60output_parser: Optional[BaseOutputParser] = None,61) -> None:62try:63import cohere64except ImportError as e:65raise ImportError(66"You must install the `cohere` package to use Cohere."67"Please `pip install cohere`"68) from e69additional_kwargs = additional_kwargs or {}70callback_manager = callback_manager or CallbackManager([])71
72self._client = cohere.Client(api_key, client_name="llama_index")73self._aclient = cohere.AsyncClient(api_key, client_name="llama_index")74
75super().__init__(76temperature=temperature,77additional_kwargs=additional_kwargs,78timeout=timeout,79max_retries=max_retries,80model=model,81callback_manager=callback_manager,82max_tokens=max_tokens,83system_prompt=system_prompt,84messages_to_prompt=messages_to_prompt,85completion_to_prompt=completion_to_prompt,86pydantic_program_mode=pydantic_program_mode,87output_parser=output_parser,88)89
90@classmethod91def class_name(cls) -> str:92"""Get class name."""93return "Cohere_LLM"94
95@property96def metadata(self) -> LLMMetadata:97return LLMMetadata(98context_window=cohere_modelname_to_contextsize(self.model),99num_output=self.max_tokens,100is_chat_model=True,101model_name=self.model,102system_role=MessageRole.CHATBOT,103)104
105@property106def _model_kwargs(self) -> Dict[str, Any]:107base_kwargs = {108"model": self.model,109"temperature": self.temperature,110}111return {112**base_kwargs,113**self.additional_kwargs,114}115
116def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:117return {118**self._model_kwargs,119**kwargs,120}121
122@llm_chat_callback()123def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:124history = messages_to_cohere_history(messages[:-1])125prompt = messages[-1].content126all_kwargs = self._get_all_kwargs(**kwargs)127if all_kwargs["model"] not in CHAT_MODELS:128raise ValueError(f"{all_kwargs['model']} not supported for chat")129
130if "stream" in all_kwargs:131warnings.warn(132"Parameter `stream` is not supported by the `chat` method."133"Use the `stream_chat` method instead"134)135response = completion_with_retry(136client=self._client,137max_retries=self.max_retries,138chat=True,139message=prompt,140chat_history=history,141**all_kwargs,142)143return ChatResponse(144message=ChatMessage(role=MessageRole.ASSISTANT, content=response.text),145raw=response.__dict__,146)147
148@llm_completion_callback()149def complete(150self, prompt: str, formatted: bool = False, **kwargs: Any151) -> CompletionResponse:152all_kwargs = self._get_all_kwargs(**kwargs)153if "stream" in all_kwargs:154warnings.warn(155"Parameter `stream` is not supported by the `chat` method."156"Use the `stream_chat` method instead"157)158
159response = completion_with_retry(160client=self._client,161max_retries=self.max_retries,162chat=False,163prompt=prompt,164**all_kwargs,165)166
167return CompletionResponse(168text=response.generations[0].text,169raw=response.__dict__,170)171
172@llm_chat_callback()173def stream_chat(174self, messages: Sequence[ChatMessage], **kwargs: Any175) -> ChatResponseGen:176history = messages_to_cohere_history(messages[:-1])177prompt = messages[-1].content178all_kwargs = self._get_all_kwargs(**kwargs)179all_kwargs["stream"] = True180if all_kwargs["model"] not in CHAT_MODELS:181raise ValueError(f"{all_kwargs['model']} not supported for chat")182response = completion_with_retry(183client=self._client,184max_retries=self.max_retries,185chat=True,186message=prompt,187chat_history=history,188**all_kwargs,189)190
191def gen() -> ChatResponseGen:192content = ""193role = MessageRole.ASSISTANT194for r in response:195if "text" in r.__dict__:196content_delta = r.text197else:198content_delta = ""199content += content_delta200yield ChatResponse(201message=ChatMessage(role=role, content=content),202delta=content_delta,203raw=r.__dict__,204)205
206return gen()207
208@llm_completion_callback()209def stream_complete(210self, prompt: str, formatted: bool = False, **kwargs: Any211) -> CompletionResponseGen:212all_kwargs = self._get_all_kwargs(**kwargs)213all_kwargs["stream"] = True214
215response = completion_with_retry(216client=self._client,217max_retries=self.max_retries,218chat=False,219prompt=prompt,220**all_kwargs,221)222
223def gen() -> CompletionResponseGen:224content = ""225for r in response:226content_delta = r.text227content += content_delta228yield CompletionResponse(229text=content, delta=content_delta, raw=r._asdict()230)231
232return gen()233
234@llm_chat_callback()235async def achat(236self, messages: Sequence[ChatMessage], **kwargs: Any237) -> ChatResponse:238history = messages_to_cohere_history(messages[:-1])239prompt = messages[-1].content240all_kwargs = self._get_all_kwargs(**kwargs)241if all_kwargs["model"] not in CHAT_MODELS:242raise ValueError(f"{all_kwargs['model']} not supported for chat")243if "stream" in all_kwargs:244warnings.warn(245"Parameter `stream` is not supported by the `chat` method."246"Use the `stream_chat` method instead"247)248
249response = await acompletion_with_retry(250aclient=self._aclient,251max_retries=self.max_retries,252chat=True,253message=prompt,254chat_history=history,255**all_kwargs,256)257
258return ChatResponse(259message=ChatMessage(role=MessageRole.ASSISTANT, content=response.text),260raw=response.__dict__,261)262
263@llm_completion_callback()264async def acomplete(265self, prompt: str, formatted: bool = False, **kwargs: Any266) -> CompletionResponse:267all_kwargs = self._get_all_kwargs(**kwargs)268if "stream" in all_kwargs:269warnings.warn(270"Parameter `stream` is not supported by the `chat` method."271"Use the `stream_chat` method instead"272)273
274response = await acompletion_with_retry(275aclient=self._aclient,276max_retries=self.max_retries,277chat=False,278prompt=prompt,279**all_kwargs,280)281
282return CompletionResponse(283text=response.generations[0].text,284raw=response.__dict__,285)286
287@llm_chat_callback()288async def astream_chat(289self, messages: Sequence[ChatMessage], **kwargs: Any290) -> ChatResponseAsyncGen:291history = messages_to_cohere_history(messages[:-1])292prompt = messages[-1].content293all_kwargs = self._get_all_kwargs(**kwargs)294all_kwargs["stream"] = True295if all_kwargs["model"] not in CHAT_MODELS:296raise ValueError(f"{all_kwargs['model']} not supported for chat")297response = await acompletion_with_retry(298aclient=self._aclient,299max_retries=self.max_retries,300chat=True,301message=prompt,302chat_history=history,303**all_kwargs,304)305
306async def gen() -> ChatResponseAsyncGen:307content = ""308role = MessageRole.ASSISTANT309async for r in response:310if "text" in r.__dict__:311content_delta = r.text312else:313content_delta = ""314content += content_delta315yield ChatResponse(316message=ChatMessage(role=role, content=content),317delta=content_delta,318raw=r.__dict__,319)320
321return gen()322
323@llm_completion_callback()324async def astream_complete(325self, prompt: str, formatted: bool = False, **kwargs: Any326) -> CompletionResponseAsyncGen:327all_kwargs = self._get_all_kwargs(**kwargs)328all_kwargs["stream"] = True329
330response = await acompletion_with_retry(331aclient=self._aclient,332max_retries=self.max_retries,333chat=False,334prompt=prompt,335**all_kwargs,336)337
338async def gen() -> CompletionResponseAsyncGen:339content = ""340async for r in response:341content_delta = r.text342content += content_delta343yield CompletionResponse(344text=content, delta=content_delta, raw=r._asdict()345)346
347return gen()348