llama-index

Форк
0
257 строк · 9.0 Кб
1
import json
2
import logging
3
from typing import Any, List, Optional, Sequence
4

5
from sqlalchemy.pool import QueuePool
6

7
from llama_index.legacy.schema import BaseNode, MetadataMode
8
from llama_index.legacy.vector_stores.types import (
9
    BaseNode,
10
    VectorStore,
11
    VectorStoreQuery,
12
    VectorStoreQueryResult,
13
)
14
from llama_index.legacy.vector_stores.utils import (
15
    metadata_dict_to_node,
16
    node_to_metadata_dict,
17
)
18

19
logger = logging.getLogger(__name__)
20

21

22
class SingleStoreVectorStore(VectorStore):
23
    """SingleStore vector store.
24

25
    This vector store stores embeddings within a SingleStore database table.
26

27
    During query time, the index uses SingleStore to query for the top
28
    k most similar nodes.
29

30
    Args:
31
        table_name (str, optional): Specifies the name of the table in use.
32
                Defaults to "embeddings".
33
        content_field (str, optional): Specifies the field to store the content.
34
            Defaults to "content".
35
        metadata_field (str, optional): Specifies the field to store metadata.
36
            Defaults to "metadata".
37
        vector_field (str, optional): Specifies the field to store the vector.
38
            Defaults to "vector".
39

40
        Following arguments pertain to the connection pool:
41

42
        pool_size (int, optional): Determines the number of active connections in
43
            the pool. Defaults to 5.
44
        max_overflow (int, optional): Determines the maximum number of connections
45
            allowed beyond the pool_size. Defaults to 10.
46
        timeout (float, optional): Specifies the maximum wait time in seconds for
47
            establishing a connection. Defaults to 30.
48

49
        Following arguments pertain to the connection:
50

51
        host (str, optional): Specifies the hostname, IP address, or URL for the
52
                database connection. The default scheme is "mysql".
53
        user (str, optional): Database username.
54
        password (str, optional): Database password.
55
        port (int, optional): Database port. Defaults to 3306 for non-HTTP
56
            connections, 80 for HTTP connections, and 443 for HTTPS connections.
57
        database (str, optional): Database name.
58

59
    """
60

61
    stores_text: bool = True
62
    flat_metadata: bool = True
63

64
    def __init__(
65
        self,
66
        table_name: str = "embeddings",
67
        content_field: str = "content",
68
        metadata_field: str = "metadata",
69
        vector_field: str = "vector",
70
        pool_size: int = 5,
71
        max_overflow: int = 10,
72
        timeout: float = 30,
73
        **kwargs: Any,
74
    ) -> None:
75
        """Init params."""
76
        self.table_name = table_name
77
        self.content_field = content_field
78
        self.metadata_field = metadata_field
79
        self.vector_field = vector_field
80
        self.pool_size = pool_size
81
        self.max_overflow = max_overflow
82
        self.timeout = timeout
83

84
        self.connection_kwargs = kwargs
85
        self.connection_pool = QueuePool(
86
            self._get_connection,
87
            pool_size=self.pool_size,
88
            max_overflow=self.max_overflow,
89
            timeout=self.timeout,
90
        )
91

92
        self._create_table()
93

94
    @property
95
    def client(self) -> Any:
96
        """Return SingleStoreDB client."""
97
        return self._get_connection()
98

99
    @classmethod
100
    def class_name(cls) -> str:
101
        return "SingleStoreVectorStore"
102

103
    def _get_connection(self) -> Any:
104
        try:
105
            import singlestoredb as s2
106
        except ImportError:
107
            raise ImportError(
108
                "Could not import singlestoredb python package. "
109
                "Please install it with `pip install singlestoredb`."
110
            )
111
        return s2.connect(**self.connection_kwargs)
112

113
    def _create_table(self) -> None:
114
        conn = self.connection_pool.connect()
115
        try:
116
            cur = conn.cursor()
117
            try:
118
                cur.execute(
119
                    f"""CREATE TABLE IF NOT EXISTS {self.table_name}
120
                    ({self.content_field} TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci,
121
                    {self.vector_field} BLOB, {self.metadata_field} JSON);"""
122
                )
123
            finally:
124
                cur.close()
125
        finally:
126
            conn.close()
127

128
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
129
        """Add nodes to index.
130

131
        Args:
132
            nodes: List[BaseNode]: list of nodes with embeddings
133

134
        """
135
        conn = self.connection_pool.connect()
