llama-index

Форк
0
229 строк · 8.6 Кб
1
"""Kùzu graph store index."""
2

3
from typing import Any, Dict, List, Optional
4

5
from llama_index.legacy.graph_stores.types import GraphStore
6

7

8
class KuzuGraphStore(GraphStore):
9
    def __init__(
10
        self,
11
        database: Any,
12
        node_table_name: str = "entity",
13
        rel_table_name: str = "links",
14
        **kwargs: Any,
15
    ) -> None:
16
        try:
17
            import kuzu
18
        except ImportError:
19
            raise ImportError("Please install kuzu: pip install kuzu")
20
        self.database = database
21
        self.connection = kuzu.Connection(database)
22
        self.node_table_name = node_table_name
23
        self.rel_table_name = rel_table_name
24
        self.init_schema()
25

26
    def init_schema(self) -> None:
27
        """Initialize schema if the tables do not exist."""
28
        node_tables = self.connection._get_node_table_names()
29
        if self.node_table_name not in node_tables:
30
            self.connection.execute(
31
                "CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))"
32
                % self.node_table_name
33
            )
34
        rel_tables = self.connection._get_rel_table_names()
35
        rel_tables = [rel_table["name"] for rel_table in rel_tables]
36
        if self.rel_table_name not in rel_tables:
37
            self.connection.execute(
38
                "CREATE REL TABLE {} (FROM {} TO {}, predicate STRING)".format(
39
                    self.rel_table_name, self.node_table_name, self.node_table_name
40
                )
41
            )
42

43
    @property
44
    def client(self) -> Any:
45
        return self.connection
46

47
    def get(self, subj: str) -> List[List[str]]:
48
        """Get triplets."""
49
        query = """
50
            MATCH (n1:%s)-[r:%s]->(n2:%s)
51
            WHERE n1.ID = $subj
52
            RETURN r.predicate, n2.ID;
53
        """
54
        prepared_statement = self.connection.prepare(
55
            query % (self.node_table_name, self.rel_table_name, self.node_table_name)
56
        )
57
        query_result = self.connection.execute(prepared_statement, [("subj", subj)])
58
        retval = []
59
        while query_result.has_next():
60
            row = query_result.get_next()
61
            retval.append([row[0], row[1]])
62
        return retval
63

64
    def get_rel_map(
65
        self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
66
    ) -> Dict[str, List[List[str]]]:
67
        """Get depth-aware rel map."""
68
        rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth)
69
        match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format(
70
            self.node_table_name,
71
            rel_wildcard,
72
            self.node_table_name,
73
        )
74
        return_clause = "RETURN n1, r, n2 LIMIT %d" % limit
75
        params = []
76
        if subjs is not None:
77
            for i, curr_subj in enumerate(subjs):
78
                if i == 0:
79
                    where_clause = "WHERE n1.ID = $%d" % i
80
                else:
81
                    where_clause += " OR n1.ID = $%d" % i
82
                params.append((str(i), curr_subj))
83
        else:
84
            where_clause = ""
85
        query = f"{match_clause} {where_clause} {return_clause}"
86
        prepared_statement = self.connection.prepare(query)
87
        if subjs is not None:
88
            query_result = self.connection.execute(prepared_statement, params)
89
        else:
90
            query_result = self.connection.execute(prepared_statement)
91
        retval: Dict[str, List[List[str]]] = {}
92
        while query_result.has_next():
93
            row = query_result.get_next()
94
            curr_path = []
95
            subj = row[0]
96
            recursive_rel = row[1]
97
            obj = row[2]
98
            nodes_map = {}
99
            nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"]
100
            nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"]
101
            for node in recursive_rel["_nodes"]:
102
                nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"]
103
            for rel in recursive_rel["_rels"]:
104
                predicate = rel["predicate"]
105
                curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])]
106
                curr_path.append(curr_subj_id)
107
                curr_path.append(predicate)
108
            # Add the last node
109
            curr_path.append(obj["ID"])
