llama-index
321 строка · 10.9 Кб
1"""MyScale vector store.
2
3An index that is built on top of an existing MyScale cluster.
4
5"""
6
7import json8import logging9from typing import Any, Dict, List, Optional, cast10
11from llama_index.legacy.readers.myscale import (12MyScaleSettings,13escape_str,14format_list_to_string,15)
16from llama_index.legacy.schema import (17BaseNode,18MetadataMode,19NodeRelationship,20RelatedNodeInfo,21TextNode,22)
23from llama_index.legacy.service_context import ServiceContext24from llama_index.legacy.utils import iter_batch25from llama_index.legacy.vector_stores.types import (26VectorStore,27VectorStoreQuery,28VectorStoreQueryMode,29VectorStoreQueryResult,30)
31
32logger = logging.getLogger(__name__)33
34
35class MyScaleVectorStore(VectorStore):36"""MyScale Vector Store.37
38In this vector store, embeddings and docs are stored within an existing
39MyScale cluster.
40
41During query time, the index uses MyScale to query for the top
42k most similar nodes.
43
44Args:
45myscale_client (httpclient): clickhouse-connect httpclient of
46an existing MyScale cluster.
47table (str, optional): The name of the MyScale table
48where data will be stored. Defaults to "llama_index".
49database (str, optional): The name of the MyScale database
50where data will be stored. Defaults to "default".
51index_type (str, optional): The type of the MyScale vector index.
52Defaults to "IVFFLAT".
53metric (str, optional): The metric type of the MyScale vector index.
54Defaults to "cosine".
55batch_size (int, optional): the size of documents to insert. Defaults to 32.
56index_params (dict, optional): The index parameters for MyScale.
57Defaults to None.
58search_params (dict, optional): The search parameters for a MyScale query.
59Defaults to None.
60service_context (ServiceContext, optional): Vector store service context.
61Defaults to None
62
63"""
64
65stores_text: bool = True66_index_existed: bool = False67metadata_column: str = "metadata"68AMPLIFY_RATIO_LE5 = 10069AMPLIFY_RATIO_GT5 = 2070AMPLIFY_RATIO_GT50 = 1071
72def __init__(73self,74myscale_client: Optional[Any] = None,75table: str = "llama_index",76database: str = "default",77index_type: str = "MSTG",78metric: str = "cosine",79batch_size: int = 32,80index_params: Optional[dict] = None,81search_params: Optional[dict] = None,82service_context: Optional[ServiceContext] = None,83**kwargs: Any,84) -> None:85"""Initialize params."""86import_err_msg = """87`clickhouse_connect` package not found,
88please run `pip install clickhouse-connect`
89"""
90try:91from clickhouse_connect.driver.httpclient import HttpClient92except ImportError:93raise ImportError(import_err_msg)94
95if myscale_client is None:96raise ValueError("Missing MyScale client!")97
98self._client = cast(HttpClient, myscale_client)99self.config = MyScaleSettings(100table=table,101database=database,102index_type=index_type,103metric=metric,104batch_size=batch_size,105index_params=index_params,106search_params=search_params,107**kwargs,108)109
110# schema column name, type, and construct format method111self.column_config: Dict = {112"id": {"type": "String", "extract_func": lambda x: x.node_id},113"doc_id": {"type": "String", "extract_func": lambda x: x.ref_doc_id},114"text": {115"type": "String",116"extract_func": lambda x: escape_str(117x.get_content(metadata_mode=MetadataMode.NONE) or ""118),119},120"vector": {121"type": "Array(Float32)",122"extract_func": lambda x: format_list_to_string(x.get_embedding()),123},124"node_info": {125"type": "JSON",126"extract_func": lambda x: json.dumps(x.node_info),127},128"metadata": {129"type": "JSON",130"extract_func": lambda x: json.dumps(x.metadata),131},132}133
134if service_context is not None:135service_context = cast(ServiceContext, service_context)136dimension = len(137service_context.embed_model.get_query_embedding("try this out")138)139self._create_index(dimension)140
141@property142def client(self) -> Any:143"""Get client."""144return self._client145
146def _create_index(self, dimension: int) -> None:147index_params = (148", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_params.items()])149if self.config.index_params150else ""151)152schema_ = f"""153CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(154{",".join([f'{k} {v["type"]}' for k, v in self.column_config.items()])},155CONSTRAINT vector_length CHECK length(vector) = {dimension},156VECTOR INDEX {self.config.table}_index vector TYPE157{self.config.index_type}('metric_type={self.config.metric}'{index_params})158) ENGINE = MergeTree ORDER BY id
159"""
160self.dim = dimension161self._client.command("SET allow_experimental_object_type=1")162self._client.command(schema_)163self._index_existed = True164
165def _build_insert_statement(166self,167values: List[BaseNode],168) -> str:169_data = []170for item in values:171item_value_str = ",".join(172[173f"'{column['extract_func'](item)}'"174for column in self.column_config.values()175]176)177_data.append(f"({item_value_str})")178
179return f"""180INSERT INTO TABLE
181{self.config.database}.{self.config.table}({",".join(self.column_config.keys())})182VALUES
183{','.join(_data)}184"""
185
186def _build_hybrid_search_statement(187self, stage_one_sql: str, query_str: str, similarity_top_k: int188) -> str:189terms_pattern = [f"(?i){x}" for x in query_str.split(" ")]190column_keys = self.column_config.keys()191return (192f"SELECT {','.join(filter(lambda k: k != 'vector', column_keys))}, "193f"dist FROM ({stage_one_sql}) tempt "194f"ORDER BY length(multiMatchAllIndices(text, {terms_pattern})) "195f"AS distance1 DESC, "196f"log(1 + countMatches(text, '(?i)({query_str.replace(' ', '|')})')) "197f"AS distance2 DESC limit {similarity_top_k}"198)199
200def _append_meta_filter_condition(201self, where_str: Optional[str], exact_match_filter: list202) -> str:203filter_str = " AND ".join(204f"JSONExtractString(toJSONString("205f"{self.metadata_column}), '{filter_item.key}') "206f"= '{filter_item.value}'"207for filter_item in exact_match_filter208)209if where_str is None:210where_str = filter_str211else:212where_str = " AND " + filter_str213return where_str214
215def add(216self,217nodes: List[BaseNode],218**add_kwargs: Any,219) -> List[str]:220"""Add nodes to index.221
222Args:
223nodes: List[BaseNode]: list of nodes with embeddings
224
225"""
226if not nodes:227return []228
229if not self._index_existed:230self._create_index(len(nodes[0].get_embedding()))231
232for result_batch in iter_batch(nodes, self.config.batch_size):233insert_statement = self._build_insert_statement(values=result_batch)234self._client.command(insert_statement)235
236return [result.node_id for result in nodes]237
238def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:239"""240Delete nodes using with ref_doc_id.
241
242Args:
243ref_doc_id (str): The doc_id of the document to delete.
244
245"""
246self._client.command(247f"DELETE FROM {self.config.database}.{self.config.table} "248f"where doc_id='{ref_doc_id}'"249)250
251def drop(self) -> None:252"""Drop MyScale Index and table."""253self._client.command(254f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"255)256
257def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:258"""Query index for top k most similar nodes.259
260Args:
261query (VectorStoreQuery): query
262
263"""
264query_embedding = cast(List[float], query.query_embedding)265where_str = (266f"doc_id in {format_list_to_string(query.doc_ids)}"267if query.doc_ids268else None269)270if query.filters is not None and len(query.filters.legacy_filters()) > 0:271where_str = self._append_meta_filter_condition(272where_str, query.filters.legacy_filters()273)274
275# build query sql276query_statement = self.config.build_query_statement(277query_embed=query_embedding,278where_str=where_str,279limit=query.similarity_top_k,280)281if query.mode == VectorStoreQueryMode.HYBRID and query.query_str is not None:282amplify_ratio = self.AMPLIFY_RATIO_LE5283if 5 < query.similarity_top_k < 50:284amplify_ratio = self.AMPLIFY_RATIO_GT5285if query.similarity_top_k > 50:286amplify_ratio = self.AMPLIFY_RATIO_GT50287query_statement = self._build_hybrid_search_statement(288self.config.build_query_statement(289query_embed=query_embedding,290where_str=where_str,291limit=query.similarity_top_k * amplify_ratio,292),293query.query_str,294query.similarity_top_k,295)296logger.debug(f"hybrid query_statement={query_statement}")297nodes = []298ids = []299similarities = []300for r in self._client.query(query_statement).named_results():301start_char_idx = None302end_char_idx = None303
304if isinstance(r["node_info"], dict):305start_char_idx = r["node_info"].get("start", None)306end_char_idx = r["node_info"].get("end", None)307node = TextNode(308id_=r["id"],309text=r["text"],310metadata=r["metadata"],311start_char_idx=start_char_idx,312end_char_idx=end_char_idx,313relationships={314NodeRelationship.SOURCE: RelatedNodeInfo(node_id=r["id"])315},316)317
318nodes.append(node)319similarities.append(r["dist"])320ids.append(r["id"])321return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)322