llama-index

Форк
0
470 строк · 16.9 Кб
1
"""Redis Vector store index.
2

3
An index that is built on top of an existing vector store.
4
"""
5

6
import logging
7
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8

9
import fsspec
10

11
from llama_index.legacy.bridge.pydantic import PrivateAttr
12
from llama_index.legacy.readers.redis.utils import (
13
    TokenEscaper,
14
    array_to_buffer,
15
    check_redis_modules_exist,
16
    convert_bytes,
17
    get_redis_query,
18
)
19
from llama_index.legacy.schema import (
20
    BaseNode,
21
    MetadataMode,
22
    NodeRelationship,
23
    RelatedNodeInfo,
24
    TextNode,
25
)
26
from llama_index.legacy.vector_stores.types import (
27
    BasePydanticVectorStore,
28
    MetadataFilters,
29
    VectorStoreQuery,
30
    VectorStoreQueryResult,
31
)
32
from llama_index.legacy.vector_stores.utils import (
33
    metadata_dict_to_node,
34
    node_to_metadata_dict,
35
)
36

37
_logger = logging.getLogger(__name__)
38

39

40
if TYPE_CHECKING:
41
    from redis.client import Redis as RedisType
42
    from redis.commands.search.field import VectorField
43

44

45
class RedisVectorStore(BasePydanticVectorStore):
46
    stores_text = True
47
    stores_node = True
48
    flat_metadata = False
49

50
    _tokenizer: Any = PrivateAttr()
51
    _redis_client: Any = PrivateAttr()
52
    _prefix: str = PrivateAttr()
53
    _index_name: str = PrivateAttr()
54
    _index_args: Dict[str, Any] = PrivateAttr()
55
    _metadata_fields: List[str] = PrivateAttr()
56
    _overwrite: bool = PrivateAttr()
57
    _vector_field: str = PrivateAttr()
58
    _vector_key: str = PrivateAttr()
59

60
    def __init__(
61
        self,
62
        index_name: str,
63
        index_prefix: str = "llama_index",
64
        prefix_ending: str = "/vector",
65
        index_args: Optional[Dict[str, Any]] = None,
66
        metadata_fields: Optional[List[str]] = None,
67
        redis_url: str = "redis://localhost:6379",
68
        overwrite: bool = False,
69
        **kwargs: Any,
70
    ) -> None:
71
        """Initialize RedisVectorStore.
72

73
        For index arguments that can be passed to RediSearch, see
74
        https://redis.io/docs/stack/search/reference/vectors/
75

76
        The index arguments will depend on the index type chosen. There
77
        are two available index types
78
            - FLAT: a flat index that uses brute force search
79
            - HNSW: a hierarchical navigable small world graph index
80

81
        Args:
82
            index_name (str): Name of the index.
83
            index_prefix (str): Prefix for the index. Defaults to "llama_index".
84
                The actual prefix used by Redis will be
85
                "{index_prefix}{prefix_ending}".
86
            prefix_ending (str): Prefix ending for the index. Be careful when
87
                changing this: https://github.com/jerryjliu/llama_index/pull/6665.
88
                Defaults to "/vector".
89
            index_args (Dict[str, Any]): Arguments for the index. Defaults to None.
90
            metadata_fields (List[str]): List of metadata fields to store in the index
91
                (only supports TAG fields).
92
            redis_url (str): URL for the redis instance.
93
                Defaults to "redis://localhost:6379".
94
            overwrite (bool): Whether to overwrite the index if it already exists.
95
                Defaults to False.
96
            kwargs (Any): Additional arguments to pass to the redis client.
97

98
        Raises:
99
            ValueError: If redis-py is not installed
100
            ValueError: If RediSearch is not installed
101

102
        Examples:
103
            >>> from llama_index.legacy.vector_stores.redis import RedisVectorStore
104
            >>> # Create a RedisVectorStore
105
            >>> vector_store = RedisVectorStore(
106
            >>>     index_name="my_index",
107
            >>>     index_prefix="llama_index",
108
            >>>     index_args={"algorithm": "HNSW", "m": 16, "ef_construction": 200,
109
                "distance_metric": "cosine"},
110
            >>>     redis_url="redis://localhost:6379/",
111
            >>>     overwrite=True)
112
        """
