llama-index

Форк
0
318 строк · 11.7 Кб
1
"""Cassandra / Astra DB Vector store index.
2

3
An index based on a DB table with vector search capabilities,
4
powered by the cassIO library
5

6
"""
7

8
import logging
9
from typing import Any, Dict, Iterable, List, Optional, TypeVar, cast
10

11
from llama_index.legacy.indices.query.embedding_utils import (
12
    get_top_k_mmr_embeddings,
13
)
14
from llama_index.legacy.schema import BaseNode, MetadataMode
15
from llama_index.legacy.vector_stores.types import (
16
    ExactMatchFilter,
17
    MetadataFilters,
18
    VectorStore,
19
    VectorStoreQuery,
20
    VectorStoreQueryMode,
21
    VectorStoreQueryResult,
22
)
23
from llama_index.legacy.vector_stores.utils import (
24
    metadata_dict_to_node,
25
    node_to_metadata_dict,
26
)
27

28
_logger = logging.getLogger(__name__)
29

30
DEFAULT_MMR_PREFETCH_FACTOR = 4.0
31
DEFAULT_INSERTION_BATCH_SIZE = 20
32

33
T = TypeVar("T")
34

35

36
def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
37
    this_batch = []
38
    for entry in iterable:
39
        this_batch.append(entry)
40
        if len(this_batch) == batch_size:
41
            yield this_batch
42
            this_batch = []
43
    if this_batch:
44
        yield this_batch
45

46

47
class CassandraVectorStore(VectorStore):
48
    """
49
    Cassandra Vector Store.
50

51
    An abstraction of a Cassandra table with
52
    vector-similarity-search. Documents, and their embeddings, are stored
53
    in a Cassandra table and a vector-capable index is used for searches.
54
    The table does not need to exist beforehand: if necessary it will
55
    be created behind the scenes.
56

57
    All Cassandra operations are done through the CassIO library.
58

59
    Note: in recent versions, only `table` and `embedding_dimension` can be
60
    passed positionally. Please revise your code if needed.
61
    This is to accommodate for a leaner usage, whereby the DB connection
62
    is set globally through a `cassio.init(...)` call: then, the DB details
63
    are not to be specified anymore when creating a vector store, unless
64
    desired.
65

66
    Args:
67
        table (str): table name to use. If not existing, it will be created.
68
        embedding_dimension (int): length of the embedding vectors in use.
69
        session (optional, cassandra.cluster.Session): the Cassandra session
70
            to use.
71
            Can be omitted, or equivalently set to None, to use the
72
            DB connection set globally through cassio.init() beforehand.
73
        keyspace (optional. str): name of the Cassandra keyspace to work in
74
            Can be omitted, or equivalently set to None, to use the
75
            DB connection set globally through cassio.init() beforehand.
76
        ttl_seconds (optional, int): expiration time for inserted entries.
77
            Default is no expiration (None).
78
        insertion_batch_size (optional, int): how many vectors are inserted
79
            concurrently, for use by bulk inserts. Defaults to 20.
80
    """
81

82
    stores_text: bool = True
83
    flat_metadata: bool = True
84

85
    def __init__(
86
        self,
87
        table: str,
88
        embedding_dimension: int,
89
        *,
90
        session: Optional[Any] = None,
91
        keyspace: Optional[str] = None,
92
        ttl_seconds: Optional[int] = None,
93
        insertion_batch_size: int = DEFAULT_INSERTION_BATCH_SIZE,
94
    ) -> None:
95
        import_err_msg = (
96
            "`cassio` package not found, please run `pip install --upgrade cassio`"
97
        )
98
        try:
99
            from cassio.table import ClusteredMetadataVectorCassandraTable
100
        except ImportError:
101
            raise ImportError(import_err_msg)
102

103
        self._session = session
104
        self._keyspace = keyspace
105
        self._table = table
106
        self._embedding_dimension = embedding_dimension
107
        self._ttl_seconds = ttl_seconds
