llama-index

Форк
0
307 строк · 10.0 Кб
1
"""DashScope embeddings file."""
2

3
import logging
4
from enum import Enum
5
from http import HTTPStatus
6
from typing import Any, Dict, List, Optional, Union
7

8
from pydantic import PrivateAttr
9

10
from llama_index.legacy.embeddings.multi_modal_base import MultiModalEmbedding
11
from llama_index.legacy.schema import ImageType
12

13
logger = logging.getLogger(__name__)
14

15

16
class DashScopeTextEmbeddingType(str, Enum):
17
    """DashScope TextEmbedding text_type."""
18

19
    TEXT_TYPE_QUERY = "query"
20
    TEXT_TYPE_DOCUMENT = "document"
21

22

23
class DashScopeTextEmbeddingModels(str, Enum):
24
    """DashScope TextEmbedding models."""
25

26
    TEXT_EMBEDDING_V1 = "text-embedding-v1"
27
    TEXT_EMBEDDING_V2 = "text-embedding-v2"
28

29

30
class DashScopeBatchTextEmbeddingModels(str, Enum):
31
    """DashScope TextEmbedding models."""
32

33
    TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
34
    TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
35

36

37
EMBED_MAX_INPUT_LENGTH = 2048
38
EMBED_MAX_BATCH_SIZE = 25
39

40

41
class DashScopeMultiModalEmbeddingModels(str, Enum):
42
    """DashScope MultiModalEmbedding models."""
43

44
    MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
45

46

47
def get_text_embedding(
48
    model: str,
49
    text: Union[str, List[str]],
50
    api_key: Optional[str] = None,
51
    **kwargs: Any,
52
) -> List[List[float]]:
53
    """Call DashScope text embedding.
54
       ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
55

56
    Args:
57
        model (str): The `DashScopeTextEmbeddingModels`
58
        text (Union[str, List[str]]): text or list text to embedding.
59

60
    Raises:
61
        ImportError: need import dashscope
62

63
    Returns:
64
        List[List[float]]: The list of embedding result, if failed return empty list.
65
    """
66
    try:
67
        import dashscope
68
    except ImportError:
69
        raise ImportError("DashScope requires `pip install dashscope")
70
    if isinstance(text, str):
71
        text = [text]
72
    embedding_results = []
73
    response = dashscope.TextEmbedding.call(
74
        model=model, input=text, api_key=api_key, kwargs=kwargs
75
    )
76
    if response.status_code == HTTPStatus.OK:
77
        for emb in response.output["embeddings"]:
78
            embedding_results.append(emb["embedding"])
79
    else:
80
        logger.error("Calling TextEmbedding failed, details: %s" % response)
81

82
    return embedding_results
83

84

85
def get_batch_text_embedding(
86
    model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
87
) -> Optional[str]:
88
    """Call DashScope batch text embedding.
89

90
    Args:
91
        model (str): The `DashScopeMultiModalEmbeddingModels`
92
        url (str): The url of the file to embedding which with lines of text to embedding.
93

94
    Raises:
95
        ImportError: Need install dashscope package.
96

97
    Returns:
98
        str: The url of the embedding result, format ref:
99
        https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
100
    """
101
    try:
102
        import dashscope
103
    except ImportError:
104
        raise ImportError("DashScope requires `pip install dashscope")
105
    response = dashscope.BatchTextEmbedding.call(
106
        model=model, url=url, api_key=api_key, kwargs=kwargs
107
    )
108
    if response.status_code == HTTPStatus.OK:
109
        return response.output["url"]
110
    else:
111
        logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
112
        return None
113

114

115
def get_multimodal_embedding(
116
    model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
117
) -> List[float]:
118
    """Call DashScope multimodal embedding.
119
       ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
120

121
    Args:
122
        model (str): The `DashScopeBatchTextEmbeddingModels`
123
        input (str): The input of the embedding, eg:
124
             [{'factor': 1, 'text': '你好'},
125
             {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
126
             {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
127

128
    Raises:
129
        ImportError: Need install dashscope package.
130

131
    Returns:
132
        List[float]: Embedding result, if failed return empty list.
133
    """
134
    try:
135
        import dashscope
136
    except ImportError:
137
        raise ImportError("DashScope requires `pip install dashscope")
138
    response = dashscope.MultiModalEmbedding.call(
139
        model=model, input=input, api_key=api_key, kwargs=kwargs
140
    )
141
    if response.status_code == HTTPStatus.OK:
142
        return response.output["embedding"]
143
    else:
144
        logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
145
        return []
146

147

