llama-index

Форк
0
257 строк · 9.3 Кб
1
"""Neo4j graph store index."""
2

3
import logging
4
from typing import Any, Dict, List, Optional
5

6
from llama_index.legacy.graph_stores.types import GraphStore
7

8
logger = logging.getLogger(__name__)
9

10
node_properties_query = """
11
CALL apoc.meta.data()
12
YIELD label, other, elementType, type, property
13
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
14
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
15
RETURN {labels: nodeLabels, properties: properties} AS output
16

17
"""
18

19
rel_properties_query = """
20
CALL apoc.meta.data()
21
YIELD label, other, elementType, type, property
22
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
23
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
24
RETURN {type: nodeLabels, properties: properties} AS output
25
"""
26

27
rel_query = """
28
CALL apoc.meta.data()
29
YIELD label, other, elementType, type, property
30
WHERE type = "RELATIONSHIP" AND elementType = "node"
31
UNWIND other AS other_node
32
RETURN {start: label, type: property, end: toString(other_node)} AS output
33
"""
34

35

36
class Neo4jGraphStore(GraphStore):
37
    def __init__(
38
        self,
39
        username: str,
40
        password: str,
41
        url: str,
42
        database: str = "neo4j",
43
        node_label: str = "Entity",
44
        **kwargs: Any,
45
    ) -> None:
46
        try:
47
            import neo4j
48
        except ImportError:
49
            raise ImportError("Please install neo4j: pip install neo4j")
50
        self.node_label = node_label
51
        self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
52
        self._database = database
53
        self.schema = ""
54
        self.structured_schema: Dict[str, Any] = {}
55
        # Verify connection
56
        try:
57
            self._driver.verify_connectivity()
58
        except neo4j.exceptions.ServiceUnavailable:
59
            raise ValueError(
60
                "Could not connect to Neo4j database. "
61
                "Please ensure that the url is correct"
62
            )
63
        except neo4j.exceptions.AuthError:
64
            raise ValueError(
65
                "Could not connect to Neo4j database. "
66
                "Please ensure that the username and password are correct"
67
            )
68
        # Set schema
69
        try:
70
            self.refresh_schema()
71
        except neo4j.exceptions.ClientError:
72
            raise ValueError(
73
                "Could not use APOC procedures. "
74
                "Please ensure the APOC plugin is installed in Neo4j and that "
75
                "'apoc.meta.data()' is allowed in Neo4j configuration "
76
            )
77
        # Create constraint for faster insert and retrieval
78
        try:  # Using Neo4j 5
79
            self.query(
80
                """
81
                CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.id IS UNIQUE;
82
                """
83
                % (self.node_label)
84
            )
85
        except Exception:  # Using Neo4j <5
86
            self.query(
87
                """
88
                CREATE CONSTRAINT IF NOT EXISTS ON (n:%s) ASSERT n.id IS UNIQUE;
89
                """
90
                % (self.node_label)
91
            )
92

93
    @property
94
    def client(self) -> Any:
95
        return self._driver
96

97
    def get(self, subj: str) -> List[List[str]]:
98
        """Get triplets."""
99
        query = """
100
            MATCH (n1:%s)-[r]->(n2:%s)
101
            WHERE n1.id = $subj
102
            RETURN type(r), n2.id;
103
        """
104

105
        prepared_statement = query % (self.node_label, self.node_label)
106

107
        with self._driver.session(database=self._database) as session:
108
            data = session.run(prepared_statement, {"subj": subj})
109
            return [record.values() for record in data]
110

111
    def get_rel_map(
112
        self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
113
    ) -> Dict[str, List[List[str]]]:
114
        """Get flat rel map."""
115
        # The flat means for multi-hop relation path, we could get
116
        # knowledge like: subj -> rel -> obj -> rel -> obj -> rel -> obj.
117
        # This type of knowledge is useful for some tasks.
118
        # +-------------+------------------------------------+
119
        # | subj        | flattened_rels                     |
120
        # +-------------+------------------------------------+
121
        # | "player101" | [95, "player125", 2002, "team204"] |
122
        # | "player100" | [1997, "team204"]                  |
123
        # ...
124
        # +-------------+------------------------------------+
125

126
        rel_map: Dict[Any, List[Any]] = {}
127
        if subjs is None or len(subjs) == 0:
128
            # unlike simple graph_store, we don't do get_all here
129
            return rel_map
130

