llama-index

Форк
0
/
sagemaker_embedding_endpoint.py 
153 строки · 5.7 Кб
1
from typing import Any, Dict, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks.base import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
6
from llama_index.legacy.core.embeddings.base import BaseEmbedding, Embedding
7
from llama_index.legacy.embeddings.sagemaker_embedding_endpoint_utils import (
8
    BaseIOHandler,
9
    IOHandler,
10
)
11
from llama_index.legacy.types import PydanticProgramMode
12
from llama_index.legacy.utilities.aws_utils import get_aws_service_client
13

14
DEFAULT_IO_HANDLER = IOHandler()
15

16

17
class SageMakerEmbedding(BaseEmbedding):
18
    endpoint_name: str = Field(description="SageMaker Embedding endpoint name")
19
    endpoint_kwargs: Dict[str, Any] = Field(
20
        default={},
21
        description="Additional kwargs for the invoke_endpoint request.",
22
    )
23
    model_kwargs: Dict[str, Any] = Field(
24
        default={},
25
        description="kwargs to pass to the model.",
26
    )
27
    content_handler: BaseIOHandler = Field(
28
        default=DEFAULT_IO_HANDLER,
29
        description="used to serialize input, deserialize output, and remove a prefix.",
30
    )
31

32
    profile_name: Optional[str] = Field(
33
        description="The name of aws profile to use. If not given, then the default profile is used."
34
    )
35
    aws_access_key_id: Optional[str] = Field(description="AWS Access Key ID to use")
36
    aws_secret_access_key: Optional[str] = Field(
37
        description="AWS Secret Access Key to use"
38
    )
39
    aws_session_token: Optional[str] = Field(description="AWS Session Token to use")
40
    aws_region_name: Optional[str] = Field(
41
        description="AWS region name to use. Uses region configured in AWS CLI if not passed"
42
    )
43
    max_retries: Optional[int] = Field(
44
        default=3,
45
        description="The maximum number of API retries.",
46
        gte=0,
47
    )
48
    timeout: Optional[float] = Field(
49
        default=60.0,
50
        description="The timeout, in seconds, for API requests.",
51
        gte=0,
52
    )
53
    _client: Any = PrivateAttr()
54
    _verbose: bool = PrivateAttr()
55

56
    def __init__(
57
        self,
58
        endpoint_name: str,
59
        endpoint_kwargs: Optional[Dict[str, Any]] = {},
60
        model_kwargs: Optional[Dict[str, Any]] = {},
61
        content_handler: BaseIOHandler = DEFAULT_IO_HANDLER,
62
        profile_name: Optional[str] = None,
63
        aws_access_key_id: Optional[str] = None,
64
        aws_secret_access_key: Optional[str] = None,
65
        aws_session_token: Optional[str] = None,
66
        region_name: Optional[str] = None,
67
        max_retries: Optional[int] = 3,
68
        timeout: Optional[float] = 60.0,
69
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
70
        callback_manager: Optional[CallbackManager] = None,
71
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
72
        verbose: bool = False,
73
    ):
74
        if not endpoint_name:
75
            raise ValueError(
76
                "Missing required argument:`endpoint_name`"
77
                " Please specify the endpoint_name"
78
            )
79
        endpoint_kwargs = endpoint_kwargs or {}
80
        model_kwargs = model_kwargs or {}
81
        content_handler = content_handler
82
        self._client = get_aws_service_client(
83
            service_name="sagemaker-runtime",
84
            profile_name=profile_name,
85
            region_name=region_name,
86
            aws_access_key_id=aws_access_key_id,
87
            aws_secret_access_key=aws_secret_access_key,
88
            aws_session_token=aws_session_token,
89
            max_retries=max_retries,
90
            timeout=timeout,
91
        )
92
        self._verbose = verbose
93

94
        super().__init__(
95
            endpoint_name=endpoint_name,
96
            endpoint_kwargs=endpoint_kwargs,
97
            model_kwargs=model_kwargs,
98
            content_handler=content_handler,
99
            embed_batch_size=embed_batch_size,
100
            pydantic_program_mode=pydantic_program_mode,
101
            callback_manager=callback_manager,
102
        )
103

104
    @classmethod
105
    def class_name(self) -> str:
106
        return "SageMakerEmbedding"
107

108
    def _get_embedding(self, payload: List[str], **kwargs: Any) -> List[Embedding]:
109
        model_kwargs = {**self.model_kwargs, **kwargs}
110

111
        request_body = self.content_handler.serialize_input(
112
            request=payload, model_kwargs=model_kwargs
113
        )
114

115
        response = self._client.invoke_endpoint(
116
            EndpointName=self.endpoint_name,
117
            Body=request_body,
118
            ContentType=self.content_handler.content_type,
119
            Accept=self.content_handler.accept,
120
            **self.endpoint_kwargs,
121
        )["Body"]
122

123
        return self.content_handler.deserialize_output(response=response)
124

125
    def _get_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
126
        query = query.replace("\n", " ")
127
        return self._get_embedding([query], **kwargs)[0]
128

129
    def _get_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
130
        text = text.replace("\n", " ")
131
        return self._get_embedding([text], **kwargs)[0]
132

133
    def _get_text_embeddings(self, texts: List[str], **kwargs: Any) -> List[Embedding]:
134
        """
135
        Embed the input sequence of text synchronously.
136

137
        Subclasses can implement this method if batch queries are supported.
138
        """
139
        texts = [text.replace("\n", " ") for text in texts]
140

141
        # Default implementation just loops over _get_text_embedding
142
        return self._get_embedding(texts, **kwargs)
143

144
    async def _aget_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
145
        raise NotImplementedError
146

147
    async def _aget_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
148
        raise NotImplementedError
149

150
    async def _aget_text_embeddings(
151
        self, texts: List[str], **kwargs: Any
152
    ) -> List[Embedding]:
153
        raise NotImplementedError
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.