llama-index

Форк
0
341 строка · 12.8 Кб
1
"""Milvus vector store index.
2

3
An index that is built within Milvus.
4

5
"""
6

7
import logging
8
from typing import Any, Dict, List, Optional, Union
9

10
from llama_index.legacy.schema import BaseNode, TextNode
11
from llama_index.legacy.vector_stores.types import (
12
    MetadataFilters,
13
    VectorStore,
14
    VectorStoreQuery,
15
    VectorStoreQueryMode,
16
    VectorStoreQueryResult,
17
)
18
from llama_index.legacy.vector_stores.utils import (
19
    DEFAULT_DOC_ID_KEY,
20
    DEFAULT_EMBEDDING_KEY,
21
    metadata_dict_to_node,
22
    node_to_metadata_dict,
23
)
24

25
logger = logging.getLogger(__name__)
26

27
MILVUS_ID_FIELD = "id"
28

29

30
def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]:
31
    """Translate standard metadata filters to Milvus specific spec."""
32
    filters = []
33
    for filter in standard_filters.legacy_filters():
34
        if isinstance(filter.value, str):
35
            filters.append(str(filter.key) + " == " + '"' + str(filter.value) + '"')
36
        else:
37
            filters.append(str(filter.key) + " == " + str(filter.value))
38
    return filters
39

40

41
class MilvusVectorStore(VectorStore):
42
    """The Milvus Vector Store.
43

44
    In this vector store we store the text, its embedding and
45
    a its metadata in a Milvus collection. This implementation
46
    allows the use of an already existing collection.
47
    It also supports creating a new one if the collection doesn't
48
    exist or if `overwrite` is set to True.
49

50
    Args:
51
        uri (str, optional): The URI to connect to, comes in the form of
52
            "http://address:port".
53
        token (str, optional): The token for log in. Empty if not using rbac, if
54
            using rbac it will most likely be "username:password".
55
        collection_name (str, optional): The name of the collection where data will be
56
            stored. Defaults to "llamalection".
57
        dim (int, optional): The dimension of the embedding vectors for the collection.
58
            Required if creating a new collection.
59
        embedding_field (str, optional): The name of the embedding field for the
60
            collection, defaults to DEFAULT_EMBEDDING_KEY.
61
        doc_id_field (str, optional): The name of the doc_id field for the collection,
62
            defaults to DEFAULT_DOC_ID_KEY.
63
        similarity_metric (str, optional): The similarity metric to use,
64
            currently supports IP and L2.
65
        consistency_level (str, optional): Which consistency level to use for a newly
66
            created collection. Defaults to "Strong".
67
        overwrite (bool, optional): Whether to overwrite existing collection with same
68
            name. Defaults to False.
69
        text_key (str, optional): What key text is stored in in the passed collection.
70
            Used when bringing your own collection. Defaults to None.
71
        index_config (dict, optional): The configuration used for building the
72
            Milvus index. Defaults to None.
73
        search_config (dict, optional): The configuration used for searching
74
            the Milvus index. Note that this must be compatible with the index
75
            type specified by `index_config`. Defaults to None.
76

77
    Raises:
78
        ImportError: Unable to import `pymilvus`.
79
        MilvusException: Error communicating with Milvus, more can be found in logging
80
            under Debug.
81

82
    Returns:
83
        MilvusVectorstore: Vectorstore that supports add, delete, and query.
84
    """
85

86
    stores_text: bool = True
87
    stores_node: bool = True
88

89
    def __init__(
90
        self,
91
        uri: str = "http://localhost:19530",
92
        token: str = "",
93
        collection_name: str = "llamalection",
94
        dim: Optional[int] = None,
95
        embedding_field: str = DEFAULT_EMBEDDING_KEY,
96
        doc_id_field: str = DEFAULT_DOC_ID_KEY,
97
        similarity_metric: str = "IP",
98
        consistency_level: str = "Strong",
99
        overwrite: bool = False,
100
        text_key: Optional[str] = None,
101
        index_config: Optional[dict] = None,
102
        search_config: Optional[dict] = None,
103
        **kwargs: Any,
104
    ) -> None:
105
        """Init params."""
