llama-index

Форк
0
547 строк · 19.1 Кб
1
"""Tencent Vector store index.
2

3
An index that is built with Tencent Vector Database.
4

5
"""
6

7
import json
8
from typing import Any, Dict, List, Optional
9

10
from llama_index.legacy.schema import (
11
    BaseNode,
12
    NodeRelationship,
13
    RelatedNodeInfo,
14
    TextNode,
15
)
16
from llama_index.legacy.vector_stores.types import (
17
    VectorStore,
18
    VectorStoreQuery,
19
    VectorStoreQueryResult,
20
)
21
from llama_index.legacy.vector_stores.utils import DEFAULT_DOC_ID_KEY, DEFAULT_TEXT_KEY
22

23
DEFAULT_USERNAME = "root"
24
DEFAULT_DATABASE_NAME = "llama_default_database"
25
DEFAULT_COLLECTION_NAME = "llama_default_collection"
26
DEFAULT_COLLECTION_DESC = "Collection for llama index"
27
DEFAULT_TIMEOUT: int = 30
28

29
DEFAULT_SHARD = 1
30
DEFAULT_REPLICAS = 2
31
DEFAULT_INDEX_TYPE = "HNSW"
32
DEFAULT_METRIC_TYPE = "COSINE"
33

34
DEFAULT_HNSW_M = 16
35
DEFAULT_HNSW_EF = 200
36
DEFAULT_IVF_NLIST = 128
37
DEFAULT_IVF_PQ_M = 16
38

39
FIELD_ID: str = "id"
40
FIELD_VECTOR: str = "vector"
41
FIELD_METADATA: str = "metadata"
42

43
READ_CONSISTENCY = "read_consistency"
44
READ_STRONG_CONSISTENCY = "strongConsistency"
45
READ_EVENTUAL_CONSISTENCY = "eventualConsistency"
46
READ_CONSISTENCY_VALUES = "['strongConsistency', 'eventualConsistency']"
47

48
VALUE_NONE_ERROR = "Parameter `{}` can not be None."
49
VALUE_RANGE_ERROR = "The value of parameter `{}` must be within {}."
50
NOT_SUPPORT_INDEX_TYPE_ERROR = (
51
    "Unsupported index type: `{}`, supported index types are {}"
52
)
53
NOT_SUPPORT_METRIC_TYPE_ERROR = (
54
    "Unsupported metric type: `{}`, supported metric types are {}"
55
)
56

57

58
def _try_import() -> None:
59
    try:
60
        import tcvectordb  # noqa
61
    except ImportError:
62
        raise ImportError(
63
            "`tcvectordb` package not found, please run `pip install tcvectordb`"
64
        )
65

66

67
class FilterField:
68
    name: str
69
    data_type: str = "string"
70

71
    def __init__(self, name: str, data_type: str = "string"):
72
        self.name = name
73
        self.data_type = "string" if data_type is None else data_type
74

75
    def match_value(self, value: Any) -> bool:
76
        if self.data_type == "uint64":
77
            return isinstance(value, int)
78
        else:
79
            return isinstance(value, str)
80

81
    def to_vdb_filter(self) -> Any:
82
        from tcvectordb.model.enum import FieldType, IndexType
83
        from tcvectordb.model.index import FilterIndex
84

85
        return FilterIndex(
86
            name=self.name,
87
            field_type=FieldType(self.data_type),
88
            index_type=IndexType.FILTER,
89
        )
90

91

92
class CollectionParams:
93
    r"""Tencent vector DB Collection params.
94
    See the following documentation for details:
95
    https://cloud.tencent.com/document/product/1709/95826.
96

97
    Args:
98
        dimension int: The dimension of vector.
99
        shard int: The number of shards in the collection.
100
        replicas int: The number of replicas in the collection.
101
        index_type (Optional[str]): HNSW, IVF_FLAT, IVF_PQ, IVF_SQ8... Default value is "HNSW"
102
        metric_type (Optional[str]): L2, COSINE, IP. Default value is "COSINE"
103
        drop_exists (Optional[bool]): Delete the existing Collection. Default value is False.
104
        vector_params (Optional[Dict]):
105
          if HNSW set parameters: `M` and `efConstruction`, for example `{'M': 16, efConstruction: 200}`
106
          if IVF_FLAT or IVF_SQ8 set parameter: `nlist`
107
          if IVF_PQ set parameters: `M` and `nlist`
108
          default is HNSW
109
        filter_fields: Optional[List[FilterField]]: Set the fields for filtering
110
          for example: [FilterField(name='author'), FilterField(name='age', data_type=uint64)]
111
          This can be used when calling the query method:
112
             store.add([
113
                TextNode(..., metadata={'age'=23, 'name'='name1'})
114
            ])
115
             ...
116
             query = VectorStoreQuery(...)
117
             store.query(query, filter="age > 20 and age < 40 and name in (\"name1\", \"name2\")")
118
    """
