llama-index

Форк
0
322 строки · 11.0 Кб
1
"""Simple vector store index."""
2

3
import json
4
import logging
5
import os
6
from dataclasses import dataclass, field
7
from typing import Any, Callable, Dict, List, Mapping, Optional, cast
8

9
import fsspec
10
from dataclasses_json import DataClassJsonMixin
11

12
from llama_index.legacy.indices.query.embedding_utils import (
13
    get_top_k_embeddings,
14
    get_top_k_embeddings_learner,
15
    get_top_k_mmr_embeddings,
16
)
17
from llama_index.legacy.schema import BaseNode
18
from llama_index.legacy.utils import concat_dirs
19
from llama_index.legacy.vector_stores.types import (
20
    DEFAULT_PERSIST_DIR,
21
    DEFAULT_PERSIST_FNAME,
22
    MetadataFilters,
23
    VectorStore,
24
    VectorStoreQuery,
25
    VectorStoreQueryMode,
26
    VectorStoreQueryResult,
27
)
28
from llama_index.legacy.vector_stores.utils import node_to_metadata_dict
29

30
logger = logging.getLogger(__name__)
31

32
LEARNER_MODES = {
33
    VectorStoreQueryMode.SVM,
34
    VectorStoreQueryMode.LINEAR_REGRESSION,
35
    VectorStoreQueryMode.LOGISTIC_REGRESSION,
36
}
37

38
MMR_MODE = VectorStoreQueryMode.MMR
39

40
NAMESPACE_SEP = "__"
41
DEFAULT_VECTOR_STORE = "default"
42

43

44
def _build_metadata_filter_fn(
45
    metadata_lookup_fn: Callable[[str], Mapping[str, Any]],
46
    metadata_filters: Optional[MetadataFilters] = None,
47
) -> Callable[[str], bool]:
48
    """Build metadata filter function."""
49
    filter_list = metadata_filters.legacy_filters() if metadata_filters else []
50
    if not filter_list:
51
        return lambda _: True
52

53
    def filter_fn(node_id: str) -> bool:
54
        metadata = metadata_lookup_fn(node_id)
55
        for filter_ in filter_list:
56
            metadata_value = metadata.get(filter_.key, None)
57
            if metadata_value is None:
58
                return False
59
            elif isinstance(metadata_value, list):
60
                if filter_.value not in metadata_value:
61
                    return False
62
            elif isinstance(metadata_value, (int, float, str, bool)):
63
                if metadata_value != filter_.value:
64
                    return False
65
        return True
66

67
    return filter_fn
68

69

70
@dataclass
71
class SimpleVectorStoreData(DataClassJsonMixin):
72
    """Simple Vector Store Data container.
73

74
    Args:
75
        embedding_dict (Optional[dict]): dict mapping node_ids to embeddings.
76
        text_id_to_ref_doc_id (Optional[dict]):
77
            dict mapping text_ids/node_ids to ref_doc_ids.
78

79
    """
80

81
    embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
82
    text_id_to_ref_doc_id: Dict[str, str] = field(default_factory=dict)
83
    metadata_dict: Dict[str, Any] = field(default_factory=dict)
84

85

86
class SimpleVectorStore(VectorStore):
87
    """Simple Vector Store.
88

89
    In this vector store, embeddings are stored within a simple, in-memory dictionary.
90

91
    Args:
92
        simple_vector_store_data_dict (Optional[dict]): data dict
93
            containing the embeddings and doc_ids. See SimpleVectorStoreData
94
            for more details.
95
    """
96

97
    stores_text: bool = False
98

99
    def __init__(
100
        self,
101
        data: Optional[SimpleVectorStoreData] = None,
102
        fs: Optional[fsspec.AbstractFileSystem] = None,
103
        **kwargs: Any,
104
    ) -> None:
105
        """Initialize params."""
106
        self._data = data or SimpleVectorStoreData()
107
        self._fs = fs or fsspec.filesystem("file")
108

109
    @classmethod
