llama-index
257 строк · 9.0 Кб
1import json2import logging3from typing import Any, List, Optional, Sequence4
5from sqlalchemy.pool import QueuePool6
7from llama_index.legacy.schema import BaseNode, MetadataMode8from llama_index.legacy.vector_stores.types import (9BaseNode,10VectorStore,11VectorStoreQuery,12VectorStoreQueryResult,13)
14from llama_index.legacy.vector_stores.utils import (15metadata_dict_to_node,16node_to_metadata_dict,17)
18
19logger = logging.getLogger(__name__)20
21
22class SingleStoreVectorStore(VectorStore):23"""SingleStore vector store.24
25This vector store stores embeddings within a SingleStore database table.
26
27During query time, the index uses SingleStore to query for the top
28k most similar nodes.
29
30Args:
31table_name (str, optional): Specifies the name of the table in use.
32Defaults to "embeddings".
33content_field (str, optional): Specifies the field to store the content.
34Defaults to "content".
35metadata_field (str, optional): Specifies the field to store metadata.
36Defaults to "metadata".
37vector_field (str, optional): Specifies the field to store the vector.
38Defaults to "vector".
39
40Following arguments pertain to the connection pool:
41
42pool_size (int, optional): Determines the number of active connections in
43the pool. Defaults to 5.
44max_overflow (int, optional): Determines the maximum number of connections
45allowed beyond the pool_size. Defaults to 10.
46timeout (float, optional): Specifies the maximum wait time in seconds for
47establishing a connection. Defaults to 30.
48
49Following arguments pertain to the connection:
50
51host (str, optional): Specifies the hostname, IP address, or URL for the
52database connection. The default scheme is "mysql".
53user (str, optional): Database username.
54password (str, optional): Database password.
55port (int, optional): Database port. Defaults to 3306 for non-HTTP
56connections, 80 for HTTP connections, and 443 for HTTPS connections.
57database (str, optional): Database name.
58
59"""
60
61stores_text: bool = True62flat_metadata: bool = True63
64def __init__(65self,66table_name: str = "embeddings",67content_field: str = "content",68metadata_field: str = "metadata",69vector_field: str = "vector",70pool_size: int = 5,71max_overflow: int = 10,72timeout: float = 30,73**kwargs: Any,74) -> None:75"""Init params."""76self.table_name = table_name77self.content_field = content_field78self.metadata_field = metadata_field79self.vector_field = vector_field80self.pool_size = pool_size81self.max_overflow = max_overflow82self.timeout = timeout83
84self.connection_kwargs = kwargs85self.connection_pool = QueuePool(86self._get_connection,87pool_size=self.pool_size,88max_overflow=self.max_overflow,89timeout=self.timeout,90)91
92self._create_table()93
94@property95def client(self) -> Any:96"""Return SingleStoreDB client."""97return self._get_connection()98
99@classmethod100def class_name(cls) -> str:101return "SingleStoreVectorStore"102
103def _get_connection(self) -> Any:104try:105import singlestoredb as s2106except ImportError:107raise ImportError(108"Could not import singlestoredb python package. "109"Please install it with `pip install singlestoredb`."110)111return s2.connect(**self.connection_kwargs)112
113def _create_table(self) -> None:114conn = self.connection_pool.connect()115try:116cur = conn.cursor()117try:118cur.execute(119f"""CREATE TABLE IF NOT EXISTS {self.table_name}120({self.content_field} TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci,121{self.vector_field} BLOB, {self.metadata_field} JSON);"""122)123finally:124cur.close()125finally:126conn.close()127
128def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:129"""Add nodes to index.130
131Args:
132nodes: List[BaseNode]: list of nodes with embeddings
133
134"""
135conn = self.connection_pool.connect()136cursor = conn.cursor()137try:138for node in nodes:139embedding = node.get_embedding()140metadata = node_to_metadata_dict(141node, remove_text=True, flat_metadata=self.flat_metadata142)143cursor.execute(144"INSERT INTO {} VALUES (%s, JSON_ARRAY_PACK(%s), %s)".format(145self.table_name146),147(148node.get_content(metadata_mode=MetadataMode.NONE) or "",149"[{}]".format(",".join(map(str, embedding))),150json.dumps(metadata),151),152)153finally:154cursor.close()155conn.close()156return [node.node_id for node in nodes]157
158def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:159"""160Delete nodes using with ref_doc_id.
161
162Args:
163ref_doc_id (str): The doc_id of the document to delete.
164
165"""
166conn = self.connection_pool.connect()167cursor = conn.cursor()168try:169cursor.execute(170f"DELETE FROM {self.table_name} WHERE JSON_EXTRACT_JSON(metadata, 'ref_doc_id') = %s",171('"' + ref_doc_id + '"',),172)173finally:174cursor.close()175conn.close()176
177def query(178self, query: VectorStoreQuery, filter: Optional[dict] = None, **kwargs: Any179) -> VectorStoreQueryResult:180"""181Query index for top k most similar nodes.
182
183Args:
184query (VectorStoreQuery): Contains query_embedding and similarity_top_k attributes.
185filter (Optional[dict]): A dictionary of metadata fields and values to filter by. Defaults to None.
186
187Returns:
188VectorStoreQueryResult: Contains nodes, similarities, and ids attributes.
189"""
190query_embedding = query.query_embedding191similarity_top_k = query.similarity_top_k192conn = self.connection_pool.connect()193where_clause: str = ""194where_clause_values: List[Any] = []195
196if filter:197where_clause = "WHERE "198arguments = []199
200def build_where_clause(201where_clause_values: List[Any],202sub_filter: dict,203prefix_args: Optional[List[str]] = None,204) -> None:205prefix_args = prefix_args or []206for key in sub_filter:207if isinstance(sub_filter[key], dict):208build_where_clause(209where_clause_values, sub_filter[key], [*prefix_args, key]210)211else:212arguments.append(213"JSON_EXTRACT({}, {}) = %s".format(214{self.metadata_field},215", ".join(["%s"] * (len(prefix_args) + 1)),216)217)218where_clause_values += [*prefix_args, key]219where_clause_values.append(json.dumps(sub_filter[key]))220
221build_where_clause(where_clause_values, filter)222where_clause += " AND ".join(arguments)223
224results: Sequence[Any] = []225if query_embedding:226try:227cur = conn.cursor()228formatted_vector = "[{}]".format(",".join(map(str, query_embedding)))229try:230logger.debug("vector field: %s", formatted_vector)231logger.debug("similarity_top_k: %s", similarity_top_k)232cur.execute(233f"SELECT {self.content_field}, {self.metadata_field}, "234f"DOT_PRODUCT({self.vector_field}, "235"JSON_ARRAY_PACK(%s)) as similarity_score "236f"FROM {self.table_name} {where_clause} "237f"ORDER BY similarity_score DESC LIMIT {similarity_top_k}",238(formatted_vector, *tuple(where_clause_values)),239)240results = cur.fetchall()241finally:242cur.close()243finally:244conn.close()245
246nodes = []247similarities = []248ids = []249for result in results:250text, metadata, similarity_score = result251node = metadata_dict_to_node(metadata)252node.set_content(text)253nodes.append(node)254similarities.append(similarity_score)255ids.append(node.node_id)256
257return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)258