llama-index

Форк
0
428 строк · 14.2 Кб
1
"""OpenAI embeddings file."""
2

3
from enum import Enum
4
from typing import Any, Dict, List, Optional, Tuple
5

6
import httpx
7
from openai import AsyncOpenAI, OpenAI
8

9
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
10
from llama_index.legacy.callbacks.base import CallbackManager
11
from llama_index.legacy.embeddings.base import BaseEmbedding
12
from llama_index.legacy.llms.openai_utils import (
13
    create_retry_decorator,
14
    resolve_openai_credentials,
15
)
16

17
embedding_retry_decorator = create_retry_decorator(
18
    max_retries=6,
19
    random_exponential=True,
20
    stop_after_delay_seconds=60,
21
    min_seconds=1,
22
    max_seconds=20,
23
)
24

25

26
class OpenAIEmbeddingMode(str, Enum):
27
    """OpenAI embedding mode."""
28

29
    SIMILARITY_MODE = "similarity"
30
    TEXT_SEARCH_MODE = "text_search"
31

32

33
class OpenAIEmbeddingModelType(str, Enum):
34
    """OpenAI embedding model type."""
35

36
    DAVINCI = "davinci"
37
    CURIE = "curie"
38
    BABBAGE = "babbage"
39
    ADA = "ada"
40
    TEXT_EMBED_ADA_002 = "text-embedding-ada-002"
41
    TEXT_EMBED_3_LARGE = "text-embedding-3-large"
42
    TEXT_EMBED_3_SMALL = "text-embedding-3-small"
43

44

45
class OpenAIEmbeddingModeModel(str, Enum):
46
    """OpenAI embedding mode model."""
47

48
    # davinci
49
    TEXT_SIMILARITY_DAVINCI = "text-similarity-davinci-001"
50
    TEXT_SEARCH_DAVINCI_QUERY = "text-search-davinci-query-001"
51
    TEXT_SEARCH_DAVINCI_DOC = "text-search-davinci-doc-001"
52

53
    # curie
54
    TEXT_SIMILARITY_CURIE = "text-similarity-curie-001"
55
    TEXT_SEARCH_CURIE_QUERY = "text-search-curie-query-001"
56
    TEXT_SEARCH_CURIE_DOC = "text-search-curie-doc-001"
57

58
    # babbage
59
    TEXT_SIMILARITY_BABBAGE = "text-similarity-babbage-001"
60
    TEXT_SEARCH_BABBAGE_QUERY = "text-search-babbage-query-001"
61
    TEXT_SEARCH_BABBAGE_DOC = "text-search-babbage-doc-001"
62

63
    # ada
64
    TEXT_SIMILARITY_ADA = "text-similarity-ada-001"
65
    TEXT_SEARCH_ADA_QUERY = "text-search-ada-query-001"
66
    TEXT_SEARCH_ADA_DOC = "text-search-ada-doc-001"
67

68
    # text-embedding-ada-002
69
    TEXT_EMBED_ADA_002 = "text-embedding-ada-002"
70

71
    # text-embedding-3-large
72
    TEXT_EMBED_3_LARGE = "text-embedding-3-large"
73

74
    # text-embedding-3-small
75
    TEXT_EMBED_3_SMALL = "text-embedding-3-small"
76

77

78
# convenient shorthand
79
OAEM = OpenAIEmbeddingMode
80
OAEMT = OpenAIEmbeddingModelType
81
OAEMM = OpenAIEmbeddingModeModel
82

83
EMBED_MAX_TOKEN_LIMIT = 2048
84

85

