llama-index
181 строка · 6.0 Кб
1"""Simple graph store index."""
2
3import json4import logging5import os6from dataclasses import dataclass, field7from typing import Any, Dict, List, Optional8
9import fsspec10from dataclasses_json import DataClassJsonMixin11
12from llama_index.legacy.graph_stores.types import (13DEFAULT_PERSIST_DIR,14DEFAULT_PERSIST_FNAME,15GraphStore,16)
17
18logger = logging.getLogger(__name__)19
20
21@dataclass
22class SimpleGraphStoreData(DataClassJsonMixin):23"""Simple Graph Store Data container.24
25Args:
26graph_dict (Optional[dict]): dict mapping subject to
27"""
28
29graph_dict: Dict[str, List[List[str]]] = field(default_factory=dict)30
31def get_rel_map(32self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 3033) -> Dict[str, List[List[str]]]:34"""Get subjects' rel map in max depth."""35if subjs is None:36subjs = list(self.graph_dict.keys())37rel_map = {}38for subj in subjs:39rel_map[subj] = self._get_rel_map(subj, depth=depth, limit=limit)40# TBD, truncate the rel_map in a spread way, now just truncate based41# on iteration order42rel_count = 043return_map = {}44for subj in rel_map:45if rel_count + len(rel_map[subj]) > limit:46return_map[subj] = rel_map[subj][: limit - rel_count]47break48else:49return_map[subj] = rel_map[subj]50rel_count += len(rel_map[subj])51return return_map52
53def _get_rel_map(54self, subj: str, depth: int = 2, limit: int = 3055) -> List[List[str]]:56"""Get one subect's rel map in max depth."""57if depth == 0:58return []59rel_map = []60rel_count = 061if subj in self.graph_dict:62for rel, obj in self.graph_dict[subj]:63if rel_count >= limit:64break65rel_map.append([subj, rel, obj])66rel_map += self._get_rel_map(obj, depth=depth - 1)67rel_count += 168return rel_map69
70
71class SimpleGraphStore(GraphStore):72"""Simple Graph Store.73
74In this graph store, triplets are stored within a simple, in-memory dictionary.
75
76Args:
77simple_graph_store_data_dict (Optional[dict]): data dict
78containing the triplets. See SimpleGraphStoreData
79for more details.
80"""
81
82def __init__(83self,84data: Optional[SimpleGraphStoreData] = None,85fs: Optional[fsspec.AbstractFileSystem] = None,86**kwargs: Any,87) -> None:88"""Initialize params."""89self._data = data or SimpleGraphStoreData()90self._fs = fs or fsspec.filesystem("file")91
92@classmethod93def from_persist_dir(94cls,95persist_dir: str = DEFAULT_PERSIST_DIR,96fs: Optional[fsspec.AbstractFileSystem] = None,97) -> "SimpleGraphStore":98"""Load from persist dir."""99persist_path = os.path.join(persist_dir, DEFAULT_PERSIST_FNAME)100return cls.from_persist_path(persist_path, fs=fs)101
102@property103def client(self) -> None:104"""Get client.105Not applicable for this store.
106"""
107return108
109def get(self, subj: str) -> List[List[str]]:110"""Get triplets."""111return self._data.graph_dict.get(subj, [])112
113def get_rel_map(114self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30115) -> Dict[str, List[List[str]]]:116"""Get depth-aware rel map."""117return self._data.get_rel_map(subjs=subjs, depth=depth, limit=limit)118
119def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:120"""Add triplet."""121if subj not in self._data.graph_dict:122self._data.graph_dict[subj] = []123if (rel, obj) not in self._data.graph_dict[subj]:124self._data.graph_dict[subj].append([rel, obj])125
126def delete(self, subj: str, rel: str, obj: str) -> None:127"""Delete triplet."""128if subj in self._data.graph_dict:129if (rel, obj) in self._data.graph_dict[subj]:130self._data.graph_dict[subj].remove([rel, obj])131if len(self._data.graph_dict[subj]) == 0:132del self._data.graph_dict[subj]133
134def persist(135self,136persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),137fs: Optional[fsspec.AbstractFileSystem] = None,138) -> None:139"""Persist the SimpleGraphStore to a directory."""140fs = fs or self._fs141dirpath = os.path.dirname(persist_path)142if not fs.exists(dirpath):143fs.makedirs(dirpath)144
145with fs.open(persist_path, "w") as f:146json.dump(self._data.to_dict(), f)147
148def get_schema(self, refresh: bool = False) -> str:149"""Get the schema of the Simple Graph store."""150raise NotImplementedError("SimpleGraphStore does not support get_schema")151
152def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:153"""Query the Simple Graph store."""154raise NotImplementedError("SimpleGraphStore does not support query")155
156@classmethod157def from_persist_path(158cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None159) -> "SimpleGraphStore":160"""Create a SimpleGraphStore from a persist directory."""161fs = fs or fsspec.filesystem("file")162if not fs.exists(persist_path):163logger.warning(164f"No existing {__name__} found at {persist_path}. "165"Initializing a new graph_store from scratch. "166)167return cls()168
169logger.debug(f"Loading {__name__} from {persist_path}.")170with fs.open(persist_path, "rb") as f:171data_dict = json.load(f)172data = SimpleGraphStoreData.from_dict(data_dict)173return cls(data)174
175@classmethod176def from_dict(cls, save_dict: dict) -> "SimpleGraphStore":177data = SimpleGraphStoreData.from_dict(save_dict)178return cls(data)179
180def to_dict(self) -> dict:181return self._data.to_dict()182