llama-index
391 строка · 13.8 Кб
1import json2import os3import warnings4from enum import Enum5from typing import Any, Callable, Dict, List, Literal, Optional, Sequence6
7from deprecated import deprecated8
9from llama_index.legacy.bridge.pydantic import Field, PrivateAttr10from llama_index.legacy.callbacks.base import CallbackManager11from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE12from llama_index.legacy.core.embeddings.base import BaseEmbedding, Embedding13from llama_index.legacy.core.llms.types import ChatMessage14from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode15
16
17class PROVIDERS(str, Enum):18AMAZON = "amazon"19COHERE = "cohere"20
21
22class Models(str, Enum):23TITAN_EMBEDDING = "amazon.titan-embed-text-v1"24TITAN_EMBEDDING_G1_TEXT_02 = "amazon.titan-embed-g1-text-02"25COHERE_EMBED_ENGLISH_V3 = "cohere.embed-english-v3"26COHERE_EMBED_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"27
28
29PROVIDER_SPECIFIC_IDENTIFIERS = {30PROVIDERS.AMAZON.value: {31"get_embeddings_func": lambda r: r.get("embedding"),32},33PROVIDERS.COHERE.value: {34"get_embeddings_func": lambda r: r.get("embeddings")[0],35},36}
37
38
39class BedrockEmbedding(BaseEmbedding):40model: str = Field(description="The modelId of the Bedrock model to use.")41profile_name: Optional[str] = Field(42description="The name of aws profile to use. If not given, then the default profile is used.",43exclude=True,44)45aws_access_key_id: Optional[str] = Field(46description="AWS Access Key ID to use", exclude=True47)48aws_secret_access_key: Optional[str] = Field(49description="AWS Secret Access Key to use", exclude=True50)51aws_session_token: Optional[str] = Field(52description="AWS Session Token to use", exclude=True53)54region_name: Optional[str] = Field(55description="AWS region name to use. Uses region configured in AWS CLI if not passed",56exclude=True,57)58botocore_session: Optional[Any] = Field(59description="Use this Botocore session instead of creating a new default one.",60exclude=True,61)62botocore_config: Optional[Any] = Field(63description="Custom configuration object to use instead of the default generated one.",64exclude=True,65)66max_retries: int = Field(67default=10, description="The maximum number of API retries.", gt=068)69timeout: float = Field(70default=60.0,71description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",72)73additional_kwargs: Dict[str, Any] = Field(74default_factory=dict, description="Additional kwargs for the bedrock client."75)76_client: Any = PrivateAttr()77
78def __init__(79self,80model: str = Models.TITAN_EMBEDDING,81profile_name: Optional[str] = None,82aws_access_key_id: Optional[str] = None,83aws_secret_access_key: Optional[str] = None,84aws_session_token: Optional[str] = None,85region_name: Optional[str] = None,86client: Optional[Any] = None,87botocore_session: Optional[Any] = None,88botocore_config: Optional[Any] = None,89additional_kwargs: Optional[Dict[str, Any]] = None,90max_retries: int = 10,91timeout: float = 60.0,92callback_manager: Optional[CallbackManager] = None,93# base class94system_prompt: Optional[str] = None,95messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,96completion_to_prompt: Optional[Callable[[str], str]] = None,97pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,98output_parser: Optional[BaseOutputParser] = None,99**kwargs: Any,100):101additional_kwargs = additional_kwargs or {}102
103session_kwargs = {104"profile_name": profile_name,105"region_name": region_name,106"aws_access_key_id": aws_access_key_id,107"aws_secret_access_key": aws_secret_access_key,108"aws_session_token": aws_session_token,109"botocore_session": botocore_session,110}111config = None112try:113import boto3114from botocore.config import Config115
116config = (117Config(118retries={"max_attempts": max_retries, "mode": "standard"},119connect_timeout=timeout,120read_timeout=timeout,121)122if botocore_config is None123else botocore_config124)125session = boto3.Session(**session_kwargs)126except ImportError:127raise ImportError(128"boto3 package not found, install with" "'pip install boto3'"129)130
131# Prior to general availability, custom boto3 wheel files were132# distributed that used the bedrock service to invokeModel.133# This check prevents any services still using those wheel files134# from breaking135if client is not None:136self._client = client137elif "bedrock-runtime" in session.get_available_services():138self._client = session.client("bedrock-runtime", config=config)139else:140self._client = session.client("bedrock", config=config)141
142super().__init__(143model=model,144max_retries=max_retries,145timeout=timeout,146botocore_config=config,147profile_name=profile_name,148aws_access_key_id=aws_access_key_id,149aws_secret_access_key=aws_secret_access_key,150aws_session_token=aws_session_token,151region_name=region_name,152botocore_session=botocore_session,153additional_kwargs=additional_kwargs,154callback_manager=callback_manager,155system_prompt=system_prompt,156messages_to_prompt=messages_to_prompt,157completion_to_prompt=completion_to_prompt,158pydantic_program_mode=pydantic_program_mode,159output_parser=output_parser,160**kwargs,161)162
163@staticmethod164def list_supported_models() -> Dict[str, List[str]]:165list_models = {}166for provider in PROVIDERS:167list_models[provider.value] = [m.value for m in Models]168return list_models169
170@classmethod171def class_name(self) -> str:172return "BedrockEmbedding"173
174@deprecated(175version="0.9.48",176reason=(177"Use the provided kwargs in the constructor, "178"set_credentials will be removed in future releases."179),180action="once",181)182def set_credentials(183self,184aws_region: Optional[str] = None,185aws_access_key_id: Optional[str] = None,186aws_secret_access_key: Optional[str] = None,187aws_session_token: Optional[str] = None,188aws_profile: Optional[str] = None,189) -> None:190aws_region = aws_region or os.getenv("AWS_REGION")191aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")192aws_secret_access_key = aws_secret_access_key or os.getenv(193"AWS_SECRET_ACCESS_KEY"194)195aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")196
197if aws_region is None:198warnings.warn(199"AWS_REGION not found. Set environment variable AWS_REGION or set aws_region"200)201
202if aws_access_key_id is None:203warnings.warn(204"AWS_ACCESS_KEY_ID not found. Set environment variable AWS_ACCESS_KEY_ID or set aws_access_key_id"205)206assert aws_access_key_id is not None207
208if aws_secret_access_key is None:209warnings.warn(210"AWS_SECRET_ACCESS_KEY not found. Set environment variable AWS_SECRET_ACCESS_KEY or set aws_secret_access_key"211)212assert aws_secret_access_key is not None213
214if aws_session_token is None:215warnings.warn(216"AWS_SESSION_TOKEN not found. Set environment variable AWS_SESSION_TOKEN or set aws_session_token"217)218assert aws_session_token is not None219
220session_kwargs = {221"profile_name": aws_profile,222"region_name": aws_region,223"aws_access_key_id": aws_access_key_id,224"aws_secret_access_key": aws_secret_access_key,225"aws_session_token": aws_session_token,226}227
228try:229import boto3230
231session = boto3.Session(**session_kwargs)232except ImportError:233raise ImportError(234"boto3 package not found, install with" "'pip install boto3'"235)236
237if "bedrock-runtime" in session.get_available_services():238self._client = session.client("bedrock-runtime")239else:240self._client = session.client("bedrock")241
242@classmethod243@deprecated(244version="0.9.48",245reason=(246"Use the provided kwargs in the constructor, "247"set_credentials will be removed in future releases."248),249action="once",250)251def from_credentials(252cls,253model_name: str = Models.TITAN_EMBEDDING,254aws_region: Optional[str] = None,255aws_access_key_id: Optional[str] = None,256aws_secret_access_key: Optional[str] = None,257aws_session_token: Optional[str] = None,258aws_profile: Optional[str] = None,259embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,260callback_manager: Optional[CallbackManager] = None,261verbose: bool = False,262) -> "BedrockEmbedding":263"""264Instantiate using AWS credentials.
265
266Args:
267model_name (str) : Name of the model
268aws_access_key_id (str): AWS access key ID
269aws_secret_access_key (str): AWS secret access key
270aws_session_token (str): AWS session token
271aws_region (str): AWS region where the service is located
272aws_profile (str): AWS profile, when None, default profile is chosen automatically
273
274Example:
275.. code-block:: python
276
277from llama_index.embeddings import BedrockEmbedding
278
279# Define the model name
280model_name = "your_model_name"
281
282embeddings = BedrockEmbedding.from_credentials(
283model_name,
284aws_access_key_id,
285aws_secret_access_key,
286aws_session_token,
287aws_region,
288aws_profile,
289)
290
291"""
292session_kwargs = {293"profile_name": aws_profile,294"region_name": aws_region,295"aws_access_key_id": aws_access_key_id,296"aws_secret_access_key": aws_secret_access_key,297"aws_session_token": aws_session_token,298}299
300try:301import boto3302
303session = boto3.Session(**session_kwargs)304except ImportError:305raise ImportError(306"boto3 package not found, install with" "'pip install boto3'"307)308
309if "bedrock-runtime" in session.get_available_services():310client = session.client("bedrock-runtime")311else:312client = session.client("bedrock")313return cls(314client=client,315model=model_name,316embed_batch_size=embed_batch_size,317callback_manager=callback_manager,318verbose=verbose,319)320
321def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding:322if self._client is None:323self.set_credentials()324
325if self._client is None:326raise ValueError("Client not set")327
328provider = self.model.split(".")[0]329request_body = self._get_request_body(provider, payload, type)330
331response = self._client.invoke_model(332body=request_body,333modelId=self.model,334accept="application/json",335contentType="application/json",336)337
338resp = json.loads(response.get("body").read().decode("utf-8"))339identifiers = PROVIDER_SPECIFIC_IDENTIFIERS.get(provider, None)340if identifiers is None:341raise ValueError("Provider not supported")342return identifiers["get_embeddings_func"](resp)343
344def _get_query_embedding(self, query: str) -> Embedding:345return self._get_embedding(query, "query")346
347def _get_text_embedding(self, text: str) -> Embedding:348return self._get_embedding(text, "text")349
350def _get_request_body(351self, provider: str, payload: str, type: Literal["text", "query"]352) -> Any:353"""Build the request body as per the provider.354Currently supported providers are amazon, cohere.
355
356amazon:
357Sample Payload of type str
358"Hello World!"
359
360cohere:
361Sample Payload of type dict of following format
362{
363'texts': ["This is a test document", "This is another document"],
364'input_type': 'search_document',
365'truncate': 'NONE'
366}
367
368"""
369if provider == PROVIDERS.AMAZON:370request_body = json.dumps({"inputText": payload})371elif provider == PROVIDERS.COHERE:372input_types = {373"text": "search_document",374"query": "search_query",375}376request_body = json.dumps(377{378"texts": [payload],379"input_type": input_types[type],380"truncate": "NONE",381}382)383else:384raise ValueError("Provider not supported")385return request_body386
387async def _aget_query_embedding(self, query: str) -> Embedding:388return self._get_embedding(query, "query")389
390async def _aget_text_embedding(self, text: str) -> Embedding:391return self._get_embedding(text, "text")392