llama-index

Форк
0
157 строк · 4.6 Кб
1
import math
2
from typing import Any, List
3

4
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
5
from llama_index.legacy.vector_stores.types import (
6
    MetadataFilters,
7
    VectorStore,
8
    VectorStoreQuery,
9
    VectorStoreQueryResult,
10
)
11
from llama_index.legacy.vector_stores.utils import (
12
    legacy_metadata_dict_to_node,
13
    metadata_dict_to_node,
14
    node_to_metadata_dict,
15
)
16

17

18
def _to_metal_filters(standard_filters: MetadataFilters) -> list:
19
    filters = []
20
    for filter in standard_filters.legacy_filters():
21
        filters.append(
22
            {
23
                "field": filter.key,
24
                "value": filter.value,
25
            }
26
        )
27
    return filters
28

29

30
class MetalVectorStore(VectorStore):
31
    def __init__(
32
        self,
33
        api_key: str,
34
        client_id: str,
35
        index_id: str,
36
    ):
37
        """Init params."""
38
        import_err_msg = (
39
            "`metal_sdk` package not found, please run `pip install metal_sdk`"
40
        )
41
        try:
42
            import metal_sdk  # noqa
43
        except ImportError:
44
            raise ImportError(import_err_msg)
45
        from metal_sdk.metal import Metal
46

47
        self.api_key = api_key
48
        self.client_id = client_id
49
        self.index_id = index_id
50

51
        self.metal_client = Metal(api_key, client_id, index_id)
52
        self.stores_text = True
53
        self.flat_metadata = False
54
        self.is_embedding_query = True
55

56
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
57
        if query.filters is not None:
58
            if "filters" in kwargs:
59
                raise ValueError(
60
                    "Cannot specify filter via both query and kwargs. "
61
                    "Use kwargs only for metal specific items that are "
62
                    "not supported via the generic query interface."
63
                )
64
            filters = _to_metal_filters(query.filters)
65
        else:
66
            filters = kwargs.get("filters", {})
67

68
        payload = {
69
            "embedding": query.query_embedding,  # Query Embedding
70
            "filters": filters,  # Metadata Filters
71
        }
72
        response = self.metal_client.search(payload, limit=query.similarity_top_k)
73

74
        nodes = []
75
        ids = []
76
        similarities = []
77

78
        for item in response["data"]:
79
            text = item["text"]
80
            id_ = item["id"]
81

82
            # load additional Node data
83
            try:
84
                node = metadata_dict_to_node(item["metadata"])
85
                node.text = text
86
            except Exception:
87
                # NOTE: deprecated legacy logic for backward compatibility
88
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
89
                    item["metadata"]
90
                )
91

92
                node = TextNode(
93
                    text=text,
94
                    id_=id_,
95
                    metadata=metadata,
96
                    start_char_idx=node_info.get("start", None),
97
                    end_char_idx=node_info.get("end", None),
98
                    relationships=relationships,
99
                )
100

101
            nodes.append(node)
102
            ids.append(id_)
103

104
            similarity_score = 1.0 - math.exp(-item["dist"])
105
            similarities.append(similarity_score)
106

107
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
108

109
    @property
110
    def client(self) -> Any:
111
        """Return Metal client."""
112
        return self.metal_client
113

114
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
115
        """Add nodes to index.
116

117
        Args:
118
            nodes: List[BaseNode]: list of nodes with embeddings.
119

120
        """
121
        if not self.metal_client:
122
            raise ValueError("metal_client not initialized")
123

124
        ids = []
125
        for node in nodes:
126
            ids.append(node.node_id)
127

128
            metadata = {}
129
            metadata["text"] = node.get_content(metadata_mode=MetadataMode.NONE) or ""
130

131
            additional_metadata = node_to_metadata_dict(
132
                node, remove_text=True, flat_metadata=self.flat_metadata
133
            )
134
            metadata.update(additional_metadata)
135

136
            payload = {
137
                "embedding": node.get_embedding(),
138
                "metadata": metadata,
139
                "id": node.node_id,
140
            }
141

142
            self.metal_client.index(payload)
143

144
        return ids
145

146
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
147
        """
148
        Delete nodes using with ref_doc_id.
149

150
        Args:
151
            ref_doc_id (str): The doc_id of the document to delete.
152

153
        """
154
        if not self.metal_client:
155
            raise ValueError("metal_client not initialized")
156

157
        self.metal_client.deleteOne(ref_doc_id)
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.