llama-index

Форк
0
362 строки · 13.0 Кб
1
"""
2
Astra DB Vector store index.
3

4
An index based on a DB table with vector search capabilities,
5
powered by the astrapy library
6

7
"""
8

9
import json
10
import logging
11
from typing import Any, Dict, List, Optional, cast
12
from warnings import warn
13

14
from llama_index.legacy.bridge.pydantic import PrivateAttr
15
from llama_index.legacy.indices.query.embedding_utils import get_top_k_mmr_embeddings
16
from llama_index.legacy.schema import BaseNode, MetadataMode
17
from llama_index.legacy.vector_stores.types import (
18
    BasePydanticVectorStore,
19
    ExactMatchFilter,
20
    FilterOperator,
21
    MetadataFilter,
22
    MetadataFilters,
23
    VectorStoreQuery,
24
    VectorStoreQueryMode,
25
    VectorStoreQueryResult,
26
)
27
from llama_index.legacy.vector_stores.utils import (
28
    metadata_dict_to_node,
29
    node_to_metadata_dict,
30
)
31

32
_logger = logging.getLogger(__name__)
33

34
DEFAULT_MMR_PREFETCH_FACTOR = 4.0
35
MAX_INSERT_BATCH_SIZE = 20
36

37
NON_INDEXED_FIELDS = ["metadata._node_content", "content"]
38

39

40
class AstraDBVectorStore(BasePydanticVectorStore):
41
    """
42
    Astra DB Vector Store.
43

44
    An abstraction of a Astra table with
45
    vector-similarity-search. Documents, and their embeddings, are stored
46
    in an Astra table and a vector-capable index is used for searches.
47
    The table does not need to exist beforehand: if necessary it will
48
    be created behind the scenes.
49

50
    All Astra operations are done through the astrapy library.
51

52
    Args:
53
        collection_name (str): collection name to use. If not existing, it will be created.
54
        token (str): The Astra DB Application Token to use.
55
        api_endpoint (str): The Astra DB JSON API endpoint for your database.
56
        embedding_dimension (int): length of the embedding vectors in use.
57
        namespace (Optional[str]): The namespace to use. If not provided, 'default_keyspace'
58
        ttl_seconds (Optional[int]): expiration time for inserted entries.
59
            Default is no expiration.
60

61
    """
62

63
    stores_text: bool = True
64
    flat_metadata: bool = True
65

66
    _embedding_dimension: int = PrivateAttr()
67
    _ttl_seconds: Optional[int] = PrivateAttr()
68
    _astra_db: Any = PrivateAttr()
69
    _astra_db_collection: Any = PrivateAttr()
70

71
    def __init__(
72
        self,
73
        *,
74
        collection_name: str,
75
        token: str,
76
        api_endpoint: str,
77
        embedding_dimension: int,
78
        namespace: Optional[str] = None,
79
        ttl_seconds: Optional[int] = None,
80
    ) -> None:
81
        super().__init__()
82

83
        import_err_msg = (
84
            "`astrapy` package not found, please run `pip install --upgrade astrapy`"
85
        )
86

87
        # Try to import astrapy for use
88
        try:
89
            from astrapy.db import AstraDB
90
        except ImportError:
91
            raise ImportError(import_err_msg)
92

93
        # Set all the required class parameters
94
        self._embedding_dimension = embedding_dimension
95
        self._ttl_seconds = ttl_seconds
96

97
        _logger.debug("Creating the Astra DB table")
98

99
        # Build the Astra DB object
100
        self._astra_db = AstraDB(
101
            api_endpoint=api_endpoint, token=token, namespace=namespace
102
        )
103

104
        from astrapy.api import APIRequestError
105

106
        try:
107
            # Create and connect to the newly created collection
108
            self._astra_db_collection = self._astra_db.create_collection(
109
                collection_name=collection_name,
110
                dimension=embedding_dimension,
111
                options={"indexing": {"deny": NON_INDEXED_FIELDS}},
112
            )
113
        except APIRequestError as e:
114
            # possibly the collection is preexisting and has legacy
115
            # indexing settings: verify
116
            get_coll_response = self._astra_db.get_collections(
117
                options={"explain": True}
118
            )
119
            collections = (get_coll_response["status"] or {}).get("collections") or []
