llama-index

Форк
0
258 строк · 8.6 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_TEMPERATURE
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseAsyncGen,
10
    ChatResponseGen,
11
    CompletionResponse,
12
    CompletionResponseAsyncGen,
13
    CompletionResponseGen,
14
    LLMMetadata,
15
    MessageRole,
16
)
17
from llama_index.legacy.llms.anthropic_utils import (
18
    anthropic_modelname_to_contextsize,
19
    messages_to_anthropic_prompt,
20
)
21
from llama_index.legacy.llms.base import (
22
    llm_chat_callback,
23
    llm_completion_callback,
24
)
25
from llama_index.legacy.llms.generic_utils import (
26
    achat_to_completion_decorator,
27
    astream_chat_to_completion_decorator,
28
    chat_to_completion_decorator,
29
    stream_chat_to_completion_decorator,
30
)
31
from llama_index.legacy.llms.llm import LLM
32
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
33

34
DEFAULT_ANTHROPIC_MODEL = "claude-2"
35
DEFAULT_ANTHROPIC_MAX_TOKENS = 512
36

37

38
class Anthropic(LLM):
39
    model: str = Field(
40
        default=DEFAULT_ANTHROPIC_MODEL, description="The anthropic model to use."
41
    )
42
    temperature: float = Field(
43
        default=DEFAULT_TEMPERATURE,
44
        description="The temperature to use for sampling.",
45
        gte=0.0,
46
        lte=1.0,
47
    )
48
    max_tokens: int = Field(
49
        default=DEFAULT_ANTHROPIC_MAX_TOKENS,
50
        description="The maximum number of tokens to generate.",
51
        gt=0,
52
    )
53

54
    base_url: Optional[str] = Field(default=None, description="The base URL to use.")
55
    timeout: Optional[float] = Field(
56
        default=None, description="The timeout to use in seconds.", gte=0
57
    )
58
    max_retries: int = Field(
59
        default=10, description="The maximum number of API retries.", gte=0
60
    )
61
    additional_kwargs: Dict[str, Any] = Field(
62
        default_factory=dict, description="Additional kwargs for the anthropic API."
63
    )
64

65
    _client: Any = PrivateAttr()
66
    _aclient: Any = PrivateAttr()
67

68
    def __init__(
69
        self,
70
        model: str = DEFAULT_ANTHROPIC_MODEL,
71
        temperature: float = DEFAULT_TEMPERATURE,
72
        max_tokens: int = DEFAULT_ANTHROPIC_MAX_TOKENS,
73
        base_url: Optional[str] = None,
74
        timeout: Optional[float] = None,
75
        max_retries: int = 10,
76
        api_key: Optional[str] = None,
77
        additional_kwargs: Optional[Dict[str, Any]] = None,
78
        callback_manager: Optional[CallbackManager] = None,
79
        system_prompt: Optional[str] = None,
80
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
81
        completion_to_prompt: Optional[Callable[[str], str]] = None,
82
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
83
        output_parser: Optional[BaseOutputParser] = None,
84
    ) -> None:
85
        try:
86
            import anthropic
87
        except ImportError as e:
88
            raise ImportError(
89
                "You must install the `anthropic` package to use Anthropic."
90
                "Please `pip install anthropic`"
91
            ) from e
92

93
        additional_kwargs = additional_kwargs or {}
94
        callback_manager = callback_manager or CallbackManager([])
95

96
        self._client = anthropic.Anthropic(
97
            api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries
98
        )
99
        self._aclient = anthropic.AsyncAnthropic(
100
            api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries
101
        )
102

103
        super().__init__(
104
            temperature=temperature,
105
            max_tokens=max_tokens,
106
            additional_kwargs=additional_kwargs,
107
            base_url=base_url,
108
            timeout=timeout,
109
            max_retries=max_retries,
110
            model=model,
111
            callback_manager=callback_manager,
112
            system_prompt=system_prompt,
113
            messages_to_prompt=messages_to_prompt,
114
            completion_to_prompt=completion_to_prompt,
115
            pydantic_program_mode=pydantic_program_mode,
116
            output_parser=output_parser,
117
        )
118

119
    @classmethod
120
    def class_name(cls) -> str:
121
        return "Anthropic_LLM"
122

123
    @property
124
    def metadata(self) -> LLMMetadata:
125
        return LLMMetadata(
126
            context_window=anthropic_modelname_to_contextsize(self.model),
127
            num_output=self.max_tokens,
128
            is_chat_model=True,
129
            model_name=self.model,
130
        )
