llama-index

Форк
0
298 строк · 10.9 Кб
1
import json
2
from typing import Any, Callable, Dict, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.constants import (
7
    DEFAULT_TEMPERATURE,
8
)
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    ChatResponse,
12
    ChatResponseAsyncGen,
13
    ChatResponseGen,
14
    CompletionResponse,
15
    CompletionResponseAsyncGen,
16
    CompletionResponseGen,
17
    LLMMetadata,
18
)
19
from llama_index.legacy.llms.base import (
20
    llm_chat_callback,
21
    llm_completion_callback,
22
)
23
from llama_index.legacy.llms.bedrock_utils import (
24
    BEDROCK_FOUNDATION_LLMS,
25
    CHAT_ONLY_MODELS,
26
    STREAMING_MODELS,
27
    Provider,
28
    completion_with_retry,
29
    get_provider,
30
)
31
from llama_index.legacy.llms.generic_utils import (
32
    completion_response_to_chat_response,
33
    stream_completion_response_to_chat_response,
34
)
35
from llama_index.legacy.llms.llm import LLM
36
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
37

38

39
class Bedrock(LLM):
40
    model: str = Field(description="The modelId of the Bedrock model to use.")
41
    temperature: float = Field(description="The temperature to use for sampling.")
42
    max_tokens: int = Field(description="The maximum number of tokens to generate.")
43
    context_size: int = Field("The maximum number of tokens available for input.")
44
    profile_name: Optional[str] = Field(
45
        description="The name of aws profile to use. If not given, then the default profile is used."
46
    )
47
    aws_access_key_id: Optional[str] = Field(
48
        description="AWS Access Key ID to use", exclude=True
49
    )
50
    aws_secret_access_key: Optional[str] = Field(
51
        description="AWS Secret Access Key to use", exclude=True
52
    )
53
    aws_session_token: Optional[str] = Field(
54
        description="AWS Session Token to use", exclude=True
55
    )
56
    region_name: Optional[str] = Field(
57
        description="AWS region name to use. Uses region configured in AWS CLI if not passed",
58
        exclude=True,
59
    )
60
    botocore_session: Optional[Any] = Field(
61
        description="Use this Botocore session instead of creating a new default one.",
62
        exclude=True,
63
    )
64
    botocore_config: Optional[Any] = Field(
65
        description="Custom configuration object to use instead of the default generated one.",
66
        exclude=True,
67
    )
68
    max_retries: int = Field(
69
        default=10, description="The maximum number of API retries.", gt=0
70
    )
71
    timeout: float = Field(
72
        default=60.0,
73
        description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
74
    )
75
    additional_kwargs: Dict[str, Any] = Field(
76
        default_factory=dict,
77
        description="Additional kwargs for the bedrock invokeModel request.",
78
    )
79

80
    _client: Any = PrivateAttr()
81
    _aclient: Any = PrivateAttr()
82
    _provider: Provider = PrivateAttr()
83

84
    def __init__(
85
        self,
86
        model: str,
87
        temperature: Optional[float] = DEFAULT_TEMPERATURE,
88
        max_tokens: Optional[int] = 512,
89
        context_size: Optional[int] = None,
90
        profile_name: Optional[str] = None,
91
        aws_access_key_id: Optional[str] = None,
92
        aws_secret_access_key: Optional[str] = None,
93
        aws_session_token: Optional[str] = None,
94
        region_name: Optional[str] = None,
95
        botocore_session: Optional[Any] = None,
96
        client: Optional[Any] = None,
97
        timeout: Optional[float] = 60.0,
98
        max_retries: Optional[int] = 10,
99
        botocore_config: Optional[Any] = None,
100
        additional_kwargs: Optional[Dict[str, Any]] = None,
101
        callback_manager: Optional[CallbackManager] = None,
102
        system_prompt: Optional[str] = None,
103
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
104
        completion_to_prompt: Optional[Callable[[str], str]] = None,
105
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
106
        output_parser: Optional[BaseOutputParser] = None,
107
        **kwargs: Any,
108
    ) -> None:
109
        if context_size is None and model not in BEDROCK_FOUNDATION_LLMS:
110
            raise ValueError(
111
                "`context_size` argument not provided and"
112
                "model provided refers to a non-foundation model."
113
                " Please specify the context_size"
114
            )
115

116
        session_kwargs = {
117
            "profile_name": profile_name,
118
            "region_name": region_name,
119
            "aws_access_key_id": aws_access_key_id,
120
            "aws_secret_access_key": aws_secret_access_key,
121
            "aws_session_token": aws_session_token,
122
            "botocore_session": botocore_session,
123
        }
124
        config = None
125
        try:
126
            import boto3
127
            from botocore.config import Config
128

129
            config = (
130
                Config(
131
                    retries={"max_attempts": max_retries, "mode": "standard"},
132
                    connect_timeout=timeout,
133
                    read_timeout=timeout,
134
                )
135
                if botocore_config is None
136
                else botocore_config
137
            )
138
            session = boto3.Session(**session_kwargs)
139
        except ImportError:
140
            raise ImportError(
141
                "boto3 package not found, install with" "'pip install boto3'"
142
            )
143

144
        # Prior to general availability, custom boto3 wheel files were
145
        # distributed that used the bedrock service to invokeModel.
146
        # This check prevents any services still using those wheel files
147
        # from breaking
148
        if client is not None:
149
            self._client = client
150
        elif "bedrock-runtime" in session.get_available_services():