120
            preexisting = [
121
                collection
122
                for collection in collections
123
                if collection["name"] == collection_name
124
            ]
125
            if preexisting:
126
                pre_collection = preexisting[0]
127
                # if it has no "indexing", it is a legacy collection;
128
                # otherwise it's unexpected warn and proceed at user's risk
129
                pre_col_options = pre_collection.get("options") or {}
130
                if "indexing" not in pre_col_options:
131
                    warn(
132
                        (
133
                            f"Collection '{collection_name}' is detected as legacy"
134
                            " and has indexing turned on for all fields. This"
135
                            " implies stricter limitations on the amount of text"
136
                            " each entry can store. Consider reindexing anew on a"
137
                            " fresh collection to be able to store longer texts."
138
                        ),
139
                        UserWarning,
140
                        stacklevel=2,
141
                    )
142
                    self._astra_db_collection = self._astra_db.collection(
143
                        collection_name=collection_name,
144
                    )
145
                else:
146
                    options_json = json.dumps(pre_col_options["indexing"])
147
                    warn(
148
                        (
149
                            f"Collection '{collection_name}' has unexpected 'indexing'"
150
                            f" settings (options.indexing = {options_json})."
151
                            " This can result in odd behaviour when running "
152
                            " metadata filtering and/or unwarranted limitations"
153
                            " on storing long texts. Consider reindexing anew on a"
154
                            " fresh collection."
155
                        ),
156
                        UserWarning,
157
                        stacklevel=2,
158
                    )
159
                    self._astra_db_collection = self._astra_db.collection(
160
                        collection_name=collection_name,
161
                    )
162
            else:
163
                # other exception
164
                raise
165

166
    def add(
167
        self,
168
        nodes: List[BaseNode],
169
        **add_kwargs: Any,
170
    ) -> List[str]:
171
        """
172
        Add nodes to index.
173

174
        Args:
175
            nodes: List[BaseNode]: list of node with embeddings
176

177
        """
178
        # Initialize list of objects to track
179
        nodes_list = []
180

181
        # Process each node individually
182
        for node in nodes:
183
            # Get the metadata
184
            metadata = node_to_metadata_dict(
185
                node,
186
                remove_text=True,
187
                flat_metadata=self.flat_metadata,
188
            )
189

190
            # One dictionary of node data per node
191
            nodes_list.append(
192
                {
193
                    "_id": node.node_id,
194
                    "content": node.get_content(metadata_mode=MetadataMode.NONE),
195
                    "metadata": metadata,
196
                    "$vector": node.get_embedding(),
197
                }
198
            )
199

200
        # Log the number of rows being added
201
        _logger.debug(f"Adding {len(nodes_list)} rows to table")
202

203
        # Initialize an empty list to hold the batches
204
        batched_list = []
205

206
        # Iterate over the node_list in steps of MAX_INSERT_BATCH_SIZE
207
        for i in range(0, len(nodes_list), MAX_INSERT_BATCH_SIZE):
208
            # Append a slice of node_list to the batched_list
209
            batched_list.append(nodes_list[i : i + MAX_INSERT_BATCH_SIZE])
210

211
        # Perform the bulk insert
212
        for i, batch in enumerate(batched_list):
213
            _logger.debug(f"Processing batch #{i + 1} of size {len(batch)}")
214

215
            # Go to astrapy to perform the bulk insert
216
            self._astra_db_collection.insert_many(batch)
217

218
        # Return the list of ids
219
        return [str(n["_id"]) for n in nodes_list]
220

221
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
222
        """
223
        Delete nodes using with ref_doc_id.
224

225
        Args:
226
            ref_doc_id (str): The id of the document to delete.
227

228
        """
229
        _logger.debug("Deleting a document from the Astra table")
230

231
        self._astra_db_collection.delete(id=ref_doc_id, **delete_kwargs)
232

233
    @property
234
    def client(self) -> Any:
235
        """Return the underlying Astra vector table object."""
236
        return self._astra_db_collection
237

238
    @staticmethod
239
    def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
240
        # Allow only legacy ExactMatchFilter and MetadataFilter with FilterOperator.EQ
241
        if not all(
242
            (
243
                isinstance(f, ExactMatchFilter)
244
                or (isinstance(f, MetadataFilter) and f.operator == FilterOperator.EQ)
245
            )
246
            for f in query_filters.filters
247
        ):
