llama-index

Форк
0
194 строки · 5.6 Кб
1
import logging
2
import math
3
from collections import defaultdict
4
from typing import Any, List
5

6
from llama_index.legacy.constants import DEFAULT_EMBEDDING_DIM
7
from llama_index.legacy.schema import BaseNode, TextNode
8
from llama_index.legacy.vector_stores.types import (
9
    MetadataFilters,
10
    VectorStore,
11
    VectorStoreQuery,
12
    VectorStoreQueryResult,
13
)
14
from llama_index.legacy.vector_stores.utils import (
15
    legacy_metadata_dict_to_node,
16
    metadata_dict_to_node,
17
    node_to_metadata_dict,
18
)
19

20
logger = logging.getLogger(__name__)
21

22

23
class SupabaseVectorStore(VectorStore):
24
    """Supbabase Vector.
25

26
    In this vector store, embeddings are stored in Postgres table using pgvector.
27

28
    During query time, the index uses pgvector/Supabase to query for the top
29
    k most similar nodes.
30

31
    Args:
32
        postgres_connection_string (str):
33
            postgres connection string
34

35
        collection_name (str):
36
            name of the collection to store the embeddings in
37

38
    """
39

40
    stores_text = True
41
    flat_metadata = False
42

43
    def __init__(
44
        self,
45
        postgres_connection_string: str,
46
        collection_name: str,
47
        dimension: int = DEFAULT_EMBEDDING_DIM,
48
        **kwargs: Any,
49
    ) -> None:
50
        """Init params."""
51
        import_err_msg = "`vecs` package not found, please run `pip install vecs`"
52
        try:
53
            import vecs
54
            from vecs.collection import CollectionNotFound
55
        except ImportError:
56
            raise ImportError(import_err_msg)
57

58
        client = vecs.create_client(postgres_connection_string)
59

60
        try:
61
            self._collection = client.get_collection(name=collection_name)
62
        except CollectionNotFound:
63
            logger.info(
64
                f"Collection {collection_name} does not exist, "
65
                f"try creating one with dimension={dimension}"
66
            )
67
            self._collection = client.create_collection(
68
                name=collection_name, dimension=dimension
69
            )
70

71
    @property
72
    def client(self) -> None:
73
        """Get client."""
74
        return
75

76
    def _to_vecs_filters(self, filters: MetadataFilters) -> Any:
77
        """Convert llama filters to vecs filters. $eq is the only supported operator."""
78
        vecs_filter = defaultdict(list)
79
        filter_cond = f"${filters.condition}"
80

81
        for f in filters.legacy_filters():
82
            sub_filter = {}
83
            sub_filter[f.key] = {"$eq": f.value}
84
            vecs_filter[filter_cond].append(sub_filter)
85
        return vecs_filter
86

87
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
88
        """Add nodes to index.
89

90
        Args:
91
            nodes: List[BaseNode]: list of nodes with embeddings
92

93
        """
94
        if self._collection is None:
95
            raise ValueError("Collection not initialized")
96

97
        data = []
98
        ids = []
99

100
        for node in nodes:
101
            # NOTE: keep text in metadata dict since there's no special field in
102
            #       Supabase Vector.
103
            metadata_dict = node_to_metadata_dict(
104
                node, remove_text=False, flat_metadata=self.flat_metadata
105
            )
106

107
            data.append((node.node_id, node.get_embedding(), metadata_dict))
108
            ids.append(node.node_id)
109

110
        self._collection.upsert(records=data)
111

112
        return ids
113

114
    def get_by_id(self, doc_id: str, **kwargs: Any) -> list:
115
        """Get row ids by doc id.
116

117
        Args:
118
            doc_id (str): document id
119
        """
120
        filters = {"doc_id": {"$eq": doc_id}}
121

122
        return self._collection.query(
123
            data=None,
124
            filters=filters,
125
            include_value=False,
126
            include_metadata=False,
127
            **kwargs,
128
        )
129

130
        # NOTE: list of row ids
131

132
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
133
        """Delete doc.
134

135
        Args:
136
            :param ref_doc_id (str): document id
137

138
        """
139
        row_ids = self.get_by_id(ref_doc_id)
140

141
        if len(row_ids) > 0:
142
            self._collection.delete(row_ids)
143

144
    def query(
145
        self,
146
        query: VectorStoreQuery,
147
        **kwargs: Any,
148
    ) -> VectorStoreQueryResult:
149
        """Query index for top k most similar nodes.
150

151
        Args:
152
            query (List[float]): query embedding
153

154
        """
155
        filters = None
156
        if query.filters is not None:
157
            filters = self._to_vecs_filters(query.filters)
158

159
        results = self._collection.query(
160
            data=query.query_embedding,
161
            limit=query.similarity_top_k,
162
            filters=filters,
163
            include_value=True,
164
            include_metadata=True,
165
        )
166

167
        similarities = []
168
        ids = []
169
        nodes = []
170
        for id_, distance, metadata in results:
171
            """shape of the result is [(vector, distance, metadata)]"""
172
            text = metadata.pop("text", None)
173

174
            try:
175
                node = metadata_dict_to_node(metadata)
176
            except Exception:
177
                # NOTE: deprecated legacy logic for backward compatibility
178
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
179
                    metadata
180
                )
181
                node = TextNode(
182
                    id_=id_,
183
                    text=text,
184
                    metadata=metadata,
185
                    start_char_idx=node_info.get("start", None),
186
                    end_char_idx=node_info.get("end", None),
187
                    relationships=relationships,
188
                )
189

190
            nodes.append(node)
191
            similarities.append(1.0 - math.exp(-distance))
192
            ids.append(id_)
193

194
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
195

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.