llama-index

Форк
0
179 строк · 6.1 Кб
1
from typing import Any, List
2

3
from llama_index.legacy.bridge.pydantic import PrivateAttr
4
from llama_index.legacy.embeddings.base import BaseEmbedding
5

6

7
class ElasticsearchEmbedding(BaseEmbedding):
8
    """Elasticsearch embedding models.
9

10
    This class provides an interface to generate embeddings using a model deployed
11
    in an Elasticsearch cluster. It requires an Elasticsearch connection object
12
    and the model_id of the model deployed in the cluster.
13

14
    In Elasticsearch you need to have an embedding model loaded and deployed.
15
    - https://www.elastic.co
16
        /guide/en/elasticsearch/reference/current/infer-trained-model.html
17
    - https://www.elastic.co
18
        /guide/en/machine-learning/current/ml-nlp-deploy-models.html
19
    """  #
20

21
    _client: Any = PrivateAttr()
22
    model_id: str
23
    input_field: str
24

25
    @classmethod
26
    def class_name(self) -> str:
27
        return "ElasticsearchEmbedding"
28

29
    def __init__(
30
        self,
31
        client: Any,
32
        model_id: str,
33
        input_field: str = "text_field",
34
        **kwargs: Any,
35
    ):
36
        self._client = client
37
        super().__init__(model_id=model_id, input_field=input_field, **kwargs)
38

39
    @classmethod
40
    def from_es_connection(
41
        cls,
42
        model_id: str,
43
        es_connection: Any,
44
        input_field: str = "text_field",
45
    ) -> BaseEmbedding:
46
        """
47
        Instantiate embeddings from an existing Elasticsearch connection.
48

49
        This method provides a way to create an instance of the ElasticsearchEmbedding
50
        class using an existing Elasticsearch connection. The connection object is used
51
        to create an MlClient, which is then used to initialize the
52
        ElasticsearchEmbedding instance.
53

54
        Args:
55
        model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
56
        es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
57
            connection object.
58
        input_field (str, optional): The name of the key for the input text field
59
            in the document. Defaults to 'text_field'.
60

61
        Returns:
62
        ElasticsearchEmbedding: An instance of the ElasticsearchEmbedding class.
63

64
        Example:
65
            .. code-block:: python
66

67
                from elasticsearch import Elasticsearch
68

69
                from llama_index.legacy.embeddings import ElasticsearchEmbedding
70

71
                # Define the model ID and input field name (if different from default)
72
                model_id = "your_model_id"
73
                # Optional, only if different from 'text_field'
74
                input_field = "your_input_field"
75

76
                # Create Elasticsearch connection
77
                es_connection = Elasticsearch(hosts=["localhost:9200"], basic_auth=("user", "password"))
78

79
                # Instantiate ElasticsearchEmbedding using the existing connection
80
                embeddings = ElasticsearchEmbedding.from_es_connection(
81
                    model_id,
82
                    es_connection,
83
                    input_field=input_field,
84
                )
85
        """
86
        try:
87
            from elasticsearch.client import MlClient
88
        except ImportError:
89
            raise ImportError(
90
                "elasticsearch package not found, install with"
91
                "'pip install elasticsearch'"
92
            )
93

94
        client = MlClient(es_connection)
95
        return cls(client, model_id, input_field=input_field)
96

97
    @classmethod
98
    def from_credentials(
99
        cls,
100
        model_id: str,
101
        es_url: str,
102
        es_username: str,
103
        es_password: str,
104
        input_field: str = "text_field",
105
    ) -> BaseEmbedding:
106
        """Instantiate embeddings from Elasticsearch credentials.
107

108
        Args:
109
            model_id (str): The model_id of the model deployed in the Elasticsearch
110
                cluster.
111
            input_field (str): The name of the key for the input text field in the
112
                document. Defaults to 'text_field'.
113
            es_url: (str): The Elasticsearch url to connect to.
114
            es_username: (str): Elasticsearch username.
115
            es_password: (str): Elasticsearch password.
116

117
        Example:
118
            .. code-block:: python
119

120
                from llama_index.legacy.embeddings import ElasticsearchEmbedding
121

122
                # Define the model ID and input field name (if different from default)
123
                model_id = "your_model_id"
124
                # Optional, only if different from 'text_field'
125
                input_field = "your_input_field"
126

127
                embeddings = ElasticsearchEmbedding.from_credentials(
128
                    model_id,
129
                    input_field=input_field,
130
                    es_url="foo",
131
                    es_username="bar",
132
                    es_password="baz",
133
                )
134
        """
135
        try:
136
            from elasticsearch import Elasticsearch
137
            from elasticsearch.client import MlClient
138
        except ImportError:
139
            raise ImportError(
140
                "elasticsearch package not found, install with"
141
                "'pip install elasticsearch'"
142
            )
143

144
        es_connection = Elasticsearch(
145
            hosts=[es_url],
146
            basic_auth=(es_username, es_password),
147
        )
148

149
        client = MlClient(es_connection)
150
        return cls(client, model_id, input_field=input_field)
151

152
    def _get_embedding(self, text: str) -> List[float]:
153
        """
154
        Generate an embedding for a single query text.
155

156
        Args:
157
            text (str): The query text to generate an embedding for.
158

159
        Returns:
160
            List[float]: The embedding for the input query text.
161
        """
162
        response = self._client.infer_trained_model(
163
            model_id=self.model_id,
164
            docs=[{self.input_field: text}],
165
        )
166

167
        return response["inference_results"][0]["predicted_value"]
168

169
    def _get_text_embedding(self, text: str) -> List[float]:
170
        return self._get_embedding(text)
171

172
    def _get_query_embedding(self, query: str) -> List[float]:
173
        return self._get_embedding(query)
174

175
    async def _aget_query_embedding(self, query: str) -> List[float]:
176
        return self._get_query_embedding(query)
177

178

179
ElasticsearchEmbeddings = ElasticsearchEmbedding
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.