llama-index

Форк
0
321 строка · 10.9 Кб
1
"""MyScale vector store.
2

3
An index that is built on top of an existing MyScale cluster.
4

5
"""
6

7
import json
8
import logging
9
from typing import Any, Dict, List, Optional, cast
10

11
from llama_index.legacy.readers.myscale import (
12
    MyScaleSettings,
13
    escape_str,
14
    format_list_to_string,
15
)
16
from llama_index.legacy.schema import (
17
    BaseNode,
18
    MetadataMode,
19
    NodeRelationship,
20
    RelatedNodeInfo,
21
    TextNode,
22
)
23
from llama_index.legacy.service_context import ServiceContext
24
from llama_index.legacy.utils import iter_batch
25
from llama_index.legacy.vector_stores.types import (
26
    VectorStore,
27
    VectorStoreQuery,
28
    VectorStoreQueryMode,
29
    VectorStoreQueryResult,
30
)
31

32
logger = logging.getLogger(__name__)
33

34

35
class MyScaleVectorStore(VectorStore):
36
    """MyScale Vector Store.
37

38
    In this vector store, embeddings and docs are stored within an existing
39
    MyScale cluster.
40

41
    During query time, the index uses MyScale to query for the top
42
    k most similar nodes.
43

44
    Args:
45
        myscale_client (httpclient): clickhouse-connect httpclient of
46
            an existing MyScale cluster.
47
        table (str, optional): The name of the MyScale table
48
            where data will be stored. Defaults to "llama_index".
49
        database (str, optional): The name of the MyScale database
50
            where data will be stored. Defaults to "default".
51
        index_type (str, optional): The type of the MyScale vector index.
52
            Defaults to "IVFFLAT".
53
        metric (str, optional): The metric type of the MyScale vector index.
54
            Defaults to "cosine".
55
        batch_size (int, optional): the size of documents to insert. Defaults to 32.
56
        index_params (dict, optional): The index parameters for MyScale.
57
            Defaults to None.
58
        search_params (dict, optional): The search parameters for a MyScale query.
59
            Defaults to None.
60
        service_context (ServiceContext, optional): Vector store service context.
61
            Defaults to None
62

63
    """
64

65
    stores_text: bool = True
66
    _index_existed: bool = False
67
    metadata_column: str = "metadata"
68
    AMPLIFY_RATIO_LE5 = 100
69
    AMPLIFY_RATIO_GT5 = 20
70
    AMPLIFY_RATIO_GT50 = 10
71

72
    def __init__(
73
        self,
74
        myscale_client: Optional[Any] = None,
75
        table: str = "llama_index",
76
        database: str = "default",
77
        index_type: str = "MSTG",
78
        metric: str = "cosine",
79
        batch_size: int = 32,
80
        index_params: Optional[dict] = None,
81
        search_params: Optional[dict] = None,
82
        service_context: Optional[ServiceContext] = None,
83
        **kwargs: Any,
84
    ) -> None:
85
        """Initialize params."""
86
        import_err_msg = """
87
            `clickhouse_connect` package not found,
88
            please run `pip install clickhouse-connect`
89
        """
90
        try:
91
            from clickhouse_connect.driver.httpclient import HttpClient
92
        except ImportError:
93
            raise ImportError(import_err_msg)
94

95
        if myscale_client is None:
96
            raise ValueError("Missing MyScale client!")
97

98
        self._client = cast(HttpClient, myscale_client)
99
        self.config = MyScaleSettings(
100
            table=table,
101
            database=database,
102
            index_type=index_type,
103
            metric=metric,
104
            batch_size=batch_size,
105
            index_params=index_params,
106
            search_params=search_params,
107
            **kwargs,
108
        )
109

110
        # schema column name, type, and construct format method
111
        self.column_config: Dict = {
112
            "id": {"type": "String", "extract_func": lambda x: x.node_id},
113
            "doc_id": {"type": "String", "extract_func": lambda x: x.ref_doc_id},
114
            "text": {
115
                "type": "String",
116
                "extract_func": lambda x: escape_str(
117
                    x.get_content(metadata_mode=MetadataMode.NONE) or ""
118
                ),
119
            },
120
            "vector": {
121
                "type": "Array(Float32)",
122
                "extract_func": lambda x: format_list_to_string(x.get_embedding()),
123
            },
124
            "node_info": {
125
                "type": "JSON",
126
                "extract_func": lambda x: json.dumps(x.node_info),
127
            },
128
            "metadata": {
129
                "type": "JSON",
130
                "extract_func": lambda x: json.dumps(x.metadata),
131
            },
132
        }
133

134
        if service_context is not None:
135
            service_context = cast(ServiceContext, service_context)
136
            dimension = len(
137
                service_context.embed_model.get_query_embedding("try this out")
138
            )
139
            self._create_index(dimension)
140

141
    @property
142
    def client(self) -> Any:
143
        """Get client."""
144
        return self._client
145

146
    def _create_index(self, dimension: int) -> None:
147
        index_params = (
148
            ", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_params.items()])
149
            if self.config.index_params
150
            else ""
151
        )
152
        schema_ = f"""
153
            CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
154
                {",".join([f'{k} {v["type"]}' for k, v in self.column_config.items()])},
155
                CONSTRAINT vector_length CHECK length(vector) = {dimension},
156
                VECTOR INDEX {self.config.table}_index vector TYPE
157
                {self.config.index_type}('metric_type={self.config.metric}'{index_params})
158
            ) ENGINE = MergeTree ORDER BY id
159
            """
160
        self.dim = dimension