106
        import_err_msg = (
107
            "`pymilvus` package not found, please run `pip install pymilvus`"
108
        )
109
        try:
110
            import pymilvus  # noqa
111
        except ImportError:
112
            raise ImportError(import_err_msg)
113

114
        from pymilvus import Collection, MilvusClient
115

116
        self.collection_name = collection_name
117
        self.dim = dim
118
        self.embedding_field = embedding_field
119
        self.doc_id_field = doc_id_field
120
        self.consistency_level = consistency_level
121
        self.overwrite = overwrite
122
        self.text_key = text_key
123
        self.index_config: Dict[str, Any] = index_config.copy() if index_config else {}
124
        # Note: The search configuration is set at construction to avoid having
125
        # to change the API for usage of the vector store (i.e. to pass the
126
        # search config along with the rest of the query).
127
        self.search_config: Dict[str, Any] = (
128
            search_config.copy() if search_config else {}
129
        )
130

131
        # Select the similarity metric
132
        if similarity_metric.lower() in ("ip"):
133
            self.similarity_metric = "IP"
134
        elif similarity_metric.lower() in ("l2", "euclidean"):
135
            self.similarity_metric = "L2"
136

137
        # Connect to Milvus instance
138
        self.milvusclient = MilvusClient(
139
            uri=uri,
140
            token=token,
141
            **kwargs,  # pass additional arguments such as server_pem_path
142
        )
143

144
        # Delete previous collection if overwriting
145
        if self.overwrite and self.collection_name in self.client.list_collections():
146
            self.milvusclient.drop_collection(self.collection_name)
147

148
        # Create the collection if it does not exist
149
        if self.collection_name not in self.client.list_collections():
150
            if self.dim is None:
151
                raise ValueError("Dim argument required for collection creation.")
152
            self.milvusclient.create_collection(
153
                collection_name=self.collection_name,
154
                dimension=self.dim,
155
                primary_field_name=MILVUS_ID_FIELD,
156
                vector_field_name=self.embedding_field,
157
                id_type="string",
158
                metric_type=self.similarity_metric,
159
                max_length=65_535,
160
                consistency_level=self.consistency_level,
161
            )
162

163
        self.collection = Collection(
164
            self.collection_name, using=self.milvusclient._using
165
        )
166
        self._create_index_if_required()
167

168
        logger.debug(f"Successfully created a new collection: {self.collection_name}")
169

170
    @property
171
    def client(self) -> Any:
172
        """Get client."""
173
        return self.milvusclient
174

175
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
176
        """Add the embeddings and their nodes into Milvus.
177

178
        Args:
179
            nodes (List[BaseNode]): List of nodes with embeddings
180
                to insert.
181

182
        Raises:
183
            MilvusException: Failed to insert data.
184

185
        Returns:
186
            List[str]: List of ids inserted.
187
        """
188
        insert_list = []
189
        insert_ids = []
190

191
        # Process that data we are going to insert
192
        for node in nodes:
193
            entry = node_to_metadata_dict(node)
194
            entry[MILVUS_ID_FIELD] = node.node_id
195
            entry[self.embedding_field] = node.embedding
196

197
            insert_ids.append(node.node_id)
198
            insert_list.append(entry)
199

200
        # Insert the data into milvus
201
        self.collection.insert(insert_list)
202
        self.collection.flush()
203
        self._create_index_if_required()
204
        logger.debug(
205
            f"Successfully inserted embeddings into: {self.collection_name} "
206
            f"Num Inserted: {len(insert_list)}"
207
        )
208
        return insert_ids
209

210
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
211
        """
212
        Delete nodes using with ref_doc_id.
213

214
        Args:
215
            ref_doc_id (str): The doc_id of the document to delete.
216

217
        Raises:
218
            MilvusException: Failed to delete the doc.
219
        """
220
        # Adds ability for multiple doc delete in future.
221
        doc_ids: List[str]
222
        if isinstance(ref_doc_id, list):
223
            doc_ids = ref_doc_id  # type: ignore
224
        else:
225
            doc_ids = [ref_doc_id]
226

227
        # Begin by querying for the primary keys to delete
228
        doc_ids = ['"' + entry + '"' for entry in doc_ids]