248
            raise NotImplementedError(
249
                "Only filters with operator=FilterOperator.EQ are supported"
250
            )
251
        return {f"metadata.{f.key}": f.value for f in query_filters.filters}
252

253
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
254
        """Query index for top k most similar nodes."""
255
        # Get the currently available query modes
256
        _available_query_modes = [
257
            VectorStoreQueryMode.DEFAULT,
258
            VectorStoreQueryMode.MMR,
259
        ]
260

261
        # Reject query if not available
262
        if query.mode not in _available_query_modes:
263
            raise NotImplementedError(f"Query mode {query.mode} not available.")
264

265
        # Get the query embedding
266
        query_embedding = cast(List[float], query.query_embedding)
267

268
        # Process the metadata filters as needed
269
        if query.filters is not None:
270
            query_metadata = self._query_filters_to_dict(query.filters)
271
        else:
272
            query_metadata = {}
273

274
        # Get the scores depending on the query mode
275
        if query.mode == VectorStoreQueryMode.DEFAULT:
276
            # Call the vector_find method of AstraPy
277
            matches = self._astra_db_collection.vector_find(
278
                vector=query_embedding,
279
                limit=query.similarity_top_k,
280
                filter=query_metadata,
281
            )
282

283
            # Get the scores associated with each
284
            top_k_scores = [match["$similarity"] for match in matches]
285
        elif query.mode == VectorStoreQueryMode.MMR:
286
            # Querying a larger number of vectors and then doing MMR on them.
287
            if (
288
                kwargs.get("mmr_prefetch_factor") is not None
289
                and kwargs.get("mmr_prefetch_k") is not None
290
            ):
291
                raise ValueError(
292
                    "'mmr_prefetch_factor' and 'mmr_prefetch_k' "
293
                    "cannot coexist in a call to query()"
294
                )
295
            else:
296
                if kwargs.get("mmr_prefetch_k") is not None:
297
                    prefetch_k0 = int(kwargs["mmr_prefetch_k"])
298
                else:
299
                    prefetch_k0 = int(
300
                        query.similarity_top_k
301
                        * kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
302
                    )
303
            # Get the most we can possibly need to fetch
304
            prefetch_k = max(prefetch_k0, query.similarity_top_k)
305

306
            # Call AstraPy to fetch them
307
            prefetch_matches = self._astra_db_collection.vector_find(
308
                vector=query_embedding,
309
                limit=prefetch_k,
310
                filter=query_metadata,
311
            )
312

313
            # Get the MMR threshold
314
            mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")
315

316
            # If we have found documents, we can proceed
317
            if prefetch_matches:
318
                zipped_indices, zipped_embeddings = zip(
319
                    *enumerate(match["$vector"] for match in prefetch_matches)
320
                )
321
                pf_match_indices, pf_match_embeddings = list(zipped_indices), list(
322
                    zipped_embeddings
323
                )
324
            else:
325
                pf_match_indices, pf_match_embeddings = [], []
326

327
            # Call the Llama utility function to get the top  k
328
            mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
329
                query_embedding,
330
                pf_match_embeddings,
331
                similarity_top_k=query.similarity_top_k,
332
                embedding_ids=pf_match_indices,
333
                mmr_threshold=mmr_threshold,
334
            )
335

336
            # Finally, build the final results based on the mmr values
337
            matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
338
            top_k_scores = mmr_similarities
339

340
        # We have three lists to return
341
        top_k_nodes = []
342
        top_k_ids = []
343

344
        # Get every match
345
        for match in matches:
346
            # Check whether we have a llama-generated node content field
347
            if "_node_content" not in match["metadata"]:
348
                match["metadata"]["_node_content"] = json.dumps(match)
349

350
            # Create a new node object from the node metadata
351
            node = metadata_dict_to_node(match["metadata"], text=match["content"])
352

353
            # Append to the respective lists
354
            top_k_nodes.append(node)
355
            top_k_ids.append(match["_id"])
356

357
        # return our final result
358
        return VectorStoreQueryResult(
359
            nodes=top_k_nodes,
360
            similarities=top_k_scores,
361
            ids=top_k_ids,
362
        )
363

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.