131
        query = (
132
            f"""MATCH p=(n1:{self.node_label})-[*1..{depth}]->() """
133
            f"""{"WHERE n1.id IN $subjs" if subjs else ""} """
134
            "UNWIND relationships(p) AS rel "
135
            "WITH n1.id AS subj, p, apoc.coll.flatten(apoc.coll.toSet("
136
            "collect([type(rel), endNode(rel).id]))) AS flattened_rels "
137
            f"RETURN subj, collect(flattened_rels) AS flattened_rels LIMIT {limit}"
138
        )
139

140
        data = list(self.query(query, {"subjs": subjs}))
141
        if not data:
142
            return rel_map
143

144
        for record in data:
145
            rel_map[record["subj"]] = record["flattened_rels"]
146
        return rel_map
147

148
    def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
149
        """Add triplet."""
150
        query = """
151
            MERGE (n1:`%s` {id:$subj})
152
            MERGE (n2:`%s` {id:$obj})
153
            MERGE (n1)-[:`%s`]->(n2)
154
        """
155

156
        prepared_statement = query % (
157
            self.node_label,
158
            self.node_label,
159
            rel.replace(" ", "_").upper(),
160
        )
161

162
        with self._driver.session(database=self._database) as session:
163
            session.run(prepared_statement, {"subj": subj, "obj": obj})
164

165
    def delete(self, subj: str, rel: str, obj: str) -> None:
166
        """Delete triplet."""
167

168
        def delete_rel(subj: str, obj: str, rel: str) -> None:
169
            with self._driver.session(database=self._database) as session:
170
                session.run(
171
                    (
172
                        "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $subj AND n2.id"
173
                        " = $obj DELETE r"
174
                    ).format(self.node_label, rel, self.node_label),
175
                    {"subj": subj, "obj": obj},
176
                )
177

178
        def delete_entity(entity: str) -> None:
179
            with self._driver.session(database=self._database) as session:
180
                session.run(
181
                    "MATCH (n:%s) WHERE n.id = $entity DELETE n" % self.node_label,
182
                    {"entity": entity},
183
                )
184

185
        def check_edges(entity: str) -> bool:
186
            with self._driver.session(database=self._database) as session:
187
                is_exists_result = session.run(
188
                    "MATCH (n1:%s)--() WHERE n1.id = $entity RETURN count(*)"
189
                    % (self.node_label),
190
                    {"entity": entity},
191
                )
192
                return bool(list(is_exists_result))
193

194
        delete_rel(subj, obj, rel)
195
        if not check_edges(subj):
196
            delete_entity(subj)
197
        if not check_edges(obj):
198
            delete_entity(obj)
199

200
    def refresh_schema(self) -> None:
201
        """
202
        Refreshes the Neo4j graph schema information.
203
        """
204
        node_properties = [el["output"] for el in self.query(node_properties_query)]
205
        rel_properties = [el["output"] for el in self.query(rel_properties_query)]
206
        relationships = [el["output"] for el in self.query(rel_query)]
207

208
        self.structured_schema = {
209
            "node_props": {el["labels"]: el["properties"] for el in node_properties},
210
            "rel_props": {el["type"]: el["properties"] for el in rel_properties},
211
            "relationships": relationships,
212
        }
213

214
        # Format node properties
215
        formatted_node_props = []
216
        for el in node_properties:
217
            props_str = ", ".join(
218
                [f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
219
            )
220
            formatted_node_props.append(f"{el['labels']} {{{props_str}}}")
221

222
        # Format relationship properties
223
        formatted_rel_props = []
224
        for el in rel_properties:
225
            props_str = ", ".join(
226
                [f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
227
            )
228
            formatted_rel_props.append(f"{el['type']} {{{props_str}}}")
229

230
        # Format relationships
231
        formatted_rels = [
232
            f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships
233
        ]
234

235
        self.schema = "\n".join(
236
            [
237
                "Node properties are the following:",
238
                ",".join(formatted_node_props),
239
                "Relationship properties are the following:",
240
                ",".join(formatted_rel_props),
241
                "The relationships are the following:",
242
                ",".join(formatted_rels),
243
            ]
244
        )
245

246
    def get_schema(self, refresh: bool = False) -> str:
247
        """Get the schema of the Neo4jGraph store."""
248
        if self.schema and not refresh:
249
            return self.schema
250
        self.refresh_schema()
251
        logger.debug(f"get_schema() schema:\n{self.schema}")
252
        return self.schema
253

254
    def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
255
        with self._driver.session(database=self._database) as session:
256
            result = session.run(query, param_map)
257
            return [d.data() for d in result]
258

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.