llama-index

Форк
0
149 строк · 5.1 Кб
1
"""DynamoDB vector store index."""
2

3
from __future__ import annotations
4

5
from logging import getLogger
6
from typing import Any, Dict, List, cast
7

8
from llama_index.legacy.indices.query.embedding_utils import (
9
    get_top_k_embeddings,
10
    get_top_k_embeddings_learner,
11
)
12
from llama_index.legacy.schema import BaseNode
13
from llama_index.legacy.storage.kvstore.dynamodb_kvstore import DynamoDBKVStore
14
from llama_index.legacy.vector_stores.types import (
15
    VectorStore,
16
    VectorStoreQuery,
17
    VectorStoreQueryMode,
18
    VectorStoreQueryResult,
19
)
20

21
logger = getLogger(__name__)
22

23
DEFAULT_NAMESPACE = "vector_store"
24

25
LEARNER_MODES = {
26
    VectorStoreQueryMode.SVM,
27
    VectorStoreQueryMode.LINEAR_REGRESSION,
28
    VectorStoreQueryMode.LOGISTIC_REGRESSION,
29
}
30

31

32
class DynamoDBVectorStore(VectorStore):
33
    """DynamoDB Vector Store.
34

35
    In this vector store, embeddings are stored within dynamodb table.
36
    This class was implemented with reference to SimpleVectorStore.
37

38
    Args:
39
        dynamodb_kvstore (DynamoDBKVStore): data store
40
        namespace (Optional[str]): namespace
41
    """
42

43
    stores_text: bool = False
44

45
    def __init__(
46
        self, dynamodb_kvstore: DynamoDBKVStore, namespace: str | None = None
47
    ) -> None:
48
        """Initialize params."""
49
        self._kvstore = dynamodb_kvstore
50
        namespace = namespace or DEFAULT_NAMESPACE
51
        self._collection_embedding = f"{namespace}/embedding"
52
        self._collection_text_id_to_doc_id = f"{namespace}/text_id_to_doc_id"
53
        self._key_value = "value"
54

55
    @classmethod
56
    def from_table_name(
57
        cls, table_name: str, namespace: str | None = None
58
    ) -> DynamoDBVectorStore:
59
        """Load from DynamoDB table name."""
60
        dynamodb_kvstore = DynamoDBKVStore.from_table_name(table_name=table_name)
61
        return cls(dynamodb_kvstore=dynamodb_kvstore, namespace=namespace)
62

63
    @property
64
    def client(self) -> None:
65
        """Get client."""
66
        return
67

68
    def get(self, text_id: str) -> List[float]:
69
        """Get embedding."""
70
        item = self._kvstore.get(key=text_id, collection=self._collection_embedding)
71
        item = cast(Dict[str, List[float]], item)
72
        return item[self._key_value]
73

74
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
75
        """Add nodes to index."""
76
        response = []
77
        for node in nodes:
78
            self._kvstore.put(
79
                key=node.node_id,
80
                val={self._key_value: node.get_embedding()},
81
                collection=self._collection_embedding,
82
            )
83
            self._kvstore.put(
84
                key=node.node_id,
85
                val={self._key_value: node.ref_doc_id},
86
                collection=self._collection_text_id_to_doc_id,
87
            )
88
            response.append(node.node_id)
89
        return response
90

91
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
92
        """
93
        Delete nodes using with ref_doc_id.
94

95
        Args:
96
            ref_doc_id (str): The doc_id of the document to delete.
97

98
        """
99
        text_ids_to_delete = set()
100
        for text_id, item in self._kvstore.get_all(
101
            collection=self._collection_text_id_to_doc_id
102
        ).items():
103
            if ref_doc_id == item[self._key_value]:
104
                text_ids_to_delete.add(text_id)
105

106
        for text_id in text_ids_to_delete:
107
            self._kvstore.delete(key=text_id, collection=self._collection_embedding)
108
            self._kvstore.delete(
109
                key=text_id, collection=self._collection_text_id_to_doc_id
110
            )
111

112
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
113
        """Get nodes for response."""
114
        if query.filters is not None:
115
            raise ValueError(
116
                "Metadata filters not implemented for SimpleVectorStore yet."
117
            )
118

119
        # TODO: consolidate with get_query_text_embedding_similarities
120
        items = self._kvstore.get_all(collection=self._collection_embedding).items()
121

122
        if query.node_ids:
123
            available_ids = set(query.node_ids)
124

125
            node_ids = [k for k, _ in items if k in available_ids]
126
            embeddings = [v[self._key_value] for k, v in items if k in available_ids]
127
        else:
128
            node_ids = [k for k, _ in items]
129
            embeddings = [v[self._key_value] for k, v in items]
130

131
        query_embedding = cast(List[float], query.query_embedding)
132
        if query.mode in LEARNER_MODES:
133
            top_similarities, top_ids = get_top_k_embeddings_learner(
134
                query_embedding=query_embedding,
135
                embeddings=embeddings,
136
                similarity_top_k=query.similarity_top_k,
137
                embedding_ids=node_ids,
138
            )
139
        elif query.mode == VectorStoreQueryMode.DEFAULT:
140
            top_similarities, top_ids = get_top_k_embeddings(
141
                query_embedding=query_embedding,
142
                embeddings=embeddings,
143
                similarity_top_k=query.similarity_top_k,
144
                embedding_ids=node_ids,
145
            )
146
        else:
147
            raise ValueError(f"Invalid query mode: {query.mode}")
148

149
        return VectorStoreQueryResult(similarities=top_similarities, ids=top_ids)
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.