llama-index
137 строк · 5.0 Кб
1import logging2from typing import Any, List, Optional3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.core.embeddings.base import (6DEFAULT_EMBED_BATCH_SIZE,7BaseEmbedding,8Embedding,9)
10
11logger = logging.getLogger(__name__)12
13
14# For bge models that Gradient AI provides, it is suggested to add the instruction for retrieval.
15# Reference: https://huggingface.co/BAAI/bge-large-en-v1.5#model-list
16QUERY_INSTRUCTION_FOR_RETRIEVAL = (17"Represent this sentence for searching relevant passages:"18)
19
20GRADIENT_EMBED_BATCH_SIZE: int = 32_76821
22
23class GradientEmbedding(BaseEmbedding):24"""GradientAI embedding models.25
26This class provides an interface to generate embeddings using a model
27deployed in Gradient AI. At the initialization it requires a model_id
28of the model deployed in the cluster.
29
30Note:
31Requires `gradientai` package to be available in the PYTHONPATH. It can be installed with
32`pip install gradientai`.
33"""
34
35embed_batch_size: int = Field(default=GRADIENT_EMBED_BATCH_SIZE, gt=0)36
37_gradient: Any = PrivateAttr()38_model: Any = PrivateAttr()39
40@classmethod41def class_name(cls) -> str:42return "GradientEmbedding"43
44def __init__(45self,46*,47embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,48gradient_model_slug: str,49gradient_access_token: Optional[str] = None,50gradient_workspace_id: Optional[str] = None,51gradient_host: Optional[str] = None,52**kwargs: Any,53):54"""Initializes the GradientEmbedding class.55
56During the initialization the `gradientai` package is imported. Using the access token,
57workspace id and the slug of the model, the model is fetched from Gradient AI and prepared to use.
58
59Args:
60embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10,
61must be > 0 and <= 100.
62gradient_model_slug (str): The model slug of the model in the Gradient AI account.
63gradient_access_token (str, optional): The access token of the Gradient AI account, if
64`None` read from the environment variable `GRADIENT_ACCESS_TOKEN`.
65gradient_workspace_id (str, optional): The workspace ID of the Gradient AI account, if `None`
66read from the environment variable `GRADIENT_WORKSPACE_ID`.
67gradient_host (str, optional): The host of the Gradient AI API. Defaults to None, which
68means the default host is used.
69
70Raises:
71ImportError: If the `gradientai` package is not available in the PYTHONPATH.
72ValueError: If the model cannot be fetched from Gradient AI.
73"""
74if embed_batch_size <= 0:75raise ValueError(f"Embed batch size {embed_batch_size} must be > 0.")76
77try:78import gradientai79except ImportError:80raise ImportError("GradientEmbedding requires `pip install gradientai`.")81
82self._gradient = gradientai.Gradient(83access_token=gradient_access_token,84workspace_id=gradient_workspace_id,85host=gradient_host,86)87
88try:89self._model = self._gradient.get_embeddings_model(slug=gradient_model_slug)90except gradientai.openapi.client.exceptions.UnauthorizedException as e:91logger.error(f"Error while loading model {gradient_model_slug}.")92self._gradient.close()93raise ValueError("Unable to fetch the requested embeddings model") from e94
95super().__init__(96embed_batch_size=embed_batch_size, model_name=gradient_model_slug, **kwargs97)98
99async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:100"""101Embed the input sequence of text asynchronously.
102"""
103inputs = [{"input": text} for text in texts]104
105result = await self._model.aembed(inputs=inputs).embeddings106
107return [e.embedding for e in result]108
109def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:110"""111Embed the input sequence of text.
112"""
113inputs = [{"input": text} for text in texts]114
115result = self._model.embed(inputs=inputs).embeddings116
117return [e.embedding for e in result]118
119def _get_text_embedding(self, text: str) -> Embedding:120"""Alias for _get_text_embeddings() with single text input."""121return self._get_text_embeddings([text])[0]122
123async def _aget_text_embedding(self, text: str) -> Embedding:124"""Alias for _aget_text_embeddings() with single text input."""125embedding = await self._aget_text_embeddings([text])126return embedding[0]127
128async def _aget_query_embedding(self, query: str) -> Embedding:129embedding = await self._aget_text_embeddings(130[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]131)132return embedding[0]133
134def _get_query_embedding(self, query: str) -> Embedding:135return self._get_text_embeddings(136[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]137)[0]138