llama-index

Форк
0
391 строка · 13.8 Кб
1
import json
2
import os
3
import warnings
4
from enum import Enum
5
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
6

7
from deprecated import deprecated
8

9
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
10
from llama_index.legacy.callbacks.base import CallbackManager
11
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
12
from llama_index.legacy.core.embeddings.base import BaseEmbedding, Embedding
13
from llama_index.legacy.core.llms.types import ChatMessage
14
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
15

16

17
class PROVIDERS(str, Enum):
18
    AMAZON = "amazon"
19
    COHERE = "cohere"
20

21

22
class Models(str, Enum):
23
    TITAN_EMBEDDING = "amazon.titan-embed-text-v1"
24
    TITAN_EMBEDDING_G1_TEXT_02 = "amazon.titan-embed-g1-text-02"
25
    COHERE_EMBED_ENGLISH_V3 = "cohere.embed-english-v3"
26
    COHERE_EMBED_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"
27

28

29
PROVIDER_SPECIFIC_IDENTIFIERS = {
30
    PROVIDERS.AMAZON.value: {
31
        "get_embeddings_func": lambda r: r.get("embedding"),
32
    },
33
    PROVIDERS.COHERE.value: {
34
        "get_embeddings_func": lambda r: r.get("embeddings")[0],
35
    },
36
}
37

38

39
class BedrockEmbedding(BaseEmbedding):
40
    model: str = Field(description="The modelId of the Bedrock model to use.")
41
    profile_name: Optional[str] = Field(
42
        description="The name of aws profile to use. If not given, then the default profile is used.",
43
        exclude=True,
44
    )
45
    aws_access_key_id: Optional[str] = Field(
46
        description="AWS Access Key ID to use", exclude=True
47
    )
48
    aws_secret_access_key: Optional[str] = Field(
49
        description="AWS Secret Access Key to use", exclude=True
50
    )
51
    aws_session_token: Optional[str] = Field(
52
        description="AWS Session Token to use", exclude=True
53
    )
54
    region_name: Optional[str] = Field(
55
        description="AWS region name to use. Uses region configured in AWS CLI if not passed",
56
        exclude=True,
57
    )
58
    botocore_session: Optional[Any] = Field(
59
        description="Use this Botocore session instead of creating a new default one.",
60
        exclude=True,
61
    )
62
    botocore_config: Optional[Any] = Field(
63
        description="Custom configuration object to use instead of the default generated one.",
64
        exclude=True,
65
    )
66
    max_retries: int = Field(
67
        default=10, description="The maximum number of API retries.", gt=0
68
    )
69
    timeout: float = Field(
70
        default=60.0,
71
        description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
72
    )
73
    additional_kwargs: Dict[str, Any] = Field(
74
        default_factory=dict, description="Additional kwargs for the bedrock client."
75
    )
76
    _client: Any = PrivateAttr()
77

78
    def __init__(
79
        self,
80
        model: str = Models.TITAN_EMBEDDING,
81
        profile_name: Optional[str] = None,
82
        aws_access_key_id: Optional[str] = None,
83
        aws_secret_access_key: Optional[str] = None,
84
        aws_session_token: Optional[str] = None,
85
        region_name: Optional[str] = None,
86
        client: Optional[Any] = None,
87
        botocore_session: Optional[Any] = None,
88
        botocore_config: Optional[Any] = None,
89
        additional_kwargs: Optional[Dict[str, Any]] = None,
90
        max_retries: int = 10,
91
        timeout: float = 60.0,
92
        callback_manager: Optional[CallbackManager] = None,
93
        # base class
94
        system_prompt: Optional[str] = None,
95
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
96
        completion_to_prompt: Optional[Callable[[str], str]] = None,
97
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
98
        output_parser: Optional[BaseOutputParser] = None,
99
        **kwargs: Any,
100
    ):
101
        additional_kwargs = additional_kwargs or {}
102

103
        session_kwargs = {
104
            "profile_name": profile_name,
105
            "region_name": region_name,
106
            "aws_access_key_id": aws_access_key_id,
107
            "aws_secret_access_key": aws_secret_access_key,
108
            "aws_session_token": aws_session_token,
109
            "botocore_session": botocore_session,
110
        }
111
        config = None
112
        try:
113
            import boto3
114
            from botocore.config import Config
115

116
            config = (
117
                Config(
118
                    retries={"max_attempts": max_retries, "mode": "standard"},
119
                    connect_timeout=timeout,
120
                    read_timeout=timeout,
121
                )
122
                if botocore_config is None
123
                else botocore_config
124
            )