161
        self._client.command("SET allow_experimental_object_type=1")
162
        self._client.command(schema_)
163
        self._index_existed = True
164

165
    def _build_insert_statement(
166
        self,
167
        values: List[BaseNode],
168
    ) -> str:
169
        _data = []
170
        for item in values:
171
            item_value_str = ",".join(
172
                [
173
                    f"'{column['extract_func'](item)}'"
174
                    for column in self.column_config.values()
175
                ]
176
            )
177
            _data.append(f"({item_value_str})")
178

179
        return f"""
180
                INSERT INTO TABLE
181
                    {self.config.database}.{self.config.table}({",".join(self.column_config.keys())})
182
                VALUES
183
                    {','.join(_data)}
184
                """
185

186
    def _build_hybrid_search_statement(
187
        self, stage_one_sql: str, query_str: str, similarity_top_k: int
188
    ) -> str:
189
        terms_pattern = [f"(?i){x}" for x in query_str.split(" ")]
190
        column_keys = self.column_config.keys()
191
        return (
192
            f"SELECT {','.join(filter(lambda k: k != 'vector', column_keys))}, "
193
            f"dist FROM ({stage_one_sql}) tempt "
194
            f"ORDER BY length(multiMatchAllIndices(text, {terms_pattern})) "
195
            f"AS distance1 DESC, "
196
            f"log(1 + countMatches(text, '(?i)({query_str.replace(' ', '|')})')) "
197
            f"AS distance2 DESC limit {similarity_top_k}"
198
        )
199

200
    def _append_meta_filter_condition(
201
        self, where_str: Optional[str], exact_match_filter: list
202
    ) -> str:
203
        filter_str = " AND ".join(
204
            f"JSONExtractString(toJSONString("
205
            f"{self.metadata_column}), '{filter_item.key}') "
206
            f"= '{filter_item.value}'"
207
            for filter_item in exact_match_filter
208
        )
209
        if where_str is None:
210
            where_str = filter_str
211
        else:
212
            where_str = " AND " + filter_str
213
        return where_str
214

215
    def add(
216
        self,
217
        nodes: List[BaseNode],
218
        **add_kwargs: Any,
219
    ) -> List[str]:
220
        """Add nodes to index.
221

222
        Args:
223
            nodes: List[BaseNode]: list of nodes with embeddings
224

225
        """
226
        if not nodes:
227
            return []
228

229
        if not self._index_existed:
230
            self._create_index(len(nodes[0].get_embedding()))
231

232
        for result_batch in iter_batch(nodes, self.config.batch_size):
233
            insert_statement = self._build_insert_statement(values=result_batch)
234
            self._client.command(insert_statement)
235

236
        return [result.node_id for result in nodes]
237

238
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
239
        """
240
        Delete nodes using with ref_doc_id.
241

242
        Args:
243
            ref_doc_id (str): The doc_id of the document to delete.
244

245
        """
246
        self._client.command(
247
            f"DELETE FROM {self.config.database}.{self.config.table} "
248
            f"where doc_id='{ref_doc_id}'"
249
        )
250

251
    def drop(self) -> None:
252
        """Drop MyScale Index and table."""
253
        self._client.command(
254
            f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
255
        )
256

257
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
258
        """Query index for top k most similar nodes.
259

260
        Args:
261
            query (VectorStoreQuery): query
262

263
        """
264
        query_embedding = cast(List[float], query.query_embedding)
265
        where_str = (
266
            f"doc_id in {format_list_to_string(query.doc_ids)}"
267
            if query.doc_ids
268
            else None
269
        )
270
        if query.filters is not None and len(query.filters.legacy_filters()) > 0:
271
            where_str = self._append_meta_filter_condition(
272
                where_str, query.filters.legacy_filters()
273
            )
274

275
        # build query sql
276
        query_statement = self.config.build_query_statement(
277
            query_embed=query_embedding,
278
            where_str=where_str,
279
            limit=query.similarity_top_k,
280
        )
281
        if query.mode == VectorStoreQueryMode.HYBRID and query.query_str is not None:
282
            amplify_ratio = self.AMPLIFY_RATIO_LE5
283
            if 5 < query.similarity_top_k < 50:
284
                amplify_ratio = self.AMPLIFY_RATIO_GT5
285
            if query.similarity_top_k > 50:
286
                amplify_ratio = self.AMPLIFY_RATIO_GT50
287
            query_statement = self._build_hybrid_search_statement(
288
                self.config.build_query_statement(
289
                    query_embed=query_embedding,
290
                    where_str=where_str,
291
                    limit=query.similarity_top_k * amplify_ratio,
292
                ),
293
                query.query_str,
294
                query.similarity_top_k,
295
            )
296
            logger.debug(f"hybrid query_statement={query_statement}")
297
        nodes = []
298
        ids = []
299
        similarities = []
300
        for r in self._client.query(query_statement).named_results():
301
            start_char_idx = None
302
            end_char_idx = None
303

304
            if isinstance(r["node_info"], dict):
305
                start_char_idx = r["node_info"].get("start", None)
306
                end_char_idx = r["node_info"].get("end", None)
307
            node = TextNode(
308
                id_=r["id"],
309
                text=r["text"],
310
                metadata=r["metadata"],
311
                start_char_idx=start_char_idx,
312
                end_char_idx=end_char_idx,
313
                relationships={
314
                    NodeRelationship.SOURCE: RelatedNodeInfo(node_id=r["id"])
315
                },
316
            )
317

318
            nodes.append(node)
319
            similarities.append(r["dist"])
320
            ids.append(r["id"])
321
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
322

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.