llama-index
301 строка · 9.2 Кб
1from typing import Any, Dict, List, Optional
2
3import httpx
4from openai import AsyncOpenAI, OpenAI
5
6from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7from llama_index.legacy.callbacks import CallbackManager
8from llama_index.legacy.callbacks.base import CallbackManager
9from llama_index.legacy.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
10from llama_index.legacy.llms.anyscale_utils import (
11resolve_anyscale_credentials,
12)
13from llama_index.legacy.llms.openai_utils import create_retry_decorator
14
15DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
16DEFAULT_MODEL = "thenlper/gte-large"
17
18embedding_retry_decorator = create_retry_decorator(
19max_retries=6,
20random_exponential=True,
21stop_after_delay_seconds=60,
22min_seconds=1,
23max_seconds=20,
24)
25
26
27@embedding_retry_decorator
28def get_embedding(client: OpenAI, text: str, engine: str, **kwargs: Any) -> List[float]:
29"""
30Get embedding.
31
32NOTE: Copied from OpenAI's embedding utils:
33https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
34
35Copied here to avoid importing unnecessary dependencies
36like matplotlib, plotly, scipy, sklearn.
37
38"""
39text = text.replace("\n", " ")
40
41return (
42client.embeddings.create(input=[text], model=engine, **kwargs).data[0].embedding
43)
44
45
46@embedding_retry_decorator
47async def aget_embedding(
48aclient: AsyncOpenAI, text: str, engine: str, **kwargs: Any
49) -> List[float]:
50"""
51Asynchronously get embedding.
52
53NOTE: Copied from OpenAI's embedding utils:
54https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
55
56Copied here to avoid importing unnecessary dependencies
57like matplotlib, plotly, scipy, sklearn.
58
59"""
60text = text.replace("\n", " ")
61
62return (
63(await aclient.embeddings.create(input=[text], model=engine, **kwargs))
64.data[0]
65.embedding
66)
67
68
69@embedding_retry_decorator
70def get_embeddings(
71client: OpenAI, list_of_text: List[str], engine: str, **kwargs: Any
72) -> List[List[float]]:
73"""
74Get embeddings.
75
76NOTE: Copied from OpenAI's embedding utils:
77https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
78
79Copied here to avoid importing unnecessary dependencies
80like matplotlib, plotly, scipy, sklearn.
81
82"""
83assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
84
85list_of_text = [text.replace("\n", " ") for text in list_of_text]
86
87data = client.embeddings.create(input=list_of_text, model=engine, **kwargs).data
88return [d.embedding for d in data]
89
90
91@embedding_retry_decorator
92async def aget_embeddings(
93aclient: AsyncOpenAI,
94list_of_text: List[str],
95engine: str,
96**kwargs: Any,
97) -> List[List[float]]:
98"""
99Asynchronously get embeddings.
100
101NOTE: Copied from OpenAI's embedding utils:
102https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
103
104Copied here to avoid importing unnecessary dependencies
105like matplotlib, plotly, scipy, sklearn.
106
107"""
108assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
109
110list_of_text = [text.replace("\n", " ") for text in list_of_text]
111
112data = (
113await aclient.embeddings.create(input=list_of_text, model=engine, **kwargs)
114).data
115return [d.embedding for d in data]
116
117
118class AnyscaleEmbedding(BaseEmbedding):
119"""
120Anyscale class for embeddings.
121
122Args:
123model (str): Model for embedding.
124Defaults to "thenlper/gte-large"
125"""
126
127additional_kwargs: Dict[str, Any] = Field(
128default_factory=dict, description="Additional kwargs for the OpenAI API."
129)
130
131api_key: str = Field(description="The Anyscale API key.")
132api_base: str = Field(description="The base URL for Anyscale API.")
133api_version: str = Field(description="The version for OpenAI API.")
134
135max_retries: int = Field(
136default=10, description="Maximum number of retries.", gte=0
137)
138timeout: float = Field(default=60.0, description="Timeout for each request.", gte=0)
139default_headers: Optional[Dict[str, str]] = Field(
140default=None, description="The default headers for API requests."
141)
142reuse_client: bool = Field(
143default=True,
144description=(
145"Reuse the Anyscale client between requests. When doing anything with large "
146"volumes of async API calls, setting this to false can improve stability."
147),
148)
149
150_query_engine: Optional[str] = PrivateAttr()
151_text_engine: Optional[str] = PrivateAttr()
152_client: Optional[OpenAI] = PrivateAttr()
153_aclient: Optional[AsyncOpenAI] = PrivateAttr()
154_http_client: Optional[httpx.Client] = PrivateAttr()
155
156def __init__(
157self,
158model: str = DEFAULT_MODEL,
159embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
160additional_kwargs: Optional[Dict[str, Any]] = None,
161api_key: Optional[str] = None,
162api_base: Optional[str] = DEFAULT_API_BASE,
163api_version: Optional[str] = None,
164max_retries: int = 10,
165timeout: float = 60.0,
166reuse_client: bool = True,
167callback_manager: Optional[CallbackManager] = None,
168default_headers: Optional[Dict[str, str]] = None,
169http_client: Optional[httpx.Client] = None,
170**kwargs: Any,
171) -> None:
172additional_kwargs = additional_kwargs or {}
173
174api_key, api_base, api_version = resolve_anyscale_credentials(
175api_key=api_key,
176api_base=api_base,
177api_version=api_version,
178)
179
180if "model_name" in kwargs:
181model_name = kwargs.pop("model_name")
182else:
183model_name = model
184
185self._query_engine = model_name
186self._text_engine = model_name
187
188super().__init__(
189embed_batch_size=embed_batch_size,
190callback_manager=callback_manager,
191model_name=model_name,
192additional_kwargs=additional_kwargs,
193api_key=api_key,
194api_base=api_base,
195api_version=api_version,
196max_retries=max_retries,
197reuse_client=reuse_client,
198timeout=timeout,
199default_headers=default_headers,
200**kwargs,
201)
202
203self._client = None
204self._aclient = None
205self._http_client = http_client
206
207def _get_client(self) -> OpenAI:
208if not self.reuse_client:
209return OpenAI(**self._get_credential_kwargs())
210
211if self._client is None:
212self._client = OpenAI(**self._get_credential_kwargs())
213return self._client
214
215def _get_aclient(self) -> AsyncOpenAI:
216if not self.reuse_client:
217return AsyncOpenAI(**self._get_credential_kwargs())
218
219if self._aclient is None:
220self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
221return self._aclient
222
223@classmethod
224def class_name(cls) -> str:
225return "AnyscaleEmbedding"
226
227def _get_credential_kwargs(self) -> Dict[str, Any]:
228return {
229"api_key": self.api_key,
230"base_url": self.api_base,
231"max_retries": self.max_retries,
232"timeout": self.timeout,
233"default_headers": self.default_headers,
234"http_client": self._http_client,
235}
236
237def _get_query_embedding(self, query: str) -> List[float]:
238"""Get query embedding."""
239client = self._get_client()
240return get_embedding(
241client,
242query,
243engine=self._query_engine,
244**self.additional_kwargs,
245)
246
247async def _aget_query_embedding(self, query: str) -> List[float]:
248"""The asynchronous version of _get_query_embedding."""
249aclient = self._get_aclient()
250return await aget_embedding(
251aclient,
252query,
253engine=self._query_engine,
254**self.additional_kwargs,
255)
256
257def _get_text_embedding(self, text: str) -> List[float]:
258"""Get text embedding."""
259client = self._get_client()
260return get_embedding(
261client,
262text,
263engine=self._text_engine,
264**self.additional_kwargs,
265)
266
267async def _aget_text_embedding(self, text: str) -> List[float]:
268"""Asynchronously get text embedding."""
269aclient = self._get_aclient()
270return await aget_embedding(
271aclient,
272text,
273engine=self._text_engine,
274**self.additional_kwargs,
275)
276
277def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
278"""
279Get text embeddings.
280
281By default, this is a wrapper around _get_text_embedding.
282Can be overridden for batch queries.
283
284"""
285client = self._get_client()
286return get_embeddings(
287client,
288texts,
289engine=self._text_engine,
290**self.additional_kwargs,
291)
292
293async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
294"""Asynchronously get text embeddings."""
295aclient = self._get_aclient()
296return await aget_embeddings(
297aclient,
298texts,
299engine=self._text_engine,
300**self.additional_kwargs,
301)
302