llama-index

Форк
0
183 строки · 5.2 Кб
1
import logging
2
import math
3
from typing import Any, List
4

5
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
6
from llama_index.legacy.vector_stores.types import (
7
    MetadataFilters,
8
    VectorStore,
9
    VectorStoreQuery,
10
    VectorStoreQueryResult,
11
)
12
from llama_index.legacy.vector_stores.utils import (
13
    legacy_metadata_dict_to_node,
14
    metadata_dict_to_node,
15
    node_to_metadata_dict,
16
)
17

18
logger = logging.getLogger(__name__)
19

20

21
def _to_bagel_filter(standard_filters: MetadataFilters) -> dict:
22
    """
23
    Translate standard metadata filters to Bagel specific spec.
24
    """
25
    filters = {}
26
    for filter in standard_filters.legacy_filters():
27
        filters[filter.key] = filter.value
28
    return filters
29

30

31
class BagelVectorStore(VectorStore):
32
    """
33
    Vector store for Bagel.
34
    """
35

36
    # support for Bagel specific parameters
37
    stores_text: bool = True
38
    flat_metadata: bool = True
39

40
    def __init__(self, collection: Any, **kwargs: Any) -> None:
41
        """
42
        Initialize BagelVectorStore.
43

44
        Args:
45
            collection: Bagel collection.
46
            **kwargs: Additional arguments.
47
        """
48
        try:
49
            from bagel.api.Cluster import Cluster
50
        except ImportError:
51
            raise ImportError("Bagel is not installed. Please install bagel.")
52

53
        if not isinstance(collection, Cluster):
54
            raise ValueError("Collection must be a bagel Cluster.")
55

56
        self._collection = collection
57

58
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
59
        """
60
        Add a list of nodes with embeddings to the vector store.
61

62
        Args:
63
            nodes: List of nodes with embeddings.
64
            kwargs: Additional arguments.
65

66
        Returns:
67
            List of document ids.
68
        """
69
        if not self._collection:
70
            raise ValueError("collection not set")
71

72
        ids = []
73
        embeddings = []
74
        metadatas = []
75
        documents = []
76

77
        for node in nodes:
78
            ids.append(node.node_id)
79
            embeddings.append(node.get_embedding())
80
            metadatas.append(
81
                node_to_metadata_dict(
82
                    node,
83
                    remove_text=True,
84
                    flat_metadata=self.flat_metadata,
85
                )
86
            )
87
            documents.append(node.get_content(metadata_mode=MetadataMode.NONE) or "")
88

89
        self._collection.add(
90
            ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents
91
        )
92

93
        return ids
94

95
    def delete(self, ref_doc_id: str, **kwargs: Any) -> None:
96
        """
97
        Delete a document from the vector store.
98

99
        Args:
100
            ref_doc_id: Reference document id.
101
            kwargs: Additional arguments.
102
        """
103
        if not self._collection:
104
            raise ValueError("collection not set")
105

106
        results = self._collection.get(where={"doc_id": ref_doc_id})
107
        if results and "ids" in results:
108
            self._collection.delete(ids=results["ids"])
109

110
    @property
111
    def client(self) -> Any:
112
        """
113
        Get the Bagel cluster.
114
        """
115
        return self._collection
116

117
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
118
        """
119
        Query the vector store.
120

121
        Args:
122
            query: Query to run.
123
            kwargs: Additional arguments.
124

125
        Returns:
126
            Query result.
127
        """
128
        if not self._collection:
129
            raise ValueError("collection not set")
130

131
        if query.filters is not None:
132
            if "where" in kwargs:
133
                raise ValueError("Cannot specify both filters and where")
134
            where = _to_bagel_filter(query.filters)
135
        else:
136
            where = kwargs.get("where", {})
137

138
        results = self._collection.find(
139
            query_embeddings=query.query_embedding,
140
            where=where,
141
            n_results=query.similarity_top_k,
142
            **kwargs,
143
        )
144

145
        logger.debug(f"query results: {results}")
146

147
        nodes = []
148
        similarities = []
149
        ids = []
150

151
        for node_id, text, metadata, distance in zip(
152
            results["ids"][0],
153
            results["documents"][0],
154
            results["metadatas"][0],
155
            results["distances"][0],
156
        ):
157
            try:
158
                node = metadata_dict_to_node(metadata)
159
                node.set_content(text)
160
            except Exception:
161
                # NOTE: deprecated legacy logic for backward compatibility
162
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
163
                    metadata
164
                )
165

166
                node = TextNode(
167
                    text=text,
168
                    id_=node_id,
169
                    metadata=metadata,
170
                    start_char_idx=node_info.get("start", None),
171
                    end_char_idx=node_info.get("end", None),
172
                    relationships=relationships,
173
                )
174

175
            nodes.append(node)
176
            similarities.append(1.0 - math.exp(-distance))
177
            ids.append(node_id)
178

179
            logger.debug(f"node: {node}")
180
            logger.debug(f"similarity: {1.0 - math.exp(-distance)}")
181
            logger.debug(f"id: {node_id}")
182

183
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
184

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.