125
            session = boto3.Session(**session_kwargs)
126
        except ImportError:
127
            raise ImportError(
128
                "boto3 package not found, install with" "'pip install boto3'"
129
            )
130

131
        # Prior to general availability, custom boto3 wheel files were
132
        # distributed that used the bedrock service to invokeModel.
133
        # This check prevents any services still using those wheel files
134
        # from breaking
135
        if client is not None:
136
            self._client = client
137
        elif "bedrock-runtime" in session.get_available_services():
138
            self._client = session.client("bedrock-runtime", config=config)
139
        else:
140
            self._client = session.client("bedrock", config=config)
141

142
        super().__init__(
143
            model=model,
144
            max_retries=max_retries,
145
            timeout=timeout,
146
            botocore_config=config,
147
            profile_name=profile_name,
148
            aws_access_key_id=aws_access_key_id,
149
            aws_secret_access_key=aws_secret_access_key,
150
            aws_session_token=aws_session_token,
151
            region_name=region_name,
152
            botocore_session=botocore_session,
153
            additional_kwargs=additional_kwargs,
154
            callback_manager=callback_manager,
155
            system_prompt=system_prompt,
156
            messages_to_prompt=messages_to_prompt,
157
            completion_to_prompt=completion_to_prompt,
158
            pydantic_program_mode=pydantic_program_mode,
159
            output_parser=output_parser,
160
            **kwargs,
161
        )
162

163
    @staticmethod
164
    def list_supported_models() -> Dict[str, List[str]]:
165
        list_models = {}
166
        for provider in PROVIDERS:
167
            list_models[provider.value] = [m.value for m in Models]
168
        return list_models
169

170
    @classmethod
171
    def class_name(self) -> str:
172
        return "BedrockEmbedding"
173

174
    @deprecated(
175
        version="0.9.48",
176
        reason=(
177
            "Use the provided kwargs in the constructor, "
178
            "set_credentials will be removed in future releases."
179
        ),
180
        action="once",
181
    )
182
    def set_credentials(
183
        self,
184
        aws_region: Optional[str] = None,
185
        aws_access_key_id: Optional[str] = None,
186
        aws_secret_access_key: Optional[str] = None,
187
        aws_session_token: Optional[str] = None,
188
        aws_profile: Optional[str] = None,
189
    ) -> None:
190
        aws_region = aws_region or os.getenv("AWS_REGION")
191
        aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
192
        aws_secret_access_key = aws_secret_access_key or os.getenv(
193
            "AWS_SECRET_ACCESS_KEY"
194
        )
195
        aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
196

197
        if aws_region is None:
198
            warnings.warn(
199
                "AWS_REGION not found. Set environment variable AWS_REGION or set aws_region"
200
            )
201

202
        if aws_access_key_id is None:
203
            warnings.warn(
204
                "AWS_ACCESS_KEY_ID not found. Set environment variable AWS_ACCESS_KEY_ID or set aws_access_key_id"
205
            )
206
            assert aws_access_key_id is not None
207

208
        if aws_secret_access_key is None:
209
            warnings.warn(
210
                "AWS_SECRET_ACCESS_KEY not found. Set environment variable AWS_SECRET_ACCESS_KEY or set aws_secret_access_key"
211
            )
212
            assert aws_secret_access_key is not None
213

214
        if aws_session_token is None:
215
            warnings.warn(
216
                "AWS_SESSION_TOKEN not found. Set environment variable AWS_SESSION_TOKEN or set aws_session_token"
217
            )
218
            assert aws_session_token is not None
219

220
        session_kwargs = {
221
            "profile_name": aws_profile,
222
            "region_name": aws_region,
223
            "aws_access_key_id": aws_access_key_id,
224
            "aws_secret_access_key": aws_secret_access_key,
225
            "aws_session_token": aws_session_token,
226
        }
227

228
        try:
229
            import boto3
230

231
            session = boto3.Session(**session_kwargs)
232
        except ImportError:
233
            raise ImportError(
234
                "boto3 package not found, install with" "'pip install boto3'"
235
            )
236

237
        if "bedrock-runtime" in session.get_available_services():
238
            self._client = session.client("bedrock-runtime")
239
        else:
240
            self._client = session.client("bedrock")
241

242
    @classmethod
243
    @deprecated(
244
        version="0.9.48",
245
        reason=(
246
            "Use the provided kwargs in the constructor, "
247
            "set_credentials will be removed in future releases."
248
        ),
249
        action="once",
250
    )
