llama-index
258 строк · 8.6 Кб
1from typing import Any, Callable, Dict, Optional, Sequence2
3from llama_index.legacy.bridge.pydantic import Field, PrivateAttr4from llama_index.legacy.callbacks import CallbackManager5from llama_index.legacy.constants import DEFAULT_TEMPERATURE6from llama_index.legacy.core.llms.types import (7ChatMessage,8ChatResponse,9ChatResponseAsyncGen,10ChatResponseGen,11CompletionResponse,12CompletionResponseAsyncGen,13CompletionResponseGen,14LLMMetadata,15MessageRole,16)
17from llama_index.legacy.llms.anthropic_utils import (18anthropic_modelname_to_contextsize,19messages_to_anthropic_prompt,20)
21from llama_index.legacy.llms.base import (22llm_chat_callback,23llm_completion_callback,24)
25from llama_index.legacy.llms.generic_utils import (26achat_to_completion_decorator,27astream_chat_to_completion_decorator,28chat_to_completion_decorator,29stream_chat_to_completion_decorator,30)
31from llama_index.legacy.llms.llm import LLM32from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode33
34DEFAULT_ANTHROPIC_MODEL = "claude-2"35DEFAULT_ANTHROPIC_MAX_TOKENS = 51236
37
38class Anthropic(LLM):39model: str = Field(40default=DEFAULT_ANTHROPIC_MODEL, description="The anthropic model to use."41)42temperature: float = Field(43default=DEFAULT_TEMPERATURE,44description="The temperature to use for sampling.",45gte=0.0,46lte=1.0,47)48max_tokens: int = Field(49default=DEFAULT_ANTHROPIC_MAX_TOKENS,50description="The maximum number of tokens to generate.",51gt=0,52)53
54base_url: Optional[str] = Field(default=None, description="The base URL to use.")55timeout: Optional[float] = Field(56default=None, description="The timeout to use in seconds.", gte=057)58max_retries: int = Field(59default=10, description="The maximum number of API retries.", gte=060)61additional_kwargs: Dict[str, Any] = Field(62default_factory=dict, description="Additional kwargs for the anthropic API."63)64
65_client: Any = PrivateAttr()66_aclient: Any = PrivateAttr()67
68def __init__(69self,70model: str = DEFAULT_ANTHROPIC_MODEL,71temperature: float = DEFAULT_TEMPERATURE,72max_tokens: int = DEFAULT_ANTHROPIC_MAX_TOKENS,73base_url: Optional[str] = None,74timeout: Optional[float] = None,75max_retries: int = 10,76api_key: Optional[str] = None,77additional_kwargs: Optional[Dict[str, Any]] = None,78callback_manager: Optional[CallbackManager] = None,79system_prompt: Optional[str] = None,80messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,81completion_to_prompt: Optional[Callable[[str], str]] = None,82pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,83output_parser: Optional[BaseOutputParser] = None,84) -> None:85try:86import anthropic87except ImportError as e:88raise ImportError(89"You must install the `anthropic` package to use Anthropic."90"Please `pip install anthropic`"91) from e92
93additional_kwargs = additional_kwargs or {}94callback_manager = callback_manager or CallbackManager([])95
96self._client = anthropic.Anthropic(97api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries98)99self._aclient = anthropic.AsyncAnthropic(100api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries101)102
103super().__init__(104temperature=temperature,105max_tokens=max_tokens,106additional_kwargs=additional_kwargs,107base_url=base_url,108timeout=timeout,109max_retries=max_retries,110model=model,111callback_manager=callback_manager,112system_prompt=system_prompt,113messages_to_prompt=messages_to_prompt,114completion_to_prompt=completion_to_prompt,115pydantic_program_mode=pydantic_program_mode,116output_parser=output_parser,117)118
119@classmethod120def class_name(cls) -> str:121return "Anthropic_LLM"122
123@property124def metadata(self) -> LLMMetadata:125return LLMMetadata(126context_window=anthropic_modelname_to_contextsize(self.model),127num_output=self.max_tokens,128is_chat_model=True,129model_name=self.model,130)131
132@property133def _model_kwargs(self) -> Dict[str, Any]:134base_kwargs = {135"model": self.model,136"temperature": self.temperature,137"max_tokens_to_sample": self.max_tokens,138}139return {140**base_kwargs,141**self.additional_kwargs,142}143
144def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:145return {146**self._model_kwargs,147**kwargs,148}149
150@llm_chat_callback()151def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:152prompt = messages_to_anthropic_prompt(messages)153all_kwargs = self._get_all_kwargs(**kwargs)154
155response = self._client.completions.create(156prompt=prompt, stream=False, **all_kwargs157)158return ChatResponse(159message=ChatMessage(160role=MessageRole.ASSISTANT, content=response.completion161),162raw=dict(response),163)164
165@llm_completion_callback()166def complete(167self, prompt: str, formatted: bool = False, **kwargs: Any168) -> CompletionResponse:169complete_fn = chat_to_completion_decorator(self.chat)170return complete_fn(prompt, **kwargs)171
172@llm_chat_callback()173def stream_chat(174self, messages: Sequence[ChatMessage], **kwargs: Any175) -> ChatResponseGen:176prompt = messages_to_anthropic_prompt(messages)177all_kwargs = self._get_all_kwargs(**kwargs)178
179response = self._client.completions.create(180prompt=prompt, stream=True, **all_kwargs181)182
183def gen() -> ChatResponseGen:184content = ""185role = MessageRole.ASSISTANT186for r in response:187content_delta = r.completion188content += content_delta189yield ChatResponse(190message=ChatMessage(role=role, content=content),191delta=content_delta,192raw=r,193)194
195return gen()196
197@llm_completion_callback()198def stream_complete(199self, prompt: str, formatted: bool = False, **kwargs: Any200) -> CompletionResponseGen:201stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat)202return stream_complete_fn(prompt, **kwargs)203
204@llm_chat_callback()205async def achat(206self, messages: Sequence[ChatMessage], **kwargs: Any207) -> ChatResponse:208prompt = messages_to_anthropic_prompt(messages)209all_kwargs = self._get_all_kwargs(**kwargs)210
211response = await self._aclient.completions.create(212prompt=prompt, stream=False, **all_kwargs213)214return ChatResponse(215message=ChatMessage(216role=MessageRole.ASSISTANT, content=response.completion217),218raw=dict(response),219)220
221@llm_completion_callback()222async def acomplete(223self, prompt: str, formatted: bool = False, **kwargs: Any224) -> CompletionResponse:225acomplete_fn = achat_to_completion_decorator(self.achat)226return await acomplete_fn(prompt, **kwargs)227
228@llm_chat_callback()229async def astream_chat(230self, messages: Sequence[ChatMessage], **kwargs: Any231) -> ChatResponseAsyncGen:232prompt = messages_to_anthropic_prompt(messages)233all_kwargs = self._get_all_kwargs(**kwargs)234
235response = await self._aclient.completions.create(236prompt=prompt, stream=True, **all_kwargs237)238
239async def gen() -> ChatResponseAsyncGen:240content = ""241role = MessageRole.ASSISTANT242async for r in response:243content_delta = r.completion244content += content_delta245yield ChatResponse(246message=ChatMessage(role=role, content=content),247delta=content_delta,248raw=r,249)250
251return gen()252
253@llm_completion_callback()254async def astream_complete(255self, prompt: str, formatted: bool = False, **kwargs: Any256) -> CompletionResponseAsyncGen:257astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)258return await astream_complete_fn(prompt, **kwargs)259