113
        try:
114
            import redis
115
        except ImportError:
116
            raise ValueError(
117
                "Could not import redis python package. "
118
                "Please install it with `pip install redis`."
119
            )
120
        try:
121
            # connect to redis from url
122
            self._redis_client = redis.from_url(redis_url, **kwargs)
123
            # check if redis has redisearch module installed
124
            check_redis_modules_exist(self._redis_client)
125
        except ValueError as e:
126
            raise ValueError(f"Redis failed to connect: {e}")
127

128
        # index identifiers
129
        self._prefix = index_prefix + prefix_ending
130
        self._index_name = index_name
131
        self._index_args = index_args if index_args is not None else {}
132
        self._metadata_fields = metadata_fields if metadata_fields is not None else []
133
        self._overwrite = overwrite
134
        self._vector_field = str(self._index_args.get("vector_field", "vector"))
135
        self._vector_key = str(self._index_args.get("vector_key", "vector"))
136
        self._tokenizer = TokenEscaper()
137
        super().__init__()
138

139
    @property
140
    def client(self) -> "RedisType":
141
        """Return the redis client instance."""
142
        return self._redis_client
143

144
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
145
        """Add nodes to the index.
146

147
        Args:
148
            nodes (List[BaseNode]): List of nodes with embeddings
149

150
        Returns:
151
            List[str]: List of ids of the documents added to the index.
152

153
        Raises:
154
            ValueError: If the index already exists and overwrite is False.
155
        """
156
        # check to see if empty document list was passed
157
        if len(nodes) == 0:
158
            return []
159

160
        # set vector dim for creation if index doesn't exist
161
        self._index_args["dims"] = len(nodes[0].get_embedding())
162

163
        if self._index_exists():
164
            if self._overwrite:
165
                self.delete_index()
166
                self._create_index()
167
            else:
168
                logging.info(f"Adding document to existing index {self._index_name}")
169
        else:
170
            self._create_index()
171

172
        ids = []
173
        for node in nodes:
174
            mapping = {
175
                "id": node.node_id,
176
                "doc_id": node.ref_doc_id,
177
                "text": node.get_content(metadata_mode=MetadataMode.NONE),
178
                self._vector_key: array_to_buffer(node.get_embedding()),
179
            }
180
            additional_metadata = node_to_metadata_dict(
181
                node, remove_text=True, flat_metadata=self.flat_metadata
182
            )
183
            mapping.update(additional_metadata)
184

185
            ids.append(node.node_id)
186
            key = "_".join([self._prefix, str(node.node_id)])
187
            self._redis_client.hset(key, mapping=mapping)  # type: ignore
188

189
        _logger.info(f"Added {len(ids)} documents to index {self._index_name}")
190
        return ids
191

192
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
193
        """
194
        Delete nodes using with ref_doc_id.
195

196
        Args:
197
            ref_doc_id (str): The doc_id of the document to delete.
198

199
        """
200
        # use tokenizer to escape dashes in query
201
        query_str = "@doc_id:{%s}" % self._tokenizer.escape(ref_doc_id)
202
        # find all documents that match a doc_id
203
        results = self._redis_client.ft(self._index_name).search(query_str)
204
        if len(results.docs) == 0:
205
            # don't raise an error but warn the user that document wasn't found
206
            # could be a result of eviction policy
207
            _logger.warning(
208
                f"Document with doc_id {ref_doc_id} not found "
209
                f"in index {self._index_name}"
210
            )
211
            return
212

213
        for doc in results.docs:
214
            self._redis_client.delete(doc.id)
215
        _logger.info(
216
            f"Deleted {len(results.docs)} documents from index {self._index_name}"
217
        )
218

219
    def delete_index(self) -> None:
220
        """Delete the index and all documents."""
221
        _logger.info(f"Deleting index {self._index_name}")
222
        self._redis_client.ft(self._index_name).dropindex(delete_documents=True)
223

224
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
225
        """Query the index.
226

227
        Args:
228
            query (VectorStoreQuery): query object
229

230
        Returns:
231
            VectorStoreQueryResult: query result
232

233
        Raises:
234
            ValueError: If query.query_embedding is None.
235
            redis.exceptions.RedisError: If there is an error querying the index.
236
            redis.exceptions.TimeoutError: If there is a timeout querying the index.
237
            ValueError: If no documents are found when querying the index.
238
        """