251
    def from_credentials(
252
        cls,
253
        model_name: str = Models.TITAN_EMBEDDING,
254
        aws_region: Optional[str] = None,
255
        aws_access_key_id: Optional[str] = None,
256
        aws_secret_access_key: Optional[str] = None,
257
        aws_session_token: Optional[str] = None,
258
        aws_profile: Optional[str] = None,
259
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
260
        callback_manager: Optional[CallbackManager] = None,
261
        verbose: bool = False,
262
    ) -> "BedrockEmbedding":
263
        """
264
        Instantiate using AWS credentials.
265

266
        Args:
267
            model_name (str) : Name of the model
268
            aws_access_key_id (str): AWS access key ID
269
            aws_secret_access_key (str): AWS secret access key
270
            aws_session_token (str): AWS session token
271
            aws_region (str): AWS region where the service is located
272
            aws_profile (str): AWS profile, when None, default profile is chosen automatically
273

274
        Example:
275
                .. code-block:: python
276

277
                    from llama_index.embeddings import BedrockEmbedding
278

279
                    # Define the model name
280
                    model_name = "your_model_name"
281

282
                    embeddings = BedrockEmbedding.from_credentials(
283
                        model_name,
284
                        aws_access_key_id,
285
                        aws_secret_access_key,
286
                        aws_session_token,
287
                        aws_region,
288
                        aws_profile,
289
                    )
290

291
        """
292
        session_kwargs = {
293
            "profile_name": aws_profile,
294
            "region_name": aws_region,
295
            "aws_access_key_id": aws_access_key_id,
296
            "aws_secret_access_key": aws_secret_access_key,
297
            "aws_session_token": aws_session_token,
298
        }
299

300
        try:
301
            import boto3
302

303
            session = boto3.Session(**session_kwargs)
304
        except ImportError:
305
            raise ImportError(
306
                "boto3 package not found, install with" "'pip install boto3'"
307
            )
308

309
        if "bedrock-runtime" in session.get_available_services():
310
            client = session.client("bedrock-runtime")
311
        else:
312
            client = session.client("bedrock")
313
        return cls(
314
            client=client,
315
            model=model_name,
316
            embed_batch_size=embed_batch_size,
317
            callback_manager=callback_manager,
318
            verbose=verbose,
319
        )
320

321
    def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding:
322
        if self._client is None:
323
            self.set_credentials()
324

325
        if self._client is None:
326
            raise ValueError("Client not set")
327

328
        provider = self.model.split(".")[0]
329
        request_body = self._get_request_body(provider, payload, type)
330

331
        response = self._client.invoke_model(
332
            body=request_body,
333
            modelId=self.model,
334
            accept="application/json",
335
            contentType="application/json",
336
        )
337

338
        resp = json.loads(response.get("body").read().decode("utf-8"))
339
        identifiers = PROVIDER_SPECIFIC_IDENTIFIERS.get(provider, None)
340
        if identifiers is None:
341
            raise ValueError("Provider not supported")
342
        return identifiers["get_embeddings_func"](resp)
343

344
    def _get_query_embedding(self, query: str) -> Embedding:
345
        return self._get_embedding(query, "query")
346

347
    def _get_text_embedding(self, text: str) -> Embedding:
348
        return self._get_embedding(text, "text")
349

350
    def _get_request_body(
351
        self, provider: str, payload: str, type: Literal["text", "query"]
352
    ) -> Any:
353
        """Build the request body as per the provider.
354
        Currently supported providers are amazon, cohere.
355

356
        amazon:
357
            Sample Payload of type str
358
            "Hello World!"
359

360
        cohere:
361
            Sample Payload of type dict of following format
362
            {
363
                'texts': ["This is a test document", "This is another document"],
364
                'input_type': 'search_document',
365
                'truncate': 'NONE'
366
            }
367

368
        """
369
        if provider == PROVIDERS.AMAZON:
370
            request_body = json.dumps({"inputText": payload})
371
        elif provider == PROVIDERS.COHERE:
372
            input_types = {
373
                "text": "search_document",
374
                "query": "search_query",
375
            }
376
            request_body = json.dumps(
377
                {
378
                    "texts": [payload],
379
                    "input_type": input_types[type],
380
                    "truncate": "NONE",
381
                }
382
            )
383
        else:
384
            raise ValueError("Provider not supported")
385
        return request_body
386

387
    async def _aget_query_embedding(self, query: str) -> Embedding:
388
        return self._get_embedding(query, "query")
389

390
    async def _aget_text_embedding(self, text: str) -> Embedding:
391
        return self._get_embedding(text, "text")
392

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.