llama-index
322 строки · 11.0 Кб
1"""Simple vector store index."""
2
3import json4import logging5import os6from dataclasses import dataclass, field7from typing import Any, Callable, Dict, List, Mapping, Optional, cast8
9import fsspec10from dataclasses_json import DataClassJsonMixin11
12from llama_index.legacy.indices.query.embedding_utils import (13get_top_k_embeddings,14get_top_k_embeddings_learner,15get_top_k_mmr_embeddings,16)
17from llama_index.legacy.schema import BaseNode18from llama_index.legacy.utils import concat_dirs19from llama_index.legacy.vector_stores.types import (20DEFAULT_PERSIST_DIR,21DEFAULT_PERSIST_FNAME,22MetadataFilters,23VectorStore,24VectorStoreQuery,25VectorStoreQueryMode,26VectorStoreQueryResult,27)
28from llama_index.legacy.vector_stores.utils import node_to_metadata_dict29
30logger = logging.getLogger(__name__)31
32LEARNER_MODES = {33VectorStoreQueryMode.SVM,34VectorStoreQueryMode.LINEAR_REGRESSION,35VectorStoreQueryMode.LOGISTIC_REGRESSION,36}
37
38MMR_MODE = VectorStoreQueryMode.MMR39
40NAMESPACE_SEP = "__"41DEFAULT_VECTOR_STORE = "default"42
43
44def _build_metadata_filter_fn(45metadata_lookup_fn: Callable[[str], Mapping[str, Any]],46metadata_filters: Optional[MetadataFilters] = None,47) -> Callable[[str], bool]:48"""Build metadata filter function."""49filter_list = metadata_filters.legacy_filters() if metadata_filters else []50if not filter_list:51return lambda _: True52
53def filter_fn(node_id: str) -> bool:54metadata = metadata_lookup_fn(node_id)55for filter_ in filter_list:56metadata_value = metadata.get(filter_.key, None)57if metadata_value is None:58return False59elif isinstance(metadata_value, list):60if filter_.value not in metadata_value:61return False62elif isinstance(metadata_value, (int, float, str, bool)):63if metadata_value != filter_.value:64return False65return True66
67return filter_fn68
69
70@dataclass
71class SimpleVectorStoreData(DataClassJsonMixin):72"""Simple Vector Store Data container.73
74Args:
75embedding_dict (Optional[dict]): dict mapping node_ids to embeddings.
76text_id_to_ref_doc_id (Optional[dict]):
77dict mapping text_ids/node_ids to ref_doc_ids.
78
79"""
80
81embedding_dict: Dict[str, List[float]] = field(default_factory=dict)82text_id_to_ref_doc_id: Dict[str, str] = field(default_factory=dict)83metadata_dict: Dict[str, Any] = field(default_factory=dict)84
85
86class SimpleVectorStore(VectorStore):87"""Simple Vector Store.88
89In this vector store, embeddings are stored within a simple, in-memory dictionary.
90
91Args:
92simple_vector_store_data_dict (Optional[dict]): data dict
93containing the embeddings and doc_ids. See SimpleVectorStoreData
94for more details.
95"""
96
97stores_text: bool = False98
99def __init__(100self,101data: Optional[SimpleVectorStoreData] = None,102fs: Optional[fsspec.AbstractFileSystem] = None,103**kwargs: Any,104) -> None:105"""Initialize params."""106self._data = data or SimpleVectorStoreData()107self._fs = fs or fsspec.filesystem("file")108
109@classmethod110def from_persist_dir(111cls,112persist_dir: str = DEFAULT_PERSIST_DIR,113namespace: Optional[str] = None,114fs: Optional[fsspec.AbstractFileSystem] = None,115) -> "SimpleVectorStore":116"""Load from persist dir."""117if namespace:118persist_fname = f"{namespace}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}"119else:120persist_fname = DEFAULT_PERSIST_FNAME121
122if fs is not None:123persist_path = concat_dirs(persist_dir, persist_fname)124else:125persist_path = os.path.join(persist_dir, persist_fname)126return cls.from_persist_path(persist_path, fs=fs)127
128@classmethod129def from_namespaced_persist_dir(130cls,131persist_dir: str = DEFAULT_PERSIST_DIR,132fs: Optional[fsspec.AbstractFileSystem] = None,133) -> Dict[str, VectorStore]:134"""Load from namespaced persist dir."""135listing_fn = os.listdir if fs is None else fs.listdir136
137vector_stores: Dict[str, VectorStore] = {}138
139try:140for fname in listing_fn(persist_dir):141if fname.endswith(DEFAULT_PERSIST_FNAME):142namespace = fname.split(NAMESPACE_SEP)[0]143
144# handle backwards compatibility with stores that were persisted145if namespace == DEFAULT_PERSIST_FNAME:146vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(147persist_dir=persist_dir, fs=fs148)149else:150vector_stores[namespace] = cls.from_persist_dir(151persist_dir=persist_dir, namespace=namespace, fs=fs152)153except Exception:154# failed to listdir, so assume there is only one store155try:156vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(157persist_dir=persist_dir, fs=fs, namespace=DEFAULT_VECTOR_STORE158)159except Exception:160# no namespace backwards compat161vector_stores[DEFAULT_VECTOR_STORE] = cls.from_persist_dir(162persist_dir=persist_dir, fs=fs163)164
165return vector_stores166
167@property168def client(self) -> None:169"""Get client."""170return171
172def get(self, text_id: str) -> List[float]:173"""Get embedding."""174return self._data.embedding_dict[text_id]175
176def add(177self,178nodes: List[BaseNode],179**add_kwargs: Any,180) -> List[str]:181"""Add nodes to index."""182for node in nodes:183self._data.embedding_dict[node.node_id] = node.get_embedding()184self._data.text_id_to_ref_doc_id[node.node_id] = node.ref_doc_id or "None"185
186metadata = node_to_metadata_dict(187node, remove_text=True, flat_metadata=False188)189metadata.pop("_node_content", None)190self._data.metadata_dict[node.node_id] = metadata191return [node.node_id for node in nodes]192
193def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:194"""195Delete nodes using with ref_doc_id.
196
197Args:
198ref_doc_id (str): The doc_id of the document to delete.
199
200"""
201text_ids_to_delete = set()202for text_id, ref_doc_id_ in self._data.text_id_to_ref_doc_id.items():203if ref_doc_id == ref_doc_id_:204text_ids_to_delete.add(text_id)205
206for text_id in text_ids_to_delete:207del self._data.embedding_dict[text_id]208del self._data.text_id_to_ref_doc_id[text_id]209# Handle metadata_dict not being present in stores that were persisted210# without metadata, or, not being present for nodes stored211# prior to metadata functionality.212if self._data.metadata_dict is not None:213self._data.metadata_dict.pop(text_id, None)214
215def query(216self,217query: VectorStoreQuery,218**kwargs: Any,219) -> VectorStoreQueryResult:220"""Get nodes for response."""221# Prevent metadata filtering on stores that were persisted without metadata.222if (223query.filters is not None224and self._data.embedding_dict225and not self._data.metadata_dict226):227raise ValueError(228"Cannot filter stores that were persisted without metadata. "229"Please rebuild the store with metadata to enable filtering."230)231# Prefilter nodes based on the query filter and node ID restrictions.232query_filter_fn = _build_metadata_filter_fn(233lambda node_id: self._data.metadata_dict[node_id], query.filters234)235
236if query.node_ids is not None:237available_ids = set(query.node_ids)238
239def node_filter_fn(node_id: str) -> bool:240return node_id in available_ids241
242else:243
244def node_filter_fn(node_id: str) -> bool:245return True246
247node_ids = []248embeddings = []249# TODO: consolidate with get_query_text_embedding_similarities250for node_id, embedding in self._data.embedding_dict.items():251if node_filter_fn(node_id) and query_filter_fn(node_id):252node_ids.append(node_id)253embeddings.append(embedding)254
255query_embedding = cast(List[float], query.query_embedding)256
257if query.mode in LEARNER_MODES:258top_similarities, top_ids = get_top_k_embeddings_learner(259query_embedding,260embeddings,261similarity_top_k=query.similarity_top_k,262embedding_ids=node_ids,263)264elif query.mode == MMR_MODE:265mmr_threshold = kwargs.get("mmr_threshold", None)266top_similarities, top_ids = get_top_k_mmr_embeddings(267query_embedding,268embeddings,269similarity_top_k=query.similarity_top_k,270embedding_ids=node_ids,271mmr_threshold=mmr_threshold,272)273elif query.mode == VectorStoreQueryMode.DEFAULT:274top_similarities, top_ids = get_top_k_embeddings(275query_embedding,276embeddings,277similarity_top_k=query.similarity_top_k,278embedding_ids=node_ids,279)280else:281raise ValueError(f"Invalid query mode: {query.mode}")282
283return VectorStoreQueryResult(similarities=top_similarities, ids=top_ids)284
285def persist(286self,287persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),288fs: Optional[fsspec.AbstractFileSystem] = None,289) -> None:290"""Persist the SimpleVectorStore to a directory."""291fs = fs or self._fs292dirpath = os.path.dirname(persist_path)293if not fs.exists(dirpath):294fs.makedirs(dirpath)295
296with fs.open(persist_path, "w") as f:297json.dump(self._data.to_dict(), f)298
299@classmethod300def from_persist_path(301cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None302) -> "SimpleVectorStore":303"""Create a SimpleKVStore from a persist directory."""304fs = fs or fsspec.filesystem("file")305if not fs.exists(persist_path):306raise ValueError(307f"No existing {__name__} found at {persist_path}, skipping load."308)309
310logger.debug(f"Loading {__name__} from {persist_path}.")311with fs.open(persist_path, "rb") as f:312data_dict = json.load(f)313data = SimpleVectorStoreData.from_dict(data_dict)314return cls(data)315
316@classmethod317def from_dict(cls, save_dict: dict) -> "SimpleVectorStore":318data = SimpleVectorStoreData.from_dict(save_dict)319return cls(data)320
321def to_dict(self) -> dict:322return self._data.to_dict()323