239
        from redis.exceptions import RedisError
240
        from redis.exceptions import TimeoutError as RedisTimeoutError
241

242
        return_fields = [
243
            "id",
244
            "doc_id",
245
            "text",
246
            self._vector_key,
247
            "vector_score",
248
            "_node_content",
249
        ]
250

251
        filters = _to_redis_filters(query.filters) if query.filters is not None else "*"
252

253
        _logger.info(f"Using filters: {filters}")
254

255
        redis_query = get_redis_query(
256
            return_fields=return_fields,
257
            top_k=query.similarity_top_k,
258
            vector_field=self._vector_field,
259
            filters=filters,
260
        )
261

262
        if not query.query_embedding:
263
            raise ValueError("Query embedding is required for querying.")
264

265
        query_params = {
266
            "vector": array_to_buffer(query.query_embedding),
267
        }
268
        _logger.info(f"Querying index {self._index_name}")
269

270
        try:
271
            results = self._redis_client.ft(self._index_name).search(
272
                redis_query, query_params=query_params  # type: ignore
273
            )
274
        except RedisTimeoutError as e:
275
            _logger.error(f"Query timed out on {self._index_name}: {e}")
276
            raise
277
        except RedisError as e:
278
            _logger.error(f"Error querying {self._index_name}: {e}")
279
            raise
280

281
        if len(results.docs) == 0:
282
            raise ValueError(
283
                f"No docs found on index '{self._index_name}' with "
284
                f"prefix '{self._prefix}' and filters '{filters}'. "
285
                "* Did you originally create the index with a different prefix? "
286
                "* Did you index your metadata fields when you created the index?"
287
            )
288

289
        ids = []
290
        nodes = []
291
        scores = []
292
        for doc in results.docs:
293
            try:
294
                node = metadata_dict_to_node({"_node_content": doc._node_content})
295
                node.text = doc.text
296
            except Exception:
297
                # TODO: Legacy support for old metadata format
298
                node = TextNode(
299
                    text=doc.text,
300
                    id_=doc.id,
301
                    embedding=None,
302
                    relationships={
303
                        NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc.doc_id)
304
                    },
305
                )
306
            ids.append(doc.id.replace(self._prefix + "_", ""))
307
            nodes.append(node)
308
            scores.append(1 - float(doc.vector_score))
309
        _logger.info(f"Found {len(nodes)} results for query with id {ids}")
310

311
        return VectorStoreQueryResult(nodes=nodes, ids=ids, similarities=scores)
312

313
    def persist(
314
        self,
315
        persist_path: str,
316
        fs: Optional[fsspec.AbstractFileSystem] = None,
317
        in_background: bool = True,
318
    ) -> None:
319
        """Persist the vector store to disk.
320

321
        Args:
322
            persist_path (str): Path to persist the vector store to. (doesn't apply)
323
            in_background (bool, optional): Persist in background. Defaults to True.
324
            fs (fsspec.AbstractFileSystem, optional): Filesystem to persist to.
325
                (doesn't apply)
326

327
        Raises:
328
            redis.exceptions.RedisError: If there is an error
329
                                         persisting the index to disk.
330
        """
331
        from redis.exceptions import RedisError
332

333
        try:
334
            if in_background:
335
                _logger.info("Saving index to disk in background")
336
                self._redis_client.bgsave()
337
            else:
338
                _logger.info("Saving index to disk")
339
                self._redis_client.save()
340

341
        except RedisError as e:
342
            _logger.error(f"Error saving index to disk: {e}")
343
            raise
344

345
    def _create_index(self) -> None:
346
        # should never be called outside class and hence should not raise importerror
347
        from redis.commands.search.field import TagField, TextField
348
        from redis.commands.search.indexDefinition import IndexDefinition, IndexType
349

350
        # Create Index
351
        default_fields = [
352
            TextField("text", weight=1.0),
353
            TagField("doc_id", sortable=False),
354
            TagField("id", sortable=False),
355
        ]
356
        # add vector field to list of index fields. Create lazily to allow user
357
        # to specify index and search attributes in creation.
358