110
    def from_persist_dir(
111
        cls,
112
        persist_dir: str = DEFAULT_PERSIST_DIR,
113
        namespace: Optional[str] = None,
114
        fs: Optional[fsspec.AbstractFileSystem] = None,
115
    ) -> "SimpleVectorStore":
116
        """Load from persist dir."""
117
        if namespace:
118
            persist_fname = f"{namespace}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}"
119
        else:
120
            persist_fname = DEFAULT_PERSIST_FNAME
121

122
        if fs is not None:
123
            persist_path = concat_dirs(persist_dir, persist_fname)
124
        else:
125
            persist_path = os.path.join(persist_dir, persist_fname)
126
        return cls.from_persist_path(persist_path, fs=fs)
127

128
    @classmethod
129
    def from_namespaced_persist_dir(
130
        cls,
131
        persist_dir: str = DEFAULT_PERSIST_DIR,
132
        fs: Optional[fsspec.AbstractFileSystem] = None,
133
    ) -> Dict[str, VectorStore]:
134
        """Load from namespaced persist dir."""
135
        listing_fn = os.listdir if fs is None else fs.listdir
136

137
        vector_stores: Dict[str, VectorStore] = {}
138

139
        try:
140
            for fname in listing_fn(persist_dir):
141
                if fname.endswith(DEFAULT_PERSIST_FNAME):
142
                    namespace = fname.split(NAMESPACE_SEP)[0]
143

144
                    # handle backwards compatibility with stores that were persisted
145
                    if namespace == DEFAULT_PERSIST_FNAME:
146
                        vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(
147
                            persist_dir=persist_dir, fs=fs
148
                        )
149
                    else:
150
                        vector_stores[namespace] = cls.from_persist_dir(
151
                            persist_dir=persist_dir, namespace=namespace, fs=fs
152
                        )
153
        except Exception:
154
            # failed to listdir, so assume there is only one store
155
            try:
156
                vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(
157
                    persist_dir=persist_dir, fs=fs, namespace=DEFAULT_VECTOR_STORE
158
                )
159
            except Exception:
160
                # no namespace backwards compat
161
                vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(
162
                    persist_dir=persist_dir, fs=fs
163
                )
164

165
        return vector_stores
166

167
    @property
168
    def client(self) -> None:
169
        """Get client."""
170
        return
171

172
    def get(self, text_id: str) -> List[float]:
173
        """Get embedding."""
174
        return self._data.embedding_dict[text_id]
175

176
    def add(
177
        self,
178
        nodes: List[BaseNode],
179
        **add_kwargs: Any,
180
    ) -> List[str]:
181
        """Add nodes to index."""
182
        for node in nodes:
183
            self._data.embedding_dict[node.node_id] = node.get_embedding()
184
            self._data.text_id_to_ref_doc_id[node.node_id] = node.ref_doc_id or "None"
185

186
            metadata = node_to_metadata_dict(
187
                node, remove_text=True, flat_metadata=False
188
            )
189
            metadata.pop("_node_content", None)
190
            self._data.metadata_dict[node.node_id] = metadata
191
        return [node.node_id for node in nodes]
192

193
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
194
        """
195
        Delete nodes using with ref_doc_id.
196

197
        Args:
198
            ref_doc_id (str): The doc_id of the document to delete.
199

200
        """
201
        text_ids_to_delete = set()
202
        for text_id, ref_doc_id_ in self._data.text_id_to_ref_doc_id.items():
203
            if ref_doc_id == ref_doc_id_:
204
                text_ids_to_delete.add(text_id)
205

206
        for text_id in text_ids_to_delete:
207
            del self._data.embedding_dict[text_id]
208
            del self._data.text_id_to_ref_doc_id[text_id]
209
            # Handle metadata_dict not being present in stores that were persisted
210
            # without metadata, or, not being present for nodes stored
211
            # prior to metadata functionality.
212
            if self._data.metadata_dict is not None:
213
                self._data.metadata_dict.pop(text_id, None)
214

215
    def query(
216
        self,
217
        query: VectorStoreQuery,
218
        **kwargs: Any,
219
    ) -> VectorStoreQueryResult:
220
        """Get nodes for response."""
221
        # Prevent metadata filtering on stores that were persisted without metadata.
222
        if (
223
            query.filters is not None
224
            and self._data.embedding_dict
225
            and not self._data.metadata_dict
226
        ):
227
            raise ValueError(
228
                "Cannot filter stores that were persisted without metadata. "
229
                "Please rebuild the store with metadata to enable filtering."
230
            )
231
        # Prefilter nodes based on the query filter and node ID restrictions.
232
        query_filter_fn = _build_metadata_filter_fn(
233
            lambda node_id: self._data.metadata_dict[node_id], query.filters
234
        )
235

236
        if query.node_ids is not None:
237
            available_ids = set(query.node_ids)
238

239
            def node_filter_fn(node_id: str) -> bool:
240
                return node_id in available_ids
241

242
        else:
243

244
            def node_filter_fn(node_id: str) -> bool:
245
                return True
246

247
        node_ids = []
248
        embeddings = []
249
        # TODO: consolidate with get_query_text_embedding_similarities
250
        for node_id, embedding in self._data.embedding_dict.items():
251
            if node_filter_fn(node_id) and query_filter_fn(node_id):
252
                node_ids.append(node_id)
253
                embeddings.append(embedding)
254

255
        query_embedding = cast(List[float], query.query_embedding)
256

257
        if query.mode in LEARNER_MODES:
258
            top_similarities, top_ids = get_top_k_embeddings_learner(
259
                query_embedding,
260
                embeddings,
261
                similarity_top_k=query.similarity_top_k,
262
                embedding_ids=node_ids,
263
            )
264
        elif query.mode == MMR_MODE:
265
            mmr_threshold = kwargs.get("mmr_threshold", None)
266
            top_similarities, top_ids = get_top_k_mmr_embeddings(
267
                query_embedding,
268
                embeddings,
269
                similarity_top_k=query.similarity_top_k,
270
                embedding_ids=node_ids,
271
                mmr_threshold=mmr_threshold,
272
            )
273
        elif query.mode == VectorStoreQueryMode.DEFAULT:
274
            top_similarities, top_ids = get_top_k_embeddings(
275
                query_embedding,
276
                embeddings,
277
                similarity_top_k=query.similarity_top_k,
278
                embedding_ids=node_ids,
279
            )
280
        else:
281
            raise ValueError(f"Invalid query mode: {query.mode}")
282

283
        return VectorStoreQueryResult(similarities=top_similarities, ids=top_ids)
284

285
    def persist(
286
        self,
287
        persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),
288
        fs: Optional[fsspec.AbstractFileSystem] = None,
289
    ) -> None:
290
        """Persist the SimpleVectorStore to a directory."""
291
        fs = fs or self._fs
292
        dirpath = os.path.dirname(persist_path)
293
        if not fs.exists(dirpath):
294
            fs.makedirs(dirpath)
295

296
        with fs.open(persist_path, "w") as f:
297
            json.dump(self._data.to_dict(), f)
298

299
    @classmethod
300
    def from_persist_path(
301
        cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
302
    ) -> "SimpleVectorStore":
303
        """Create a SimpleKVStore from a persist directory."""
304
        fs = fs or fsspec.filesystem("file")
305
        if not fs.exists(persist_path):
306
            raise ValueError(
307
                f"No existing {__name__} found at {persist_path}, skipping load."
308
            )
309

310
        logger.debug(f"Loading {__name__} from {persist_path}.")
311
        with fs.open(persist_path, "rb") as f:
312
            data_dict = json.load(f)
313
            data = SimpleVectorStoreData.from_dict(data_dict)
314
        return cls(data)
315

316
    @classmethod
317
    def from_dict(cls, save_dict: dict) -> "SimpleVectorStore":
318
        data = SimpleVectorStoreData.from_dict(save_dict)
319
        return cls(data)
320

321
    def to_dict(self) -> dict:
322
        return self._data.to_dict()
323

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.