llama-index

Форк
0
181 строка · 6.0 Кб
1
"""Simple graph store index."""
2

3
import json
4
import logging
5
import os
6
from dataclasses import dataclass, field
7
from typing import Any, Dict, List, Optional
8

9
import fsspec
10
from dataclasses_json import DataClassJsonMixin
11

12
from llama_index.legacy.graph_stores.types import (
13
    DEFAULT_PERSIST_DIR,
14
    DEFAULT_PERSIST_FNAME,
15
    GraphStore,
16
)
17

18
logger = logging.getLogger(__name__)
19

20

21
@dataclass
22
class SimpleGraphStoreData(DataClassJsonMixin):
23
    """Simple Graph Store Data container.
24

25
    Args:
26
        graph_dict (Optional[dict]): dict mapping subject to
27
    """
28

29
    graph_dict: Dict[str, List[List[str]]] = field(default_factory=dict)
30

31
    def get_rel_map(
32
        self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
33
    ) -> Dict[str, List[List[str]]]:
34
        """Get subjects' rel map in max depth."""
35
        if subjs is None:
36
            subjs = list(self.graph_dict.keys())
37
        rel_map = {}
38
        for subj in subjs:
39
            rel_map[subj] = self._get_rel_map(subj, depth=depth, limit=limit)
40
        # TBD, truncate the rel_map in a spread way, now just truncate based
41
        # on iteration order
42
        rel_count = 0
43
        return_map = {}
44
        for subj in rel_map:
45
            if rel_count + len(rel_map[subj]) > limit:
46
                return_map[subj] = rel_map[subj][: limit - rel_count]
47
                break
48
            else:
49
                return_map[subj] = rel_map[subj]
50
                rel_count += len(rel_map[subj])
51
        return return_map
52

53
    def _get_rel_map(
54
        self, subj: str, depth: int = 2, limit: int = 30
55
    ) -> List[List[str]]:
56
        """Get one subect's rel map in max depth."""
57
        if depth == 0:
58
            return []
59
        rel_map = []
60
        rel_count = 0
61
        if subj in self.graph_dict:
62
            for rel, obj in self.graph_dict[subj]:
63
                if rel_count >= limit:
64
                    break
65
                rel_map.append([subj, rel, obj])
66
                rel_map += self._get_rel_map(obj, depth=depth - 1)
67
                rel_count += 1
68
        return rel_map
69

70

71
class SimpleGraphStore(GraphStore):
72
    """Simple Graph Store.
73

74
    In this graph store, triplets are stored within a simple, in-memory dictionary.
75

76
    Args:
77
        simple_graph_store_data_dict (Optional[dict]): data dict
78
            containing the triplets. See SimpleGraphStoreData
79
            for more details.
80
    """
81

82
    def __init__(
83
        self,
84
        data: Optional[SimpleGraphStoreData] = None,
85
        fs: Optional[fsspec.AbstractFileSystem] = None,
86
        **kwargs: Any,
87
    ) -> None:
88
        """Initialize params."""
89
        self._data = data or SimpleGraphStoreData()
90
        self._fs = fs or fsspec.filesystem("file")
91

92
    @classmethod
93
    def from_persist_dir(
94
        cls,
95
        persist_dir: str = DEFAULT_PERSIST_DIR,
96
        fs: Optional[fsspec.AbstractFileSystem] = None,
97
    ) -> "SimpleGraphStore":
98
        """Load from persist dir."""
99
        persist_path = os.path.join(persist_dir, DEFAULT_PERSIST_FNAME)
100
        return cls.from_persist_path(persist_path, fs=fs)
101

102
    @property
103
    def client(self) -> None:
104
        """Get client.
105
        Not applicable for this store.
106
        """
107
        return
108

109
    def get(self, subj: str) -> List[List[str]]:
110
        """Get triplets."""
111
        return self._data.graph_dict.get(subj, [])
112

113
    def get_rel_map(
114
        self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
115
    ) -> Dict[str, List[List[str]]]:
116
        """Get depth-aware rel map."""
117
        return self._data.get_rel_map(subjs=subjs, depth=depth, limit=limit)
118

119
    def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
120
        """Add triplet."""
121
        if subj not in self._data.graph_dict:
122
            self._data.graph_dict[subj] = []
123
        if (rel, obj) not in self._data.graph_dict[subj]:
124
            self._data.graph_dict[subj].append([rel, obj])
125

126
    def delete(self, subj: str, rel: str, obj: str) -> None:
127
        """Delete triplet."""
128
        if subj in self._data.graph_dict:
129
            if (rel, obj) in self._data.graph_dict[subj]:
130
                self._data.graph_dict[subj].remove([rel, obj])
131
                if len(self._data.graph_dict[subj]) == 0:
132
                    del self._data.graph_dict[subj]
133

134
    def persist(
135
        self,
136
        persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),
137
        fs: Optional[fsspec.AbstractFileSystem] = None,
138
    ) -> None:
139
        """Persist the SimpleGraphStore to a directory."""
140
        fs = fs or self._fs
141
        dirpath = os.path.dirname(persist_path)
142
        if not fs.exists(dirpath):
143
            fs.makedirs(dirpath)
144

145
        with fs.open(persist_path, "w") as f:
146
            json.dump(self._data.to_dict(), f)
147

148
    def get_schema(self, refresh: bool = False) -> str:
149
        """Get the schema of the Simple Graph store."""
150
        raise NotImplementedError("SimpleGraphStore does not support get_schema")
151

152
    def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
153
        """Query the Simple Graph store."""
154
        raise NotImplementedError("SimpleGraphStore does not support query")
155

156
    @classmethod
157
    def from_persist_path(
158
        cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
159
    ) -> "SimpleGraphStore":
160
        """Create a SimpleGraphStore from a persist directory."""
161
        fs = fs or fsspec.filesystem("file")
162
        if not fs.exists(persist_path):
163
            logger.warning(
164
                f"No existing {__name__} found at {persist_path}. "
165
                "Initializing a new graph_store from scratch. "
166
            )
167
            return cls()
168

169
        logger.debug(f"Loading {__name__} from {persist_path}.")
170
        with fs.open(persist_path, "rb") as f:
171
            data_dict = json.load(f)
172
            data = SimpleGraphStoreData.from_dict(data_dict)
173
        return cls(data)
174

175
    @classmethod
176
    def from_dict(cls, save_dict: dict) -> "SimpleGraphStore":
177
        data = SimpleGraphStoreData.from_dict(save_dict)
178
        return cls(data)
179

180
    def to_dict(self) -> dict:
181
        return self._data.to_dict()
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.