359
        fields = [
360
            *default_fields,
361
            self._create_vector_field(self._vector_field, **self._index_args),
362
        ]
363

364
        # add metadata fields to list of index fields or we won't be able to search them
365
        for metadata_field in self._metadata_fields:
366
            # TODO: allow addition of text fields as metadata
367
            # TODO: make sure we're preventing overwriting other keys (e.g. text,
368
            #   doc_id, id, and other vector fields)
369
            fields.append(TagField(metadata_field, sortable=False))
370

371
        _logger.info(f"Creating index {self._index_name}")
372
        self._redis_client.ft(self._index_name).create_index(
373
            fields=fields,
374
            definition=IndexDefinition(
375
                prefix=[self._prefix], index_type=IndexType.HASH
376
            ),  # TODO support JSON
377
        )
378

379
    def _index_exists(self) -> bool:
380
        # use FT._LIST to check if index exists
381
        indices = convert_bytes(self._redis_client.execute_command("FT._LIST"))
382
        return self._index_name in indices
383

384
    def _create_vector_field(
385
        self,
386
        name: str,
387
        dims: int = 1536,
388
        algorithm: str = "FLAT",
389
        datatype: str = "FLOAT32",
390
        distance_metric: str = "COSINE",
391
        initial_cap: int = 20000,
392
        block_size: int = 1000,
393
        m: int = 16,
394
        ef_construction: int = 200,
395
        ef_runtime: int = 10,
396
        epsilon: float = 0.8,
397
        **kwargs: Any,
398
    ) -> "VectorField":
399
        """Create a RediSearch VectorField.
400

401
        Args:
402
            name (str): The name of the field.
403
            algorithm (str): The algorithm used to index the vector.
404
            dims (int): The dimensionality of the vector.
405
            datatype (str): The type of the vector. default: FLOAT32
406
            distance_metric (str): The distance metric used to compare vectors.
407
            initial_cap (int): The initial capacity of the index.
408
            block_size (int): The block size of the index.
409
            m (int): The number of outgoing edges in the HNSW graph.
410
            ef_construction (int): Number of maximum allowed potential outgoing edges
411
                            candidates for each node in the graph,
412
                            during the graph building.
413
            ef_runtime (int): The umber of maximum top candidates to hold during the
414
                KNN search
415

416
        Returns:
417
            A RediSearch VectorField.
418
        """
419
        from redis import DataError
420
        from redis.commands.search.field import VectorField
421

422
        try:
423
            if algorithm.upper() == "HNSW":
424
                return VectorField(
425
                    name,
426
                    "HNSW",
427
                    {
428
                        "TYPE": datatype.upper(),
429
                        "DIM": dims,
430
                        "DISTANCE_METRIC": distance_metric.upper(),
431
                        "INITIAL_CAP": initial_cap,
432
                        "M": m,
433
                        "EF_CONSTRUCTION": ef_construction,
434
                        "EF_RUNTIME": ef_runtime,
435
                        "EPSILON": epsilon,
436
                    },
437
                )
438
            else:
439
                return VectorField(
440
                    name,
441
                    "FLAT",
442
                    {
443
                        "TYPE": datatype.upper(),
444
                        "DIM": dims,
445
                        "DISTANCE_METRIC": distance_metric.upper(),
446
                        "INITIAL_CAP": initial_cap,
447
                        "BLOCK_SIZE": block_size,
448
                    },
449
                )
450
        except DataError as e:
451
            raise ValueError(
452
                f"Failed to create Redis index vector field with error: {e}"
453
            )
454

455

456
# currently only supports exact tag match - {} denotes a tag
457
# must create the index with the correct metadata field before using a field as a
458
#   filter, or it will return no results
459
def _to_redis_filters(metadata_filters: MetadataFilters) -> str:
460
    tokenizer = TokenEscaper()
461

462
    filter_strings = []
463
    for filter in metadata_filters.legacy_filters():
464
        # adds quotes around the value to ensure that the filter is treated as an
465
        #   exact match
466
        filter_string = f"@{filter.key}:{{{tokenizer.escape(str(filter.value))}}}"
467
        filter_strings.append(filter_string)
468

469
    joined_filter_strings = " & ".join(filter_strings)
470
    return f"({joined_filter_strings})"
471

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.