110
            if subj["ID"] not in retval:
111
                retval[subj["ID"]] = []
112
            retval[subj["ID"]].append(curr_path)
113
        return retval
114

115
    def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
116
        """Add triplet."""
117

118
        def check_entity_exists(connection: Any, entity: str) -> bool:
119
            is_exists_result = connection.execute(
120
                "MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name,
121
                [("entity", entity)],
122
            )
123
            return is_exists_result.has_next()
124

125
        def create_entity(connection: Any, entity: str) -> None:
126
            connection.execute(
127
                "CREATE (n:%s {ID: $entity})" % self.node_table_name,
128
                [("entity", entity)],
129
            )
130

131
        def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool:
132
            is_exists_result = connection.execute(
133
                (
134
                    "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = "
135
                    "$obj AND r.predicate = $pred RETURN r.predicate"
136
                ).format(
137
                    self.node_table_name, self.rel_table_name, self.node_table_name
138
                ),
139
                [("subj", subj), ("obj", obj), ("pred", rel)],
140
            )
141
            return is_exists_result.has_next()
142

143
        def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
144
            connection.execute(
145
                (
146
                    "MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj "
147
                    "CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)"
148
                ).format(
149
                    self.node_table_name, self.node_table_name, self.rel_table_name
150
                ),
151
                [("subj", subj), ("obj", obj), ("pred", rel)],
152
            )
153

154
        is_subj_exists = check_entity_exists(self.connection, subj)
155
        is_obj_exists = check_entity_exists(self.connection, obj)
156

157
        if not is_subj_exists:
158
            create_entity(self.connection, subj)
159
        if not is_obj_exists:
160
            create_entity(self.connection, obj)
161

162
        if is_subj_exists and is_obj_exists:
163
            is_rel_exists = check_rel_exists(self.connection, subj, obj, rel)
164
            if is_rel_exists:
165
                return
166

167
        create_rel(self.connection, subj, obj, rel)
168

169
    def delete(self, subj: str, rel: str, obj: str) -> None:
170
        """Delete triplet."""
171

172
        def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
173
            connection.execute(
174
                (
175
                    "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID"
176
                    " = $obj AND r.predicate = $pred DELETE r"
177
                ).format(
178
                    self.node_table_name, self.rel_table_name, self.node_table_name
179
                ),
180
                [("subj", subj), ("obj", obj), ("pred", rel)],
181
            )
182

183
        def delete_entity(connection: Any, entity: str) -> None:
184
            connection.execute(
185
                "MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name,
186
                [("entity", entity)],
187
            )
188

189
        def check_edges(connection: Any, entity: str) -> bool:
190
            is_exists_result = connection.execute(
191
                "MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format(
192
                    self.node_table_name, self.rel_table_name, self.node_table_name
193
                ),
194
                [("entity", entity)],
195
            )
196
            return is_exists_result.has_next()
197

198
        delete_rel(self.connection, subj, obj, rel)
199
        if not check_edges(self.connection, subj):
200
            delete_entity(self.connection, subj)
201
        if not check_edges(self.connection, obj):
202
            delete_entity(self.connection, obj)
203

204
    @classmethod
205
    def from_persist_dir(
206
        cls,
207
        persist_dir: str,
208
        node_table_name: str = "entity",
209
        rel_table_name: str = "links",
210
    ) -> "KuzuGraphStore":
211
        """Load from persist dir."""
212
        try:
213
            import kuzu
214
        except ImportError:
215
            raise ImportError("Please install kuzu: pip install kuzu")
216
        database = kuzu.Database(persist_dir)
217
        return cls(database, node_table_name, rel_table_name)
218

219
    @classmethod
220
    def from_dict(cls, config_dict: Dict[str, Any]) -> "KuzuGraphStore":
221
        """Initialize graph store from configuration dictionary.
222

223
        Args:
224
            config_dict: Configuration dictionary.
225

226
        Returns:
227
            Graph store.
228
        """
229
        return cls(**config_dict)
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.