86
_QUERY_MODE_MODEL_DICT = {
87
    (OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI,
88
    (OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE,
89
    (OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE,
90
    (OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA,
91
    (OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
92
    (OAEM.SIMILARITY_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
93
    (OAEM.SIMILARITY_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
94
    (OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_QUERY,
95
    (OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_QUERY,
96
    (OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_QUERY,
97
    (OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_QUERY,
98
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
99
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
100
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
101
}
102

103
_TEXT_MODE_MODEL_DICT = {
104
    (OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI,
105
    (OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE,
106
    (OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE,
107
    (OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA,
108
    (OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
109
    (OAEM.SIMILARITY_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
110
    (OAEM.SIMILARITY_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
111
    (OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_DOC,
112
    (OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_DOC,
113
    (OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_DOC,
114
    (OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_DOC,
115
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
116
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
117
    (OAEM.TEXT_SEARCH_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
118
}
119

120

121
@embedding_retry_decorator
122
def get_embedding(client: OpenAI, text: str, engine: str, **kwargs: Any) -> List[float]:
123
    """Get embedding.
124

125
    NOTE: Copied from OpenAI's embedding utils:
126
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
127

128
    Copied here to avoid importing unnecessary dependencies
129
    like matplotlib, plotly, scipy, sklearn.
130

131
    """
132
    text = text.replace("\n", " ")
133

134
    return (
135
        client.embeddings.create(input=[text], model=engine, **kwargs).data[0].embedding
136
    )
137

138

139
@embedding_retry_decorator
140
async def aget_embedding(
141
    aclient: AsyncOpenAI, text: str, engine: str, **kwargs: Any
142
) -> List[float]:
143
    """Asynchronously get embedding.
144

145
    NOTE: Copied from OpenAI's embedding utils:
146
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
147

148
    Copied here to avoid importing unnecessary dependencies
149
    like matplotlib, plotly, scipy, sklearn.
150

151
    """
152
    text = text.replace("\n", " ")
153

154
    return (
155
        (await aclient.embeddings.create(input=[text], model=engine, **kwargs))
156
        .data[0]
157
        .embedding
158
    )
159

160

161
@embedding_retry_decorator
162
def get_embeddings(
163
    client: OpenAI, list_of_text: List[str], engine: str, **kwargs: Any
164
) -> List[List[float]]:
165
    """Get embeddings.
166

167
    NOTE: Copied from OpenAI's embedding utils:
168
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
169

170
    Copied here to avoid importing unnecessary dependencies
171
    like matplotlib, plotly, scipy, sklearn.
172

173
    """
174
    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
175

176
    list_of_text = [text.replace("\n", " ") for text in list_of_text]
177

178
    data = client.embeddings.create(input=list_of_text, model=engine, **kwargs).data
179
    return [d.embedding for d in data]
180

181

182
@embedding_retry_decorator
183
async def aget_embeddings(
184
    aclient: AsyncOpenAI,
185
    list_of_text: List[str],
186
    engine: str,
187
    **kwargs: Any,
188
) -> List[List[float]]:
189
    """Asynchronously get embeddings.
190

191
    NOTE: Copied from OpenAI's embedding utils:
192
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
193

194
    Copied here to avoid importing unnecessary dependencies
195
    like matplotlib, plotly, scipy, sklearn.
196

197
    """
198
    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
199

200
    list_of_text = [text.replace("\n", " ") for text in list_of_text]
201

202
    data = (
203
        await aclient.embeddings.create(input=list_of_text, model=engine, **kwargs)
204
    ).data
205
    return [d.embedding for d in data]
206

207

208
def get_engine(
209
    mode: str,
210
    model: str,
211
    mode_model_dict: Dict[Tuple[OpenAIEmbeddingMode, str], OpenAIEmbeddingModeModel],
212
) -> OpenAIEmbeddingModeModel:
213
    """Get engine."""
214
    key = (OpenAIEmbeddingMode(mode), OpenAIEmbeddingModelType(model))
215
    if key not in mode_model_dict:
216
        raise ValueError(f"Invalid mode, model combination: {key}")
217
    return mode_model_dict[key]
218

219

220
class OpenAIEmbedding(BaseEmbedding):
221
    """OpenAI class for embeddings.
222

223
    Args:
224
        mode (str): Mode for embedding.
225
            Defaults to OpenAIEmbeddingMode.TEXT_SEARCH_MODE.
226
            Options are:
227

228
            - OpenAIEmbeddingMode.SIMILARITY_MODE
229
            - OpenAIEmbeddingMode.TEXT_SEARCH_MODE
230

231
        model (str): Model for embedding.
232
            Defaults to OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002.
233
            Options are:
234

235
            - OpenAIEmbeddingModelType.DAVINCI
236
            - OpenAIEmbeddingModelType.CURIE
237
            - OpenAIEmbeddingModelType.BABBAGE
238
            - OpenAIEmbeddingModelType.ADA
239
            - OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002
240
    """
241

242
    additional_kwargs: Dict[str, Any] = Field(
243
        default_factory=dict, description="Additional kwargs for the OpenAI API."
244
    )
245

246
    api_key: str = Field(description="The OpenAI API key.")
247
    api_base: str = Field(description="The base URL for OpenAI API.")
248
    api_version: str = Field(description="The version for OpenAI API.")
249

250
    max_retries: int = Field(
251
        default=10, description="Maximum number of retries.", gte=0
252
    )
253
    timeout: float = Field(default=60.0, description="Timeout for each request.", gte=0)
254
    default_headers: Optional[Dict[str, str]] = Field(
255
        default=None, description="The default headers for API requests."
256
    )
257
    reuse_client: bool = Field(
258
        default=True,
259
        description=(
260
            "Reuse the OpenAI client between requests. When doing anything with large "
261
            "volumes of async API calls, setting this to false can improve stability."
262
        ),
263
    )
264
    dimensions: Optional[int] = Field(
265
        default=None,
266
        description=(
267
            "The number of dimensions on the output embedding vectors. "
268
            "Works only with v3 embedding models."
269
        ),
270
    )
271

272
    _query_engine: OpenAIEmbeddingModeModel = PrivateAttr()
273
    _text_engine: OpenAIEmbeddingModeModel = PrivateAttr()
274
    _client: Optional[OpenAI] = PrivateAttr()
275
    _aclient: Optional[AsyncOpenAI] = PrivateAttr()
276
    _http_client: Optional[httpx.Client] = PrivateAttr()
277

278
    def __init__(
279
        self,
280
        mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
281
        model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
282
        embed_batch_size: int = 100,
283
        dimensions: Optional[int] = None,
284
        additional_kwargs: Optional[Dict[str, Any]] = None,
285
        api_key: Optional[str] = None,
286
        api_base: Optional[str] = None,
287
        api_version: Optional[str] = None,
288
        max_retries: int = 10,
289
        timeout: float = 60.0,
290
        reuse_client: bool = True,
291
        callback_manager: Optional[CallbackManager] = None,
292
        default_headers: Optional[Dict[str, str]] = None,
293
        http_client: Optional[httpx.Client] = None,
294
        **kwargs: Any,
295
    ) -> None:
296
        additional_kwargs = additional_kwargs or {}
297
        if dimensions is not None:
298
            additional_kwargs["dimensions"] = dimensions
299

300
        api_key, api_base, api_version = resolve_openai_credentials(
301
            api_key=api_key,
302
            api_base=api_base,
303
            api_version=api_version,
304
        )
305

306
        self._query_engine = get_engine(mode, model, _QUERY_MODE_MODEL_DICT)
307
        self._text_engine = get_engine(mode, model, _TEXT_MODE_MODEL_DICT)
308

309
        if "model_name" in kwargs:
310
            model_name = kwargs.pop("model_name")
311
            self._query_engine = self._text_engine = model_name
312
        else:
313
            model_name = model
314

315
        super().__init__(
316
            embed_batch_size=embed_batch_size,
317
            dimensions=dimensions,
318
            callback_manager=callback_manager,
319
            model_name=model_name,
320
            additional_kwargs=additional_kwargs,
321
            api_key=api_key,
322
            api_base=api_base,
323
            api_version=api_version,
324
            max_retries=max_retries,
325
            reuse_client=reuse_client,
326
            timeout=timeout,
327
            default_headers=default_headers,
328
            **kwargs,
329
        )
330

331
        self._client = None
332
        self._aclient = None
333
        self._http_client = http_client
334

335
    def _get_client(self) -> OpenAI:
336
        if not self.reuse_client:
337
            return OpenAI(**self._get_credential_kwargs())
338

339
        if self._client is None:
340
            self._client = OpenAI(**self._get_credential_kwargs())
341
        return self._client
342

343
    def _get_aclient(self) -> AsyncOpenAI:
344
        if not self.reuse_client:
345
            return AsyncOpenAI(**self._get_credential_kwargs())
346

347
        if self._aclient is None:
348
            self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
349
        return self._aclient
350

351
    @classmethod
352
    def class_name(cls) -> str:
353
        return "OpenAIEmbedding"
354

355
    def _get_credential_kwargs(self) -> Dict[str, Any]:
356
        return {
357
            "api_key": self.api_key,
358
            "base_url": self.api_base,
359
            "max_retries": self.max_retries,
360
            "timeout": self.timeout,
361
            "default_headers": self.default_headers,
362
            "http_client": self._http_client,
363
        }
364

365
    def _get_query_embedding(self, query: str) -> List[float]:
366
        """Get query embedding."""
367
        client = self._get_client()
368
        return get_embedding(
369
            client,
370
            query,
371
            engine=self._query_engine,
372
            **self.additional_kwargs,
373
        )
374

375
    async def _aget_query_embedding(self, query: str) -> List[float]:
376
        """The asynchronous version of _get_query_embedding."""
377
        aclient = self._get_aclient()
378
        return await aget_embedding(
379
            aclient,
380
            query,
381
            engine=self._query_engine,
382
            **self.additional_kwargs,
383
        )
384

385
    def _get_text_embedding(self, text: str) -> List[float]:
386
        """Get text embedding."""
387
        client = self._get_client()
388
        return get_embedding(
389
            client,
390
            text,
391
            engine=self._text_engine,
392
            **self.additional_kwargs,
393
        )
394

395
    async def _aget_text_embedding(self, text: str) -> List[float]:
396
        """Asynchronously get text embedding."""
397
        aclient = self._get_aclient()
398
        return await aget_embedding(
399
            aclient,
400
            text,
401
            engine=self._text_engine,
402
            **self.additional_kwargs,
403
        )
404

405
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
406
        """Get text embeddings.
407

408
        By default, this is a wrapper around _get_text_embedding.
409
        Can be overridden for batch queries.
410

411
        """
412
        client = self._get_client()
413
        return get_embeddings(
414
            client,
415
            texts,
416
            engine=self._text_engine,
417
            **self.additional_kwargs,
418
        )
419

420
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
421
        """Asynchronously get text embeddings."""
422
        aclient = self._get_aclient()
423
        return await aget_embeddings(
424
            aclient,
425
            texts,
426
            engine=self._text_engine,
427
            **self.additional_kwargs,
428
        )
429

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.