119

120
    def __init__(
121
        self,
122
        dimension: int,
123
        collection_name: str = DEFAULT_COLLECTION_NAME,
124
        collection_description: str = DEFAULT_COLLECTION_DESC,
125
        shard: int = DEFAULT_SHARD,
126
        replicas: int = DEFAULT_REPLICAS,
127
        index_type: str = DEFAULT_INDEX_TYPE,
128
        metric_type: str = DEFAULT_METRIC_TYPE,
129
        drop_exists: Optional[bool] = False,
130
        vector_params: Optional[Dict] = None,
131
        filter_fields: Optional[List[FilterField]] = [],
132
    ):
133
        self.collection_name = collection_name
134
        self.collection_description = collection_description
135
        self.dimension = dimension
136
        self.shard = shard
137
        self.replicas = replicas
138
        self.index_type = index_type
139
        self.metric_type = metric_type
140
        self.vector_params = vector_params
141
        self.drop_exists = drop_exists
142
        self.filter_fields = filter_fields or []
143

144

145
class TencentVectorDB(VectorStore):
146
    """Tencent Vector Store.
147

148
    In this vector store, embeddings and docs are stored within a Collection.
149
    If the Collection does not exist, it will be automatically created.
150

151
    In order to use this you need to have a database instance.
152
    See the following documentation for details:
153
    https://cloud.tencent.com/document/product/1709/94951
154

155
    Args:
156
        url (Optional[str]): url of Tencent vector database
157
        username (Optional[str]): The username for Tencent vector database. Default value is "root"
158
        key (Optional[str]): The Api-Key for Tencent vector database
159
        collection_params (Optional[CollectionParams]): The collection parameters for vector database
160

161
    """
162

163
    stores_text: bool = True
164
    filter_fields: List[FilterField] = []
165

166
    def __init__(
167
        self,
168
        url: str,
169
        key: str,
170
        username: str = DEFAULT_USERNAME,
171
        database_name: str = DEFAULT_DATABASE_NAME,
172
        read_consistency: str = READ_EVENTUAL_CONSISTENCY,
173
        collection_params: CollectionParams = CollectionParams(dimension=1536),
174
        batch_size: int = 512,
175
        **kwargs: Any,
176
    ):
177
        """Init params."""
178
        self._init_client(url, username, key, read_consistency)
179
        self._create_database_if_not_exists(database_name)
180
        self._create_collection(database_name, collection_params)
181
        self._init_filter_fields()
182
        self.batch_size = batch_size
183

184
    def _init_filter_fields(self) -> None:
185
        fields = vars(self.collection).get("indexes", [])
186
        for field in fields:
187
            if field["fieldName"] not in [FIELD_ID, DEFAULT_DOC_ID_KEY, FIELD_VECTOR]:
188
                self.filter_fields.append(
189
                    FilterField(name=field["fieldName"], data_type=field["fieldType"])
190
                )
191

192
    @classmethod
193
    def class_name(cls) -> str:
194
        return "TencentVectorDB"
195

196
    @classmethod
197
    def from_params(
198
        cls,
199
        url: str,
200
        key: str,
201
        username: str = DEFAULT_USERNAME,
202
        database_name: str = DEFAULT_DATABASE_NAME,
203
        read_consistency: str = READ_EVENTUAL_CONSISTENCY,
204
        collection_params: CollectionParams = CollectionParams(dimension=1536),
205
        batch_size: int = 512,
206
        **kwargs: Any,
207
    ) -> "TencentVectorDB":