131

132
    @property
133
    def _model_kwargs(self) -> Dict[str, Any]:
134
        base_kwargs = {
135
            "model": self.model,
136
            "temperature": self.temperature,
137
            "max_tokens_to_sample": self.max_tokens,
138
        }
139
        return {
140
            **base_kwargs,
141
            **self.additional_kwargs,
142
        }
143

144
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
145
        return {
146
            **self._model_kwargs,
147
            **kwargs,
148
        }
149

150
    @llm_chat_callback()
151
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
152
        prompt = messages_to_anthropic_prompt(messages)
153
        all_kwargs = self._get_all_kwargs(**kwargs)
154

155
        response = self._client.completions.create(
156
            prompt=prompt, stream=False, **all_kwargs
157
        )
158
        return ChatResponse(
159
            message=ChatMessage(
160
                role=MessageRole.ASSISTANT, content=response.completion
161
            ),
162
            raw=dict(response),
163
        )
164

165
    @llm_completion_callback()
166
    def complete(
167
        self, prompt: str, formatted: bool = False, **kwargs: Any
168
    ) -> CompletionResponse:
169
        complete_fn = chat_to_completion_decorator(self.chat)
170
        return complete_fn(prompt, **kwargs)
171

172
    @llm_chat_callback()
173
    def stream_chat(
174
        self, messages: Sequence[ChatMessage], **kwargs: Any
175
    ) -> ChatResponseGen:
176
        prompt = messages_to_anthropic_prompt(messages)
177
        all_kwargs = self._get_all_kwargs(**kwargs)
178

179
        response = self._client.completions.create(
180
            prompt=prompt, stream=True, **all_kwargs
181
        )
182

183
        def gen() -> ChatResponseGen:
184
            content = ""
185
            role = MessageRole.ASSISTANT
186
            for r in response:
187
                content_delta = r.completion
188
                content += content_delta
189
                yield ChatResponse(
190
                    message=ChatMessage(role=role, content=content),
191
                    delta=content_delta,
192
                    raw=r,
193
                )
194

195
        return gen()
196

197
    @llm_completion_callback()
198
    def stream_complete(
199
        self, prompt: str, formatted: bool = False, **kwargs: Any
200
    ) -> CompletionResponseGen:
201
        stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat)
202
        return stream_complete_fn(prompt, **kwargs)
203

204
    @llm_chat_callback()
205
    async def achat(
206
        self, messages: Sequence[ChatMessage], **kwargs: Any
207
    ) -> ChatResponse:
208
        prompt = messages_to_anthropic_prompt(messages)
209
        all_kwargs = self._get_all_kwargs(**kwargs)
210

211
        response = await self._aclient.completions.create(
212
            prompt=prompt, stream=False, **all_kwargs
213
        )
214
        return ChatResponse(
215
            message=ChatMessage(
216
                role=MessageRole.ASSISTANT, content=response.completion
217
            ),
218
            raw=dict(response),
219
        )
220

221
    @llm_completion_callback()
222
    async def acomplete(
223
        self, prompt: str, formatted: bool = False, **kwargs: Any
224
    ) -> CompletionResponse:
225
        acomplete_fn = achat_to_completion_decorator(self.achat)
226
        return await acomplete_fn(prompt, **kwargs)
227

228
    @llm_chat_callback()
229
    async def astream_chat(
230
        self, messages: Sequence[ChatMessage], **kwargs: Any
231
    ) -> ChatResponseAsyncGen:
232
        prompt = messages_to_anthropic_prompt(messages)
233
        all_kwargs = self._get_all_kwargs(**kwargs)
234

235
        response = await self._aclient.completions.create(
236
            prompt=prompt, stream=True, **all_kwargs
237
        )
238

239
        async def gen() -> ChatResponseAsyncGen:
240
            content = ""
241
            role = MessageRole.ASSISTANT
242
            async for r in response:
243
                content_delta = r.completion
244
                content += content_delta
245
                yield ChatResponse(
246
                    message=ChatMessage(role=role, content=content),
247
                    delta=content_delta,
248
                    raw=r,
249
                )
250

251
        return gen()
252

253
    @llm_completion_callback()
254
    async def astream_complete(
255
        self, prompt: str, formatted: bool = False, **kwargs: Any
256
    ) -> CompletionResponseAsyncGen:
257
        astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
258
        return await astream_complete_fn(prompt, **kwargs)
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.