llama-index

Форк
0
137 строк · 5.0 Кб
1
import logging
2
from typing import Any, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.core.embeddings.base import (
6
    DEFAULT_EMBED_BATCH_SIZE,
7
    BaseEmbedding,
8
    Embedding,
9
)
10

11
logger = logging.getLogger(__name__)
12

13

14
# For bge models that Gradient AI provides, it is suggested to add the instruction for retrieval.
15
# Reference: https://huggingface.co/BAAI/bge-large-en-v1.5#model-list
16
QUERY_INSTRUCTION_FOR_RETRIEVAL = (
17
    "Represent this sentence for searching relevant passages:"
18
)
19

20
GRADIENT_EMBED_BATCH_SIZE: int = 32_768
21

22

23
class GradientEmbedding(BaseEmbedding):
24
    """GradientAI embedding models.
25

26
    This class provides an interface to generate embeddings using a model
27
    deployed in Gradient AI. At the initialization it requires a model_id
28
    of the model deployed in the cluster.
29

30
    Note:
31
        Requires `gradientai` package to be available in the PYTHONPATH. It can be installed with
32
        `pip install gradientai`.
33
    """
34

35
    embed_batch_size: int = Field(default=GRADIENT_EMBED_BATCH_SIZE, gt=0)
36

37
    _gradient: Any = PrivateAttr()
38
    _model: Any = PrivateAttr()
39

40
    @classmethod
41
    def class_name(cls) -> str:
42
        return "GradientEmbedding"
43

44
    def __init__(
45
        self,
46
        *,
47
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
48
        gradient_model_slug: str,
49
        gradient_access_token: Optional[str] = None,
50
        gradient_workspace_id: Optional[str] = None,
51
        gradient_host: Optional[str] = None,
52
        **kwargs: Any,
53
    ):
54
        """Initializes the GradientEmbedding class.
55

56
        During the initialization the `gradientai` package is imported. Using the access token,
57
        workspace id and the slug of the model, the model is fetched from Gradient AI and prepared to use.
58

59
        Args:
60
            embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10,
61
                must be > 0 and <= 100.
62
            gradient_model_slug (str): The model slug of the model in the Gradient AI account.
63
            gradient_access_token (str, optional): The access token of the Gradient AI account, if
64
                `None` read from the environment variable `GRADIENT_ACCESS_TOKEN`.
65
            gradient_workspace_id (str, optional): The workspace ID of the Gradient AI account, if `None`
66
                read from the environment variable `GRADIENT_WORKSPACE_ID`.
67
            gradient_host (str, optional): The host of the Gradient AI API. Defaults to None, which
68
              means the default host is used.
69

70
        Raises:
71
            ImportError: If the `gradientai` package is not available in the PYTHONPATH.
72
            ValueError: If the model cannot be fetched from Gradient AI.
73
        """
74
        if embed_batch_size <= 0:
75
            raise ValueError(f"Embed batch size {embed_batch_size}  must be > 0.")
76

77
        try:
78
            import gradientai
79
        except ImportError:
80
            raise ImportError("GradientEmbedding requires `pip install gradientai`.")
81

82
        self._gradient = gradientai.Gradient(
83
            access_token=gradient_access_token,
84
            workspace_id=gradient_workspace_id,
85
            host=gradient_host,
86
        )
87

88
        try:
89
            self._model = self._gradient.get_embeddings_model(slug=gradient_model_slug)
90
        except gradientai.openapi.client.exceptions.UnauthorizedException as e:
91
            logger.error(f"Error while loading model {gradient_model_slug}.")
92
            self._gradient.close()
93
            raise ValueError("Unable to fetch the requested embeddings model") from e
94

95
        super().__init__(
96
            embed_batch_size=embed_batch_size, model_name=gradient_model_slug, **kwargs
97
        )
98

99
    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
100
        """
101
        Embed the input sequence of text asynchronously.
102
        """
103
        inputs = [{"input": text} for text in texts]
104

105
        result = await self._model.aembed(inputs=inputs).embeddings
106

107
        return [e.embedding for e in result]
108

109
    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
110
        """
111
        Embed the input sequence of text.
112
        """
113
        inputs = [{"input": text} for text in texts]
114

115
        result = self._model.embed(inputs=inputs).embeddings
116

117
        return [e.embedding for e in result]
118

119
    def _get_text_embedding(self, text: str) -> Embedding:
120
        """Alias for _get_text_embeddings() with single text input."""
121
        return self._get_text_embeddings([text])[0]
122

123
    async def _aget_text_embedding(self, text: str) -> Embedding:
124
        """Alias for _aget_text_embeddings() with single text input."""
125
        embedding = await self._aget_text_embeddings([text])
126
        return embedding[0]
127

128
    async def _aget_query_embedding(self, query: str) -> Embedding:
129
        embedding = await self._aget_text_embeddings(
130
            [f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
131
        )
132
        return embedding[0]
133

134
    def _get_query_embedding(self, query: str) -> Embedding:
135
        return self._get_text_embeddings(
136
            [f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
137
        )[0]
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.