208
        _try_import()
209
        return cls(
210
            url=url,
211
            username=username,
212
            key=key,
213
            database_name=database_name,
214
            read_consistency=read_consistency,
215
            collection_params=collection_params,
216
            batch_size=batch_size,
217
            **kwargs,
218
        )
219

220
    def _init_client(
221
        self, url: str, username: str, key: str, read_consistency: str
222
    ) -> None:
223
        import tcvectordb
224
        from tcvectordb.model.enum import ReadConsistency
225

226
        if read_consistency is None:
227
            raise ValueError(VALUE_RANGE_ERROR.format(read_consistency))
228

229
        try:
230
            v_read_consistency = ReadConsistency(read_consistency)
231
        except ValueError:
232
            raise ValueError(
233
                VALUE_RANGE_ERROR.format(READ_CONSISTENCY, READ_CONSISTENCY_VALUES)
234
            )
235

236
        self.tencent_client = tcvectordb.VectorDBClient(
237
            url=url,
238
            username=username,
239
            key=key,
240
            read_consistency=v_read_consistency,
241
            timeout=DEFAULT_TIMEOUT,
242
        )
243

244
    def _create_database_if_not_exists(self, database_name: str) -> None:
245
        db_list = self.tencent_client.list_databases()
246

247
        if database_name in [db.database_name for db in db_list]:
248
            self.database = self.tencent_client.database(database_name)
249
        else:
250
            self.database = self.tencent_client.create_database(database_name)
251

252
    def _create_collection(
253
        self, database_name: str, collection_params: CollectionParams
254
    ) -> None:
255
        import tcvectordb
256

257
        collection_name: str = self._compute_collection_name(
258
            database_name, collection_params
259
        )
260
        collection_description = collection_params.collection_description
261

262
        if collection_params is None:
263
            raise ValueError(VALUE_NONE_ERROR.format("collection_params"))
264

265
        try:
266
            self.collection = self.database.describe_collection(collection_name)
267
            if collection_params.drop_exists:
268
                self.database.drop_collection(collection_name)
269
                self._create_collection_in_db(
270
                    collection_name, collection_description, collection_params
271
                )
272
        except tcvectordb.exceptions.VectorDBException:
273
            self._create_collection_in_db(
274
                collection_name, collection_description, collection_params
275
            )
276

277
    @staticmethod
278
    def _compute_collection_name(
279
        database_name: str, collection_params: CollectionParams
280
    ) -> str:
281
        if database_name == DEFAULT_DATABASE_NAME:
282
            return collection_params.collection_name
283
        if collection_params.collection_name != DEFAULT_COLLECTION_NAME:
284
            return collection_params.collection_name
285
        else:
286
            return database_name + "_" + DEFAULT_COLLECTION_NAME
287

288
    def _create_collection_in_db(
289
        self,
290
        collection_name: str,
291
        collection_description: str,
292
        collection_params: CollectionParams,
293
    ) -> None:
294
        from tcvectordb.model.enum import FieldType, IndexType
295
        from tcvectordb.model.index import FilterIndex, Index, VectorIndex
296

297
        index_type = self._get_index_type(collection_params.index_type)
298
        metric_type = self._get_metric_type(collection_params.metric_type)
299
        index_param = self._get_index_params(index_type, collection_params)
300
        index = Index(
301
            FilterIndex(
302
                name=FIELD_ID,
303
                field_type=FieldType.String,
304
                index_type=IndexType.PRIMARY_KEY,
305
            ),
306
            FilterIndex(
307
                name=DEFAULT_DOC_ID_KEY,
308
                field_type=FieldType.String,
309
                index_type=IndexType.FILTER,
310
            ),
311
            VectorIndex(
312
                name=FIELD_VECTOR,
313
                dimension=collection_params.dimension,
314
                index_type=index_type,
315
                metric_type=metric_type,
316
                params=index_param,
317
            ),
318
        )
319
        for field in collection_params.filter_fields:
320
            index.add(field.to_vdb_filter())
321

322
        self.collection = self.database.create_collection(
323
            name=collection_name,
324
            shard=collection_params.shard,
325
            replicas=collection_params.replicas,
326
            description=collection_description,
327
            index=index,
328
        )
