llama-index

Форк
0
229 строк · 7.4 Кб
1
"""MongoDB Vector store index.
2

3
An index that is built on top of an existing vector store.
4

5
"""
6

7
import logging
8
import os
9
from typing import Any, Dict, List, Optional, cast
10

11
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
12
from llama_index.legacy.vector_stores.types import (
13
    MetadataFilters,
14
    VectorStore,
15
    VectorStoreQuery,
16
    VectorStoreQueryResult,
17
)
18
from llama_index.legacy.vector_stores.utils import (
19
    legacy_metadata_dict_to_node,
20
    metadata_dict_to_node,
21
    node_to_metadata_dict,
22
)
23

24
logger = logging.getLogger(__name__)
25

26

27
def _to_mongodb_filter(standard_filters: MetadataFilters) -> Dict:
28
    """Convert from standard dataclass to filter dict."""
29
    filters = {}
30
    for filter in standard_filters.legacy_filters():
31
        filters[filter.key] = filter.value
32
    return filters
33

34

35
class MongoDBAtlasVectorSearch(VectorStore):
36
    """MongoDB Atlas Vector Store.
37

38
    To use, you should have both:
39
    - the ``pymongo`` python package installed
40
    - a connection string associated with a MongoDB Atlas Cluster
41
    that has an Atlas Vector Search index
42

43
    """
44

45
    stores_text: bool = True
46
    flat_metadata: bool = True
47

48
    def __init__(
49
        self,
50
        mongodb_client: Optional[Any] = None,
51
        db_name: str = "default_db",
52
        collection_name: str = "default_collection",
53
        index_name: str = "default",
54
        id_key: str = "id",
55
        embedding_key: str = "embedding",
56
        text_key: str = "text",
57
        metadata_key: str = "metadata",
58
        insert_kwargs: Optional[Dict] = None,
59
        **kwargs: Any,
60
    ) -> None:
61
        """Initialize the vector store.
62

63
        Args:
64
            mongodb_client: A MongoDB client.
65
            db_name: A MongoDB database name.
66
            collection_name: A MongoDB collection name.
67
            index_name: A MongoDB Atlas Vector Search index name.
68
            id_key: The data field to use as the id.
69
            embedding_key: A MongoDB field that will contain
70
            the embedding for each document.
71
            text_key: A MongoDB field that will contain the text for each document.
72
            metadata_key: A MongoDB field that will contain
73
            the metadata for each document.
74
            insert_kwargs: The kwargs used during `insert`.
75
        """
76
        import_err_msg = "`pymongo` package not found, please run `pip install pymongo`"
77
        try:
78
            from importlib.metadata import version
79

80
            from pymongo import MongoClient
81
            from pymongo.driver_info import DriverInfo
82
        except ImportError:
83
            raise ImportError(import_err_msg)
84

85
        if mongodb_client is not None:
86
            self._mongodb_client = cast(MongoClient, mongodb_client)
87
        else:
88
            if "MONGO_URI" not in os.environ:
89
                raise ValueError(
90
                    "Must specify MONGO_URI via env variable "
91
                    "if not directly passing in client."
92
                )
93
            self._mongodb_client = MongoClient(
94
                os.environ["MONGO_URI"],
95
                driver=DriverInfo(name="llama-index", version=version("llama-index")),
96
            )
97

98
        self._collection = self._mongodb_client[db_name][collection_name]
99
        self._index_name = index_name
100
        self._embedding_key = embedding_key
101
        self._id_key = id_key
102
        self._text_key = text_key
103
        self._metadata_key = metadata_key
104
        self._insert_kwargs = insert_kwargs or {}
105

106
    def add(
107
        self,
108
        nodes: List[BaseNode],
109
        **add_kwargs: Any,
110
    ) -> List[str]:
111
        """Add nodes to index.
112

113
        Args:
114
            nodes: List[BaseNode]: list of nodes with embeddings
115

116
        Returns:
117
            A List of ids for successfully added nodes.
118

119
        """
120
        ids = []
121
        data_to_insert = []
122
        for node in nodes:
123
            metadata = node_to_metadata_dict(
124
                node, remove_text=True, flat_metadata=self.flat_metadata
125
            )
126

127
            entry = {
128
                self._id_key: node.node_id,
129
                self._embedding_key: node.get_embedding(),
130
                self._text_key: node.get_content(metadata_mode=MetadataMode.NONE) or "",
131
                self._metadata_key: metadata,
132
            }
133
            data_to_insert.append(entry)
134
            ids.append(node.node_id)
135
        logger.debug("Inserting data into MongoDB: %s", data_to_insert)
136
        insert_result = self._collection.insert_many(
137
            data_to_insert, **self._insert_kwargs
138
        )
139
        logger.debug("Result of insert: %s", insert_result)
140
        return ids
141

142
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
143
        """
144
        Delete nodes using with ref_doc_id.
145

146
        Args:
147
            ref_doc_id (str): The doc_id of the document to delete.
148

149
        """
150
        # delete by filtering on the doc_id metadata
151
        self._collection.delete_one(
152
            filter={self._metadata_key + ".ref_doc_id": ref_doc_id}, **delete_kwargs
153
        )
154

155
    @property
156
    def client(self) -> Any:
157
        """Return MongoDB client."""
158
        return self._mongodb_client
159

160
    def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
161
        params: Dict[str, Any] = {
162
            "queryVector": query.query_embedding,
163
            "path": self._embedding_key,
164
            "numCandidates": query.similarity_top_k * 10,
165
            "limit": query.similarity_top_k,
166
            "index": self._index_name,
167
        }
168
        if query.filters:
169
            params["filter"] = _to_mongodb_filter(query.filters)
170

171
        query_field = {"$vectorSearch": params}
172

173
        pipeline = [
174
            query_field,
175
            {
176
                "$project": {
177
                    "score": {"$meta": "vectorSearchScore"},
178
                    self._embedding_key: 0,
179
                }
180
            },
181
        ]
182
        logger.debug("Running query pipeline: %s", pipeline)
183
        cursor = self._collection.aggregate(pipeline)  # type: ignore
184
        top_k_nodes = []
185
        top_k_ids = []
186
        top_k_scores = []
187
        for res in cursor:
188
            text = res.pop(self._text_key)
189
            score = res.pop("score")
190
            id = res.pop(self._id_key)
191
            metadata_dict = res.pop(self._metadata_key)
192

193
            try:
194
                node = metadata_dict_to_node(metadata_dict)
195
                node.set_content(text)
196
            except Exception:
197
                # NOTE: deprecated legacy logic for backward compatibility
198
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
199
                    metadata_dict
200
                )
201

202
                node = TextNode(
203
                    text=text,
204
                    id_=id,
205
                    metadata=metadata,
206
                    start_char_idx=node_info.get("start", None),
207
                    end_char_idx=node_info.get("end", None),
208
                    relationships=relationships,
209
                )
210

211
            top_k_ids.append(id)
212
            top_k_nodes.append(node)
213
            top_k_scores.append(score)
214
        result = VectorStoreQueryResult(
215
            nodes=top_k_nodes, similarities=top_k_scores, ids=top_k_ids
216
        )
217
        logger.debug("Result of query: %s", result)
218
        return result
219

220
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
221
        """Query index for top k most similar nodes.
222

223
        Args:
224
            query: a VectorStoreQuery object.
225

226
        Returns:
227
            A VectorStoreQueryResult containing the results of the query.
228
        """
229
        return self._query(query)
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.