llama-index

Форк
0
94 строки · 2.7 Кб
1
import logging
2
from typing import TYPE_CHECKING, Any, List
3

4
from llama_index.legacy.bridge.pydantic import PrivateAttr
5
from llama_index.legacy.schema import BaseNode, MetadataMode
6
from llama_index.legacy.vector_stores.types import (
7
    BasePydanticVectorStore,
8
    VectorStoreQuery,
9
    VectorStoreQueryResult,
10
)
11
from llama_index.legacy.vector_stores.utils import (
12
    metadata_dict_to_node,
13
    node_to_metadata_dict,
14
)
15

16
logger = logging.getLogger(__name__)
17
import_err_msg = (
18
    '`pgvecto_rs.sdk` package not found, please run `pip install "pgvecto_rs[sdk]"`'
19
)
20

21
if TYPE_CHECKING:
22
    from pgvecto_rs.sdk import PGVectoRs
23

24

25
class PGVectoRsStore(BasePydanticVectorStore):
26
    stores_text = True
27

28
    _client: "PGVectoRs" = PrivateAttr()
29

30
    def __init__(self, client: "PGVectoRs") -> None:
31
        try:
32
            from pgvecto_rs.sdk import PGVectoRs
33
        except ImportError:
34
            raise ImportError(import_err_msg)
35
        self._client: PGVectoRs = client
36
        super().__init__()
37

38
    @classmethod
39
    def class_name(cls) -> str:
40
        return "PGVectoRsStore"
41

42
    @property
43
    def client(self) -> Any:
44
        return self._client
45

46
    def add(
47
        self,
48
        nodes: List[BaseNode],
49
    ) -> List[str]:
50
        from pgvecto_rs.sdk import Record
51

52
        records = [
53
            Record(
54
                id=node.id_,
55
                text=node.get_content(metadata_mode=MetadataMode.NONE),
56
                meta=node_to_metadata_dict(node, remove_text=True),
57
                embedding=node.get_embedding(),
58
            )
59
            for node in nodes
60
        ]
61

62
        self._client.insert(records)
63
        return [node.id_ for node in nodes]
64

65
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
66
        from pgvecto_rs.sdk.filters import meta_contains
67

68
        self._client.delete(meta_contains({"ref_doc_id": ref_doc_id}))
69

70
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
71
        from pgvecto_rs.sdk.filters import meta_contains
72

73
        results = self._client.search(
74
            embedding=query.query_embedding,
75
            top_k=query.similarity_top_k,
76
            filter=(
77
                meta_contains(
78
                    {pair.key: pair.value for pair in query.filters.legacy_filters()}
79
                )
80
                if query.filters is not None
81
                else None
82
            ),
83
        )
84

85
        nodes = [
86
            metadata_dict_to_node(record.meta, text=record.text)
87
            for record, _ in results
88
        ]
89

90
        return VectorStoreQueryResult(
91
            nodes=nodes,
92
            similarities=[score for _, score in results],
93
            ids=[str(record.id) for record, _ in results],
94
        )
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.