108
        self._insertion_batch_size = insertion_batch_size
109

110
        _logger.debug("Creating the Cassandra table")
111
        self.vector_table = ClusteredMetadataVectorCassandraTable(
112
            session=self._session,
113
            keyspace=self._keyspace,
114
            table=self._table,
115
            vector_dimension=self._embedding_dimension,
116
            primary_key_type=["TEXT", "TEXT"],
117
            # a conservative choice here, to make everything searchable
118
            # except the bulky "_node_content" key (it'd make little sense to):
119
            metadata_indexing=("default_to_searchable", ["_node_content"]),
120
        )
121

122
    def add(
123
        self,
124
        nodes: List[BaseNode],
125
        **add_kwargs: Any,
126
    ) -> List[str]:
127
        """Add nodes to index.
128

129
        Args:
130
            nodes: List[BaseNode]: list of node with embeddings
131

132
        """
133
        node_ids = []
134
        node_contents = []
135
        node_metadatas = []
136
        node_embeddings = []
137
        for node in nodes:
138
            metadata = node_to_metadata_dict(
139
                node,
140
                remove_text=True,
141
                flat_metadata=self.flat_metadata,
142
            )
143
            node_ids.append(node.node_id)
144
            node_contents.append(node.get_content(metadata_mode=MetadataMode.NONE))
145
            node_metadatas.append(metadata)
146
            node_embeddings.append(node.get_embedding())
147

148
        _logger.debug(f"Adding {len(node_ids)} rows to table")
149
        # Concurrent batching of inserts:
150
        insertion_tuples = zip(node_ids, node_contents, node_metadatas, node_embeddings)
151
        for insertion_batch in _batch_iterable(
152
            insertion_tuples, batch_size=self._insertion_batch_size
153
        ):
154
            futures = []
155
            for (
156
                node_id,
157
                node_content,
158
                node_metadata,
159
                node_embedding,
160
            ) in insertion_batch:
161
                node_ref_doc_id = node_metadata["ref_doc_id"]
162
                futures.append(
163
                    self.vector_table.put_async(
164
                        row_id=node_id,
165
                        body_blob=node_content,
166
                        vector=node_embedding,
167
                        metadata=node_metadata,
168
                        partition_id=node_ref_doc_id,
169
                        ttl_seconds=self._ttl_seconds,
170
                    )
171
                )
172
            for future in futures:
173
                _ = future.result()
174

175
        return node_ids
176

177
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
178
        """
179
        Delete nodes using with ref_doc_id.
180

181
        Args:
182
            ref_doc_id (str): The doc_id of the document to delete.
183

184
        """
185
        _logger.debug("Deleting a document from the Cassandra table")
186
        self.vector_table.delete_partition(
187
            partition_id=ref_doc_id,
188
        )
189

190
    @property
191
    def client(self) -> Any:
192
        """Return the underlying cassIO vector table object."""
193
        return self.vector_table
194

195
    @staticmethod
196
    def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
197
        if any(
198
            not isinstance(f, ExactMatchFilter) for f in query_filters.legacy_filters()
199
        ):
200
            raise NotImplementedError("Only `ExactMatchFilter` filters are supported")
201
        return {f.key: f.value for f in query_filters.filters}
202

203
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
204
        """
205
        Query index for top k most similar nodes.
206

207
        Supported query modes: 'default' (most similar vectors) and 'mmr'.
208

209
        Args:
210
            query (VectorStoreQuery): the basic query definition. Defines:
211
                mode (VectorStoreQueryMode): one of the supported modes
212
                query_embedding (List[float]): query embedding to search against
213
                similarity_top_k (int): top k most similar nodes
214
                mmr_threshold (Optional[float]): this is the 0-to-1 MMR lambda.
215
                    If present, takes precedence over the kwargs parameter.
216
                    Ignored unless for MMR queries.
217

218
        Args for query.mode == 'mmr' (ignored otherwise):
219
            mmr_threshold (Optional[float]): this is the 0-to-1 lambda for MMR.
220
                Note that in principle mmr_threshold could come in the query
221
            mmr_prefetch_factor (Optional[float]): factor applied to top_k
222
                for prefetch pool size. Defaults to 4.0
223
            mmr_prefetch_k (Optional[int]): prefetch pool size. This cannot be
224
                passed together with mmr_prefetch_factor
225
        """
