llama-index

Форк
0
301 строка · 9.2 Кб
1
from typing import Any, Dict, List, Optional
2

3
import httpx
4
from openai import AsyncOpenAI, OpenAI
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.callbacks import CallbackManager
8
from llama_index.legacy.callbacks.base import CallbackManager
9
from llama_index.legacy.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
10
from llama_index.legacy.llms.anyscale_utils import (
11
    resolve_anyscale_credentials,
12
)
13
from llama_index.legacy.llms.openai_utils import create_retry_decorator
14

15
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
16
DEFAULT_MODEL = "thenlper/gte-large"
17

18
embedding_retry_decorator = create_retry_decorator(
19
    max_retries=6,
20
    random_exponential=True,
21
    stop_after_delay_seconds=60,
22
    min_seconds=1,
23
    max_seconds=20,
24
)
25

26

27
@embedding_retry_decorator
28
def get_embedding(client: OpenAI, text: str, engine: str, **kwargs: Any) -> List[float]:
29
    """
30
    Get embedding.
31

32
    NOTE: Copied from OpenAI's embedding utils:
33
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
34

35
    Copied here to avoid importing unnecessary dependencies
36
    like matplotlib, plotly, scipy, sklearn.
37

38
    """
39
    text = text.replace("\n", " ")
40

41
    return (
42
        client.embeddings.create(input=[text], model=engine, **kwargs).data[0].embedding
43
    )
44

45

46
@embedding_retry_decorator
47
async def aget_embedding(
48
    aclient: AsyncOpenAI, text: str, engine: str, **kwargs: Any
49
) -> List[float]:
50
    """
51
    Asynchronously get embedding.
52

53
    NOTE: Copied from OpenAI's embedding utils:
54
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
55

56
    Copied here to avoid importing unnecessary dependencies
57
    like matplotlib, plotly, scipy, sklearn.
58

59
    """
60
    text = text.replace("\n", " ")
61

62
    return (
63
        (await aclient.embeddings.create(input=[text], model=engine, **kwargs))
64
        .data[0]
65
        .embedding
66
    )
67

68

69
@embedding_retry_decorator
70
def get_embeddings(
71
    client: OpenAI, list_of_text: List[str], engine: str, **kwargs: Any
72
) -> List[List[float]]:
73
    """
74
    Get embeddings.
75

76
    NOTE: Copied from OpenAI's embedding utils:
77
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
78

79
    Copied here to avoid importing unnecessary dependencies
80
    like matplotlib, plotly, scipy, sklearn.
81

82
    """
83
    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
84

85
    list_of_text = [text.replace("\n", " ") for text in list_of_text]
86

87
    data = client.embeddings.create(input=list_of_text, model=engine, **kwargs).data
88
    return [d.embedding for d in data]
89

90

91
@embedding_retry_decorator
92
async def aget_embeddings(
93
    aclient: AsyncOpenAI,
94
    list_of_text: List[str],
95
    engine: str,
96
    **kwargs: Any,
97
) -> List[List[float]]:
98
    """
99
    Asynchronously get embeddings.
100

101
    NOTE: Copied from OpenAI's embedding utils:
102
    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
103

104
    Copied here to avoid importing unnecessary dependencies
105
    like matplotlib, plotly, scipy, sklearn.
106

107
    """
108
    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
109

110
    list_of_text = [text.replace("\n", " ") for text in list_of_text]
111

112
    data = (
113
        await aclient.embeddings.create(input=list_of_text, model=engine, **kwargs)
114
    ).data
115
    return [d.embedding for d in data]
116

117

118
class AnyscaleEmbedding(BaseEmbedding):
119
    """
120
    Anyscale class for embeddings.
121

122
    Args:
123
        model (str): Model for embedding.
124
            Defaults to "thenlper/gte-large"
125
    """
126

127
    additional_kwargs: Dict[str, Any] = Field(
128
        default_factory=dict, description="Additional kwargs for the OpenAI API."
129
    )
130

131
    api_key: str = Field(description="The Anyscale API key.")
132
    api_base: str = Field(description="The base URL for Anyscale API.")
133
    api_version: str = Field(description="The version for OpenAI API.")
134

135
    max_retries: int = Field(
136
        default=10, description="Maximum number of retries.", gte=0
137
    )
138
    timeout: float = Field(default=60.0, description="Timeout for each request.", gte=0)
139
    default_headers: Optional[Dict[str, str]] = Field(
140
        default=None, description="The default headers for API requests."
141
    )
142
    reuse_client: bool = Field(
143
        default=True,
144
        description=(
145
            "Reuse the Anyscale client between requests. When doing anything with large "
146
            "volumes of async API calls, setting this to false can improve stability."
147
        ),
148
    )
149

150
    _query_engine: Optional[str] = PrivateAttr()