229
        entries = self.milvusclient.query(
230
            collection_name=self.collection_name,
231
            filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]",
232
        )
233
        ids = [entry["id"] for entry in entries]
234
        self.milvusclient.delete(collection_name=self.collection_name, pks=ids)
235
        logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")
236

237
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
238
        """Query index for top k most similar nodes.
239

240
        Args:
241
            query_embedding (List[float]): query embedding
242
            similarity_top_k (int): top k most similar nodes
243
            doc_ids (Optional[List[str]]): list of doc_ids to filter by
244
            node_ids (Optional[List[str]]): list of node_ids to filter by
245
            output_fields (Optional[List[str]]): list of fields to return
246
            embedding_field (Optional[str]): name of embedding field
247
        """
248
        if query.mode != VectorStoreQueryMode.DEFAULT:
249
            raise ValueError(f"Milvus does not support {query.mode} yet.")
250

251
        expr = []
252
        output_fields = ["*"]
253

254
        # Parse the filter
255
        if query.filters is not None:
256
            expr.extend(_to_milvus_filter(query.filters))
257

258
        # Parse any docs we are filtering on
259
        if query.doc_ids is not None and len(query.doc_ids) != 0:
260
            expr_list = ['"' + entry + '"' for entry in query.doc_ids]
261
            expr.append(f"{self.doc_id_field} in [{','.join(expr_list)}]")
262

263
        # Parse any nodes we are filtering on
264
        if query.node_ids is not None and len(query.node_ids) != 0:
265
            expr_list = ['"' + entry + '"' for entry in query.node_ids]
266
            expr.append(f"{MILVUS_ID_FIELD} in [{','.join(expr_list)}]")
267

268
        # Limit output fields
269
        if query.output_fields is not None:
270
            output_fields = query.output_fields
271

272
        # Convert to string expression
273
        string_expr = ""
274
        if len(expr) != 0:
275
            string_expr = " and ".join(expr)
276

277
        # Perform the search
278
        res = self.milvusclient.search(
279
            collection_name=self.collection_name,
280
            data=[query.query_embedding],
281
            filter=string_expr,
282
            limit=query.similarity_top_k,
283
            output_fields=output_fields,
284
            search_params=self.search_config,
285
        )
286

287
        logger.debug(
288
            f"Successfully searched embedding in collection: {self.collection_name}"
289
            f" Num Results: {len(res[0])}"
290
        )
291

292
        nodes = []
293
        similarities = []
294
        ids = []
295

296
        # Parse the results
297
        for hit in res[0]:
298
            if not self.text_key:
299
                node = metadata_dict_to_node(
300
                    {
301
                        "_node_content": hit["entity"].get("_node_content", None),
302
                        "_node_type": hit["entity"].get("_node_type", None),
303
                    }
304
                )
305
            else:
306
                try:
307
                    text = hit["entity"].get(self.text_key)
308
                except Exception:
309
                    raise ValueError(
310
                        "The passed in text_key value does not exist "
311
                        "in the retrieved entity."
312
                    )
313
                node = TextNode(
314
                    text=text,
315
                )
316
            nodes.append(node)
317
            similarities.append(hit["distance"])
318
            ids.append(hit["id"])
319

320
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
321

322
    def _create_index_if_required(self, force: bool = False) -> None:
323
        # This helper method is introduced to allow the index to be created
324
        # both in the constructor and in the `add` method. The `force` flag is
325
        # provided to ensure that the index is created in the constructor even
326
        # if self.overwrite is false. In the `add` method, the index is
327
        # recreated only if self.overwrite is true.
328
        if (self.collection.has_index() and self.overwrite) or force:
329
            self.collection.release()
330
            self.collection.drop_index()
331
            base_params: Dict[str, Any] = self.index_config.copy()
332
            index_type: str = base_params.pop("index_type", "FLAT")
333
            index_params: Dict[str, Union[str, Dict[str, Any]]] = {
334
                "params": base_params,
335
                "metric_type": self.similarity_metric,
336
                "index_type": index_type,
337
            }
338
            self.collection.create_index(
339
                self.embedding_field, index_params=index_params
340
            )
341
            self.collection.load()
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.