329

330
    @staticmethod
331
    def _get_index_params(index_type: Any, collection_params: CollectionParams) -> None:
332
        from tcvectordb.model.enum import IndexType
333
        from tcvectordb.model.index import (
334
            HNSWParams,
335
            IVFFLATParams,
336
            IVFPQParams,
337
            IVFSQ4Params,
338
            IVFSQ8Params,
339
            IVFSQ16Params,
340
        )
341

342
        vector_params = (
343
            {}
344
            if collection_params.vector_params is None
345
            else collection_params.vector_params
346
        )
347

348
        if index_type == IndexType.HNSW:
349
            return HNSWParams(
350
                m=vector_params.get("M", DEFAULT_HNSW_M),
351
                efconstruction=vector_params.get("efConstruction", DEFAULT_HNSW_EF),
352
            )
353
        elif index_type == IndexType.IVF_FLAT:
354
            return IVFFLATParams(nlist=vector_params.get("nlist", DEFAULT_IVF_NLIST))
355
        elif index_type == IndexType.IVF_PQ:
356
            return IVFPQParams(
357
                m=vector_params.get("M", DEFAULT_IVF_PQ_M),
358
                nlist=vector_params.get("nlist", DEFAULT_IVF_NLIST),
359
            )
360
        elif index_type == IndexType.IVF_SQ4:
361
            return IVFSQ4Params(nlist=vector_params.get("nlist", DEFAULT_IVF_NLIST))
362
        elif index_type == IndexType.IVF_SQ8:
363
            return IVFSQ8Params(nlist=vector_params.get("nlist", DEFAULT_IVF_NLIST))
364
        elif index_type == IndexType.IVF_SQ16:
365
            return IVFSQ16Params(nlist=vector_params.get("nlist", DEFAULT_IVF_NLIST))
366
        return None
367

368
    @staticmethod
369
    def _get_index_type(index_type_value: str) -> Any:
370
        from tcvectordb.model.enum import IndexType
371

372
        index_type_value = index_type_value or IndexType.HNSW
373
        try:
374
            return IndexType(index_type_value)
375
        except ValueError:
376
            support_index_types = [d.value for d in IndexType.__members__.values()]
377
            raise ValueError(
378
                NOT_SUPPORT_INDEX_TYPE_ERROR.format(
379
                    index_type_value, support_index_types
380
                )
381
            )
382

383
    @staticmethod
384
    def _get_metric_type(metric_type_value: str) -> Any:
385
        from tcvectordb.model.enum import MetricType
386

387
        metric_type_value = metric_type_value or MetricType.COSINE
388
        try:
389
            return MetricType(metric_type_value.upper())
390
        except ValueError:
391
            support_metric_types = [d.value for d in MetricType.__members__.values()]
392
            raise ValueError(
393
                NOT_SUPPORT_METRIC_TYPE_ERROR.format(
394
                    metric_type_value, support_metric_types
395
                )
396
            )
397

398
    @property
399
    def client(self) -> Any:
400
        """Get client."""
401
        return self.tencent_client
402

403
    def add(
404
        self,
405
        nodes: List[BaseNode],
406
        **add_kwargs: Any,
407
    ) -> List[str]:
408
        """Add nodes to index.
409

410
        Args:
411
            nodes: List[BaseNode]: list of nodes with embeddings
412

413
        """
414
        from tcvectordb.model.document import Document
415

416
        ids = []
417
        entries = []
418
        for node in nodes:
419
            document = Document(id=node.node_id, vector=node.get_embedding())
420
            if node.ref_doc_id is not None:
421
                document.__dict__[DEFAULT_DOC_ID_KEY] = node.ref_doc_id
422
            if node.metadata is not None:
423
                document.__dict__[FIELD_METADATA] = json.dumps(node.metadata)
424
                for field in self.filter_fields:
425
                    v = node.metadata.get(field.name)
426
                    if field.match_value(v):
427
                        document.__dict__[field.name] = v
428
            if isinstance(node, TextNode) and node.text is not None:
429
                document.__dict__[DEFAULT_TEXT_KEY] = node.text
430

431
            entries.append(document)
432
            ids.append(node.node_id)