151
            self._client = session.client("bedrock-runtime", config=config)
152
        else:
153
            self._client = session.client("bedrock", config=config)
154

155
        additional_kwargs = additional_kwargs or {}
156
        callback_manager = callback_manager or CallbackManager([])
157
        context_size = context_size or BEDROCK_FOUNDATION_LLMS[model]
158
        self._provider = get_provider(model)
159
        messages_to_prompt = messages_to_prompt or self._provider.messages_to_prompt
160
        completion_to_prompt = (
161
            completion_to_prompt or self._provider.completion_to_prompt
162
        )
163
        super().__init__(
164
            model=model,
165
            temperature=temperature,
166
            max_tokens=max_tokens,
167
            context_size=context_size,
168
            profile_name=profile_name,
169
            timeout=timeout,
170
            max_retries=max_retries,
171
            botocore_config=config,
172
            additional_kwargs=additional_kwargs,
173
            callback_manager=callback_manager,
174
            system_prompt=system_prompt,
175
            messages_to_prompt=messages_to_prompt,
176
            completion_to_prompt=completion_to_prompt,
177
            pydantic_program_mode=pydantic_program_mode,
178
            output_parser=output_parser,
179
        )
180

181
    @classmethod
182
    def class_name(cls) -> str:
183
        """Get class name."""
184
        return "Bedrock_LLM"
185

186
    @property
187
    def metadata(self) -> LLMMetadata:
188
        return LLMMetadata(
189
            context_window=self.context_size,
190
            num_output=self.max_tokens,
191
            is_chat_model=self.model in CHAT_ONLY_MODELS,
192
            model_name=self.model,
193
        )
194

195
    @property
196
    def _model_kwargs(self) -> Dict[str, Any]:
197
        base_kwargs = {
198
            "temperature": self.temperature,
199
            self._provider.max_tokens_key: self.max_tokens,
200
        }
201
        return {
202
            **base_kwargs,
203
            **self.additional_kwargs,
204
        }
205

206
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
207
        return {
208
            **self._model_kwargs,
209
            **kwargs,
210
        }
211

212
    @llm_completion_callback()
213
    def complete(
214
        self, prompt: str, formatted: bool = False, **kwargs: Any
215
    ) -> CompletionResponse:
216
        if not formatted:
217
            prompt = self.completion_to_prompt(prompt)
218
        all_kwargs = self._get_all_kwargs(**kwargs)
219
        request_body = self._provider.get_request_body(prompt, all_kwargs)
220
        request_body_str = json.dumps(request_body)
221
        response = completion_with_retry(
222
            client=self._client,
223
            model=self.model,
224
            request_body=request_body_str,
225
            max_retries=self.max_retries,
226
            **all_kwargs,
227
        )["body"].read()
228
        response = json.loads(response)
229
        return CompletionResponse(
230
            text=self._provider.get_text_from_response(response), raw=response
231
        )
232

233
    @llm_completion_callback()
234
    def stream_complete(
235
        self, prompt: str, formatted: bool = False, **kwargs: Any
236
    ) -> CompletionResponseGen:
237
        if self.model in BEDROCK_FOUNDATION_LLMS and self.model not in STREAMING_MODELS:
238
            raise ValueError(f"Model {self.model} does not support streaming")
239

240
        if not formatted:
241
            prompt = self.completion_to_prompt(prompt)
242

243
        all_kwargs = self._get_all_kwargs(**kwargs)
244
        request_body = self._provider.get_request_body(prompt, all_kwargs)
245
        request_body_str = json.dumps(request_body)
246
        response = completion_with_retry(
247
            client=self._client,
248
            model=self.model,
249
            request_body=request_body_str,
250
            max_retries=self.max_retries,
251
            stream=True,
252
            **all_kwargs,
253
        )["body"]
254

255
        def gen() -> CompletionResponseGen:
256
            content = ""
257
            for r in response:
258
                r = json.loads(r["chunk"]["bytes"])
259
                content_delta = self._provider.get_text_from_stream_response(r)
260
                content += content_delta
261
                yield CompletionResponse(text=content, delta=content_delta, raw=r)
262

263
        return gen()
264

265
    @llm_chat_callback()
266
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
267
        prompt = self.messages_to_prompt(messages)
268
        completion_response = self.complete(prompt, formatted=True, **kwargs)
269
        return completion_response_to_chat_response(completion_response)
270

271
    def stream_chat(
272
        self, messages: Sequence[ChatMessage], **kwargs: Any
273
    ) -> ChatResponseGen:
274
        prompt = self.messages_to_prompt(messages)
275
        completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
276
        return stream_completion_response_to_chat_response(completion_response)
277

278
    async def achat(
279
        self, messages: Sequence[ChatMessage], **kwargs: Any
280
    ) -> ChatResponse:
281
        """Chat asynchronously."""
282
        # TODO: do synchronous chat for now
283
        return self.chat(messages, **kwargs)
284

285
    async def acomplete(
286
        self, prompt: str, formatted: bool = False, **kwargs: Any
287
    ) -> CompletionResponse:
288
        raise NotImplementedError
289

290
    async def astream_chat(
291
        self, messages: Sequence[ChatMessage], **kwargs: Any
292
    ) -> ChatResponseAsyncGen:
293
        raise NotImplementedError
294

295
    async def astream_complete(
296
        self, prompt: str, formatted: bool = False, **kwargs: Any
297
    ) -> CompletionResponseAsyncGen:
298
        raise NotImplementedError
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.