151
    _text_engine: Optional[str] = PrivateAttr()
152
    _client: Optional[OpenAI] = PrivateAttr()
153
    _aclient: Optional[AsyncOpenAI] = PrivateAttr()
154
    _http_client: Optional[httpx.Client] = PrivateAttr()
155

156
    def __init__(
157
        self,
158
        model: str = DEFAULT_MODEL,
159
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
160
        additional_kwargs: Optional[Dict[str, Any]] = None,
161
        api_key: Optional[str] = None,
162
        api_base: Optional[str] = DEFAULT_API_BASE,
163
        api_version: Optional[str] = None,
164
        max_retries: int = 10,
165
        timeout: float = 60.0,
166
        reuse_client: bool = True,
167
        callback_manager: Optional[CallbackManager] = None,
168
        default_headers: Optional[Dict[str, str]] = None,
169
        http_client: Optional[httpx.Client] = None,
170
        **kwargs: Any,
171
    ) -> None:
172
        additional_kwargs = additional_kwargs or {}
173

174
        api_key, api_base, api_version = resolve_anyscale_credentials(
175
            api_key=api_key,
176
            api_base=api_base,
177
            api_version=api_version,
178
        )
179

180
        if "model_name" in kwargs:
181
            model_name = kwargs.pop("model_name")
182
        else:
183
            model_name = model
184

185
        self._query_engine = model_name
186
        self._text_engine = model_name
187

188
        super().__init__(
189
            embed_batch_size=embed_batch_size,
190
            callback_manager=callback_manager,
191
            model_name=model_name,
192
            additional_kwargs=additional_kwargs,
193
            api_key=api_key,
194
            api_base=api_base,
195
            api_version=api_version,
196
            max_retries=max_retries,
197
            reuse_client=reuse_client,
198
            timeout=timeout,
199
            default_headers=default_headers,
200
            **kwargs,
201
        )
202

203
        self._client = None
204
        self._aclient = None
205
        self._http_client = http_client
206

207
    def _get_client(self) -> OpenAI:
208
        if not self.reuse_client:
209
            return OpenAI(**self._get_credential_kwargs())
210

211
        if self._client is None:
212
            self._client = OpenAI(**self._get_credential_kwargs())
213
        return self._client
214

215
    def _get_aclient(self) -> AsyncOpenAI:
216
        if not self.reuse_client:
217
            return AsyncOpenAI(**self._get_credential_kwargs())
218

219
        if self._aclient is None:
220
            self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
221
        return self._aclient
222

223
    @classmethod
224
    def class_name(cls) -> str:
225
        return "AnyscaleEmbedding"
226

227
    def _get_credential_kwargs(self) -> Dict[str, Any]:
228
        return {
229
            "api_key": self.api_key,
230
            "base_url": self.api_base,
231
            "max_retries": self.max_retries,
232
            "timeout": self.timeout,
233
            "default_headers": self.default_headers,
234
            "http_client": self._http_client,
235
        }
236

237
    def _get_query_embedding(self, query: str) -> List[float]:
238
        """Get query embedding."""
239
        client = self._get_client()
240
        return get_embedding(
241
            client,
242
            query,
243
            engine=self._query_engine,
244
            **self.additional_kwargs,
245
        )
246

247
    async def _aget_query_embedding(self, query: str) -> List[float]:
248
        """The asynchronous version of _get_query_embedding."""
249
        aclient = self._get_aclient()
250
        return await aget_embedding(
251
            aclient,
252
            query,
253
            engine=self._query_engine,
254
            **self.additional_kwargs,
255
        )
256

257
    def _get_text_embedding(self, text: str) -> List[float]:
258
        """Get text embedding."""
259
        client = self._get_client()
260
        return get_embedding(
261
            client,
262
            text,
263
            engine=self._text_engine,
264
            **self.additional_kwargs,
265
        )
266

267
    async def _aget_text_embedding(self, text: str) -> List[float]:
268
        """Asynchronously get text embedding."""
269
        aclient = self._get_aclient()
270
        return await aget_embedding(
271
            aclient,
272
            text,
273
            engine=self._text_engine,
274
            **self.additional_kwargs,
275
        )
276

277
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
278
        """
279
        Get text embeddings.
280

281
        By default, this is a wrapper around _get_text_embedding.
282
        Can be overridden for batch queries.
283

284
        """
285
        client = self._get_client()
286
        return get_embeddings(
287
            client,
288
            texts,
289
            engine=self._text_engine,
290
            **self.additional_kwargs,
291
        )
292

293
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
294
        """Asynchronously get text embeddings."""
295
        aclient = self._get_aclient()
296
        return await aget_embeddings(
297
            aclient,
298
            texts,
299
            engine=self._text_engine,
300
            **self.additional_kwargs,
301
        )
302

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.