llama-index
153 строки · 5.7 Кб
1from typing import Any, Dict, List, Optional2
3from llama_index.legacy.bridge.pydantic import Field, PrivateAttr4from llama_index.legacy.callbacks.base import CallbackManager5from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE6from llama_index.legacy.core.embeddings.base import BaseEmbedding, Embedding7from llama_index.legacy.embeddings.sagemaker_embedding_endpoint_utils import (8BaseIOHandler,9IOHandler,10)
11from llama_index.legacy.types import PydanticProgramMode12from llama_index.legacy.utilities.aws_utils import get_aws_service_client13
14DEFAULT_IO_HANDLER = IOHandler()15
16
17class SageMakerEmbedding(BaseEmbedding):18endpoint_name: str = Field(description="SageMaker Embedding endpoint name")19endpoint_kwargs: Dict[str, Any] = Field(20default={},21description="Additional kwargs for the invoke_endpoint request.",22)23model_kwargs: Dict[str, Any] = Field(24default={},25description="kwargs to pass to the model.",26)27content_handler: BaseIOHandler = Field(28default=DEFAULT_IO_HANDLER,29description="used to serialize input, deserialize output, and remove a prefix.",30)31
32profile_name: Optional[str] = Field(33description="The name of aws profile to use. If not given, then the default profile is used."34)35aws_access_key_id: Optional[str] = Field(description="AWS Access Key ID to use")36aws_secret_access_key: Optional[str] = Field(37description="AWS Secret Access Key to use"38)39aws_session_token: Optional[str] = Field(description="AWS Session Token to use")40aws_region_name: Optional[str] = Field(41description="AWS region name to use. Uses region configured in AWS CLI if not passed"42)43max_retries: Optional[int] = Field(44default=3,45description="The maximum number of API retries.",46gte=0,47)48timeout: Optional[float] = Field(49default=60.0,50description="The timeout, in seconds, for API requests.",51gte=0,52)53_client: Any = PrivateAttr()54_verbose: bool = PrivateAttr()55
56def __init__(57self,58endpoint_name: str,59endpoint_kwargs: Optional[Dict[str, Any]] = {},60model_kwargs: Optional[Dict[str, Any]] = {},61content_handler: BaseIOHandler = DEFAULT_IO_HANDLER,62profile_name: Optional[str] = None,63aws_access_key_id: Optional[str] = None,64aws_secret_access_key: Optional[str] = None,65aws_session_token: Optional[str] = None,66region_name: Optional[str] = None,67max_retries: Optional[int] = 3,68timeout: Optional[float] = 60.0,69embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,70callback_manager: Optional[CallbackManager] = None,71pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,72verbose: bool = False,73):74if not endpoint_name:75raise ValueError(76"Missing required argument:`endpoint_name`"77" Please specify the endpoint_name"78)79endpoint_kwargs = endpoint_kwargs or {}80model_kwargs = model_kwargs or {}81content_handler = content_handler82self._client = get_aws_service_client(83service_name="sagemaker-runtime",84profile_name=profile_name,85region_name=region_name,86aws_access_key_id=aws_access_key_id,87aws_secret_access_key=aws_secret_access_key,88aws_session_token=aws_session_token,89max_retries=max_retries,90timeout=timeout,91)92self._verbose = verbose93
94super().__init__(95endpoint_name=endpoint_name,96endpoint_kwargs=endpoint_kwargs,97model_kwargs=model_kwargs,98content_handler=content_handler,99embed_batch_size=embed_batch_size,100pydantic_program_mode=pydantic_program_mode,101callback_manager=callback_manager,102)103
104@classmethod105def class_name(self) -> str:106return "SageMakerEmbedding"107
108def _get_embedding(self, payload: List[str], **kwargs: Any) -> List[Embedding]:109model_kwargs = {**self.model_kwargs, **kwargs}110
111request_body = self.content_handler.serialize_input(112request=payload, model_kwargs=model_kwargs113)114
115response = self._client.invoke_endpoint(116EndpointName=self.endpoint_name,117Body=request_body,118ContentType=self.content_handler.content_type,119Accept=self.content_handler.accept,120**self.endpoint_kwargs,121)["Body"]122
123return self.content_handler.deserialize_output(response=response)124
125def _get_query_embedding(self, query: str, **kwargs: Any) -> Embedding:126query = query.replace("\n", " ")127return self._get_embedding([query], **kwargs)[0]128
129def _get_text_embedding(self, text: str, **kwargs: Any) -> Embedding:130text = text.replace("\n", " ")131return self._get_embedding([text], **kwargs)[0]132
133def _get_text_embeddings(self, texts: List[str], **kwargs: Any) -> List[Embedding]:134"""135Embed the input sequence of text synchronously.
136
137Subclasses can implement this method if batch queries are supported.
138"""
139texts = [text.replace("\n", " ") for text in texts]140
141# Default implementation just loops over _get_text_embedding142return self._get_embedding(texts, **kwargs)143
144async def _aget_query_embedding(self, query: str, **kwargs: Any) -> Embedding:145raise NotImplementedError146
147async def _aget_text_embedding(self, text: str, **kwargs: Any) -> Embedding:148raise NotImplementedError149
150async def _aget_text_embeddings(151self, texts: List[str], **kwargs: Any152) -> List[Embedding]:153raise NotImplementedError154