136
        cursor = conn.cursor()
137
        try:
138
            for node in nodes:
139
                embedding = node.get_embedding()
140
                metadata = node_to_metadata_dict(
141
                    node, remove_text=True, flat_metadata=self.flat_metadata
142
                )
143
                cursor.execute(
144
                    "INSERT INTO {} VALUES (%s, JSON_ARRAY_PACK(%s), %s)".format(
145
                        self.table_name
146
                    ),
147
                    (
148
                        node.get_content(metadata_mode=MetadataMode.NONE) or "",
149
                        "[{}]".format(",".join(map(str, embedding))),
150
                        json.dumps(metadata),
151
                    ),
152
                )
153
        finally:
154
            cursor.close()
155
            conn.close()
156
        return [node.node_id for node in nodes]
157

158
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
159
        """
160
        Delete nodes using with ref_doc_id.
161

162
        Args:
163
            ref_doc_id (str): The doc_id of the document to delete.
164

165
        """
166
        conn = self.connection_pool.connect()
167
        cursor = conn.cursor()
168
        try:
169
            cursor.execute(
170
                f"DELETE FROM {self.table_name} WHERE JSON_EXTRACT_JSON(metadata, 'ref_doc_id') = %s",
171
                ('"' + ref_doc_id + '"',),
172
            )
173
        finally:
174
            cursor.close()
175
            conn.close()
176

177
    def query(
178
        self, query: VectorStoreQuery, filter: Optional[dict] = None, **kwargs: Any
179
    ) -> VectorStoreQueryResult:
180
        """
181
        Query index for top k most similar nodes.
182

183
        Args:
184
            query (VectorStoreQuery): Contains query_embedding and similarity_top_k attributes.
185
            filter (Optional[dict]): A dictionary of metadata fields and values to filter by. Defaults to None.
186

187
        Returns:
188
            VectorStoreQueryResult: Contains nodes, similarities, and ids attributes.
189
        """
190
        query_embedding = query.query_embedding
191
        similarity_top_k = query.similarity_top_k
192
        conn = self.connection_pool.connect()
193
        where_clause: str = ""
194
        where_clause_values: List[Any] = []
195

196
        if filter:
197
            where_clause = "WHERE "
198
            arguments = []
199

200
            def build_where_clause(
201
                where_clause_values: List[Any],
202
                sub_filter: dict,
203
                prefix_args: Optional[List[str]] = None,
204
            ) -> None:
205
                prefix_args = prefix_args or []
206
                for key in sub_filter:
207
                    if isinstance(sub_filter[key], dict):
208
                        build_where_clause(
209
                            where_clause_values, sub_filter[key], [*prefix_args, key]
210
                        )
211
                    else:
212
                        arguments.append(
213
                            "JSON_EXTRACT({}, {}) = %s".format(
214
                                {self.metadata_field},
215
                                ", ".join(["%s"] * (len(prefix_args) + 1)),
216
                            )
217
                        )
218
                        where_clause_values += [*prefix_args, key]
219
                        where_clause_values.append(json.dumps(sub_filter[key]))
220

221
            build_where_clause(where_clause_values, filter)
222
            where_clause += " AND ".join(arguments)
223

224
        results: Sequence[Any] = []
225
        if query_embedding:
226
            try:
227
                cur = conn.cursor()
228
                formatted_vector = "[{}]".format(",".join(map(str, query_embedding)))
229
                try:
230
                    logger.debug("vector field: %s", formatted_vector)
231
                    logger.debug("similarity_top_k: %s", similarity_top_k)
232
                    cur.execute(
233
                        f"SELECT {self.content_field}, {self.metadata_field}, "
234
                        f"DOT_PRODUCT({self.vector_field}, "
235
                        "JSON_ARRAY_PACK(%s)) as similarity_score "
236
                        f"FROM {self.table_name} {where_clause} "
237
                        f"ORDER BY similarity_score DESC LIMIT {similarity_top_k}",
238
                        (formatted_vector, *tuple(where_clause_values)),
239
                    )
240
                    results = cur.fetchall()
241
                finally:
242
                    cur.close()
243
            finally:
244
                conn.close()
245

246
        nodes = []
247
        similarities = []
248
        ids = []
249
        for result in results:
250
            text, metadata, similarity_score = result
251
            node = metadata_dict_to_node(metadata)
252
            node.set_content(text)
253
            nodes.append(node)
254
            similarities.append(similarity_score)
255
            ids.append(node.node_id)
256

257
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
258

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.