148
class DashScopeEmbedding(MultiModalEmbedding):
149
    """DashScope class for text embedding.
150

151
    Args:
152
        model_name (str): Model name for embedding.
153
            Defaults to DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2.
154
                Options are:
155

156
                - DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V1
157
                - DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2
158
        text_type (str): The input type, ['query', 'document'],
159
            For asymmetric tasks such as retrieval, in order to achieve better
160
            retrieval results, it is recommended to distinguish between query
161
            text (query) and base text (document) types, clustering Symmetric
162
            tasks such as classification and classification do not need to
163
            be specially specified, and the system default
164
            value "document" can be used.
165
        api_key (str): The DashScope api key.
166
    """
167

168
    _api_key: Optional[str] = PrivateAttr()
169
    _text_type: Optional[str] = PrivateAttr()
170

171
    def __init__(
172
        self,
173
        model_name: str = DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
174
        text_type: str = "document",
175
        api_key: Optional[str] = None,
176
        **kwargs: Any,
177
    ) -> None:
178
        self._api_key = api_key
179
        self._text_type = text_type
180
        super().__init__(
181
            model_name=model_name,
182
            **kwargs,
183
        )
184

185
    @classmethod
186
    def class_name(cls) -> str:
187
        return "DashScopeEmbedding"
188

189
    def _get_query_embedding(self, query: str) -> List[float]:
190
        """Get query embedding."""
191
        emb = get_text_embedding(
192
            self.model_name,
193
            query,
194
            api_key=self._api_key,
195
            text_type=self._text_type,
196
        )
197
        if len(emb) > 0:
198
            return emb[0]
199
        else:
200
            return []
201

202
    def _get_text_embedding(self, text: str) -> List[float]:
203
        """Get text embedding."""
204
        emb = get_text_embedding(
205
            self.model_name,
206
            text,
207
            api_key=self._api_key,
208
            text_type=self._text_type,
209
        )
210
        if len(emb) > 0:
211
            return emb[0]
212
        else:
213
            return []
214

215
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
216
        """Get text embeddings."""
217
        return get_text_embedding(
218
            self.model_name,
219
            texts,
220
            api_key=self._api_key,
221
            text_type=self._text_type,
222
        )
223

224
    # TODO: use proper async methods
225
    async def _aget_text_embedding(self, query: str) -> List[float]:
226
        """Get text embedding."""
227
        return self._get_text_embedding(query)
228

229
    # TODO: user proper async methods
230
    async def _aget_query_embedding(self, query: str) -> List[float]:
231
        """Get query embedding."""
232
        return self._get_query_embedding(query)
233

234
    def get_batch_query_embedding(self, embedding_file_url: str) -> Optional[str]:
235
        """Get batch query embeddings.
236

237
        Args:
238
            embedding_file_url (str): The url of the file to embedding which with lines of text to embedding.
239

240
        Returns:
241
            str: The url of the embedding result, format ref:
242
                 https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details.
243
        """
244
        return get_batch_text_embedding(
245
            self.model_name,
246
            embedding_file_url,
247
            api_key=self._api_key,
248
            text_type=self._text_type,
249
        )
250

251
    def get_batch_text_embedding(self, embedding_file_url: str) -> Optional[str]:
252
        """Get batch text embeddings.
253

254
        Args:
255
            embedding_file_url (str): The url of the file to embedding which with lines of text to embedding.
256

257
        Returns:
258
            str: The url of the embedding result, format ref:
259
                 https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details.
260
        """
261
        return get_batch_text_embedding(
262
            self.model_name,
263
            embedding_file_url,
264
            api_key=self._api_key,
265
            text_type=self._text_type,
266
        )
267

268
    def _get_image_embedding(self, img_file_path: ImageType) -> List[float]:
269
        """
270
        Embed the input image synchronously.
271
        """
272
        input = [{"image": img_file_path}]
273
        return get_multimodal_embedding(
274
            self.model_name, input=input, api_key=self._api_key
275
        )
276

277
    async def _aget_image_embedding(self, img_file_path: ImageType) -> List[float]:
278
        """
279
        Embed the input image asynchronously.
280

281
        """
282
        return self._get_image_embedding(img_file_path=img_file_path)
283

284
    def get_multimodal_embedding(
285
        self, input: List[Dict], auto_truncation: bool = False
286
    ) -> List[float]:
287
        """Call DashScope multimodal embedding.
288
        ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
289

290
        Args:
291
            input (str): The input of the multimodal embedding, eg:
292
                [{'factor': 1, 'text': '你好'},
293
                {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
294
                {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
295

296
        Raises:
297
            ImportError: Need install dashscope package.
298

299
        Returns:
300
            List[float]: The embedding result
301
        """
302
        return get_multimodal_embedding(
303
            self.model_name,
304
            input=input,
305
            api_key=self._api_key,
306
            auto_truncation=auto_truncation,
307
        )
308

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.