433

434
            if len(entries) >= self.batch_size:
435
                self.collection.upsert(
436
                    documents=entries, build_index=True, timeout=DEFAULT_TIMEOUT
437
                )
438
                entries = []
439

440
        if len(entries) > 0:
441
            self.collection.upsert(
442
                documents=entries, build_index=True, timeout=DEFAULT_TIMEOUT
443
            )
444

445
        return ids
446

447
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
448
        """
449
        Delete nodes using with ref_doc_id or ids.
450

451
        Args:
452
            ref_doc_id (str): The doc_id of the document to delete.
453

454
        """
455
        if ref_doc_id is None or len(ref_doc_id) == 0:
456
            return
457

458
        from tcvectordb.model.document import Filter
459

460
        delete_ids = ref_doc_id if isinstance(ref_doc_id, list) else [ref_doc_id]
461
        self.collection.delete(filter=Filter(Filter.In(DEFAULT_DOC_ID_KEY, delete_ids)))
462

463
    def query_by_ids(self, ids: List[str]) -> List[Dict]:
464
        return self.collection.query(document_ids=ids, limit=len(ids))
465

466
    def truncate(self) -> None:
467
        self.database.truncate_collection(self.collection.collection_name)
468

469
    def describe_collection(self) -> Any:
470
        return self.database.describe_collection(self.collection.collection_name)
471

472
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
473
        """Query index for top k most similar nodes.
474

475
        Args:
476
            query (VectorStoreQuery): contains
477
                query_embedding (List[float]): query embedding
478
                similarity_top_k (int): top k most similar nodes
479
                doc_ids (Optional[List[str]]): filter by doc_id
480
                filters (Optional[MetadataFilters]): filter result
481
            kwargs.filter (Optional[str|Filter]):
482

483
            if `kwargs` in kwargs:
484
               using filter: `age > 20 and author in (...) and ...`
485
            elif query.filters:
486
               using filter: " and ".join([f'{f.key} = "{f.value}"' for f in query.filters.filters])
487
            elif query.doc_ids:
488
               using filter: `doc_id in (query.doc_ids)`
489
        """
490
        search_filter = self._to_vdb_filter(query, **kwargs)
491
        results = self.collection.search(
492
            vectors=[query.query_embedding],
493
            limit=query.similarity_top_k,
494
            retrieve_vector=True,
495
            output_fields=query.output_fields,
496
            filter=search_filter,
497
        )
498
        if len(results) == 0:
499
            return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
500

501
        nodes = []
502
        similarities = []
503
        ids = []
504
        for doc in results[0]:
505
            ids.append(doc.get(FIELD_ID))
506
            similarities.append(doc.get("score"))
507

508
            meta_str = doc.get(FIELD_METADATA)
509
            meta = {} if meta_str is None else json.loads(meta_str)
510
            doc_id = doc.get(DEFAULT_DOC_ID_KEY)
511

512
            node = TextNode(
513
                id_=doc.get(FIELD_ID),
514
                text=doc.get(DEFAULT_TEXT_KEY),
515
                embedding=doc.get(FIELD_VECTOR),
516
                metadata=meta,
517
            )
518
            if doc_id is not None:
519
                node.relationships = {
520
                    NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc_id)
521
                }
522

523
            nodes.append(node)
524

525
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
526

527
    @staticmethod
528
    def _to_vdb_filter(query: VectorStoreQuery, **kwargs: Any) -> Any:
529
        from tcvectordb.model.document import Filter
530

531
        search_filter = None
532
        if "filter" in kwargs:
533
            search_filter = kwargs.pop("filter")
534
            search_filter = (
535
                search_filter
536
                if type(search_filter) is Filter
537
                else Filter(search_filter)
538
            )
539
        elif query.filters is not None and len(query.filters.legacy_filters()) > 0:
540
            search_filter = " and ".join(
541
                [f'{f.key} = "{f.value}"' for f in query.filters.legacy_filters()]
542
            )
543
            search_filter = Filter(search_filter)
544
        elif query.doc_ids is not None:
545
            search_filter = Filter(Filter.In(DEFAULT_DOC_ID_KEY, query.doc_ids))
546

547
        return search_filter
548

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.