226
        _available_query_modes = [
227
            VectorStoreQueryMode.DEFAULT,
228
            VectorStoreQueryMode.MMR,
229
        ]
230
        if query.mode not in _available_query_modes:
231
            raise NotImplementedError(f"Query mode {query.mode} not available.")
232
        #
233
        query_embedding = cast(List[float], query.query_embedding)
234

235
        # metadata filtering
236
        if query.filters is not None:
237
            # raise NotImplementedError("No metadata filtering yet")
238
            query_metadata = self._query_filters_to_dict(query.filters)
239
        else:
240
            query_metadata = {}
241

242
        _logger.debug(
243
            f"Running ANN search on the Cassandra table (query mode: {query.mode})"
244
        )
245
        if query.mode == VectorStoreQueryMode.DEFAULT:
246
            matches = list(
247
                self.vector_table.metric_ann_search(
248
                    vector=query_embedding,
249
                    n=query.similarity_top_k,
250
                    metric="cos",
251
                    metric_threshold=None,
252
                    metadata=query_metadata,
253
                )
254
            )
255
            top_k_scores = [match["distance"] for match in matches]
256
        elif query.mode == VectorStoreQueryMode.MMR:
257
            # Querying a larger number of vectors and then doing MMR on them.
258
            if (
259
                kwargs.get("mmr_prefetch_factor") is not None
260
                and kwargs.get("mmr_prefetch_k") is not None
261
            ):
262
                raise ValueError(
263
                    "'mmr_prefetch_factor' and 'mmr_prefetch_k' "
264
                    "cannot coexist in a call to query()"
265
                )
266
            else:
267
                if kwargs.get("mmr_prefetch_k") is not None:
268
                    prefetch_k0 = int(kwargs["mmr_prefetch_k"])
269
                else:
270
                    prefetch_k0 = int(
271
                        query.similarity_top_k
272
                        * kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
273
                    )
274
            prefetch_k = max(prefetch_k0, query.similarity_top_k)
275
            #
276
            prefetch_matches = list(
277
                self.vector_table.metric_ann_search(
278
                    vector=query_embedding,
279
                    n=prefetch_k,
280
                    metric="cos",
281
                    metric_threshold=None,  # this is not `mmr_threshold`
282
                    metadata=query_metadata,
283
                )
284
            )
285
            #
286
            mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")
287
            if prefetch_matches:
288
                pf_match_indices, pf_match_embeddings = zip(
289
                    *enumerate(match["vector"] for match in prefetch_matches)
290
                )
291
            else:
292
                pf_match_indices, pf_match_embeddings = [], []
293
            pf_match_indices = list(pf_match_indices)
294
            pf_match_embeddings = list(pf_match_embeddings)
295
            mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
296
                query_embedding,
297
                pf_match_embeddings,
298
                similarity_top_k=query.similarity_top_k,
299
                embedding_ids=pf_match_indices,
300
                mmr_threshold=mmr_threshold,
301
            )
302
            #
303
            matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
304
            top_k_scores = mmr_similarities
305

306
        top_k_nodes = []
307
        top_k_ids = []
308
        for match in matches:
309
            node = metadata_dict_to_node(match["metadata"])
310
            node.set_content(match["body_blob"])
311
            top_k_nodes.append(node)
312
            top_k_ids.append(match["row_id"])
313

314
        return VectorStoreQueryResult(
315
            nodes=top_k_nodes,
316
            similarities=top_k_scores,
317
            ids=top_k_ids,
318
        )
319

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.