llama-index

Форк
0
396 строк · 14.1 Кб
1
from typing import Any, Dict, List, Optional
2

3
from llama_index.legacy.schema import BaseNode, MetadataMode
4
from llama_index.legacy.vector_stores.types import (
5
    VectorStore,
6
    VectorStoreQuery,
7
    VectorStoreQueryResult,
8
)
9
from llama_index.legacy.vector_stores.utils import (
10
    metadata_dict_to_node,
11
    node_to_metadata_dict,
12
)
13

14

15
def check_if_not_null(props: List[str], values: List[Any]) -> None:
16
    """Check if variable is not null and raise error accordingly."""
17
    for prop, value in zip(props, values):
18
        if not value:
19
            raise ValueError(f"Parameter `{prop}` must not be None or empty string")
20

21

22
def sort_by_index_name(
23
    lst: List[Dict[str, Any]], index_name: str
24
) -> List[Dict[str, Any]]:
25
    """Sort first element to match the index_name if exists."""
26
    return sorted(lst, key=lambda x: x.get("index_name") != index_name)
27

28

29
def clean_params(params: List[BaseNode]) -> List[Dict[str, Any]]:
30
    """Convert BaseNode object to a dictionary to be imported into Neo4j."""
31
    clean_params = []
32
    for record in params:
33
        text = record.get_content(metadata_mode=MetadataMode.NONE)
34
        embedding = record.get_embedding()
35
        id = record.node_id
36
        metadata = node_to_metadata_dict(record, remove_text=True, flat_metadata=False)
37
        # Remove redundant metadata information
38
        for k in ["document_id", "doc_id"]:
39
            del metadata[k]
40
        clean_params.append(
41
            {"text": text, "embedding": embedding, "id": id, "metadata": metadata}
42
        )
43
    return clean_params
44

45

46
def _get_search_index_query(hybrid: bool) -> str:
47
    if not hybrid:
48
        return (
49
            "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score "
50
        )
51
    return (
52
        "CALL { "
53
        "CALL db.index.vector.queryNodes($index, $k, $embedding) "
54
        "YIELD node, score "
55
        "WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
56
        "UNWIND nodes AS n "
57
        # We use 0 as min
58
        "RETURN n.node AS node, (n.score / max) AS score UNION "
59
        "CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
60
        "YIELD node, score "
61
        "WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
62
        "UNWIND nodes AS n "
63
        # We use 0 as min
64
        "RETURN n.node AS node, (n.score / max) AS score "
65
        "} "
66
        # dedup
67
        "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
68
    )
69

70

71
def remove_lucene_chars(text: Optional[str]) -> Optional[str]:
72
    """Remove Lucene special characters."""
73
    if not text:
74
        return None
75
    special_chars = [
76
        "+",
77
        "-",
78
        "&",
79
        "|",
80
        "!",
81
        "(",
82
        ")",
83
        "{",
84
        "}",
85
        "[",
86
        "]",
87
        "^",
88
        '"',
89
        "~",
90
        "*",
91
        "?",
92
        ":",
93
        "\\",
94
    ]
95
    for char in special_chars:
96
        if char in text:
97
            text = text.replace(char, " ")
98
    return text.strip()
99

100

101
class Neo4jVectorStore(VectorStore):
102
    stores_text: bool = True
103
    flat_metadata = True
104

105
    def __init__(
106
        self,
107
        username: str,
108
        password: str,
109
        url: str,
110
        embedding_dimension: int,
111
        database: str = "neo4j",
112
        index_name: str = "vector",
113
        keyword_index_name: str = "keyword",
114
        node_label: str = "Chunk",
115
        embedding_node_property: str = "embedding",
116
        text_node_property: str = "text",
117
        distance_strategy: str = "cosine",
118
        hybrid_search: bool = False,
119
        retrieval_query: str = "",
120
        **kwargs: Any,
121
    ) -> None:
122
        try:
123
            import neo4j
124
        except ImportError:
125
            raise ImportError(
126
                "Could not import neo4j python package. "
127
                "Please install it with `pip install neo4j`."
128
            )
129
        if distance_strategy not in ["cosine", "euclidean"]:
130
            raise ValueError("distance_strategy must be either 'euclidean' or 'cosine'")
131

132
        self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
133
        self._database = database
134

135
        # Verify connection
136
        try:
137
            self._driver.verify_connectivity()
138
        except neo4j.exceptions.ServiceUnavailable:
139
            raise ValueError(
140
                "Could not connect to Neo4j database. "
141
                "Please ensure that the url is correct"
142
            )
143
        except neo4j.exceptions.AuthError:
144
            raise ValueError(
145
                "Could not connect to Neo4j database. "
146
                "Please ensure that the username and password are correct"
147
            )
148

149
        # Verify if the version support vector index
150
        self._verify_version()
151

152
        # Verify that required values are not null
153
        check_if_not_null(
154
            [
155
                "index_name",
156
                "node_label",
157
                "embedding_node_property",
158
                "text_node_property",
159
            ],
160
            [index_name, node_label, embedding_node_property, text_node_property],
161
        )
162

163
        self.distance_strategy = distance_strategy
164
        self.index_name = index_name
165
        self.keyword_index_name = keyword_index_name
166
        self.hybrid_search = hybrid_search
167
        self.node_label = node_label
168
        self.embedding_node_property = embedding_node_property
169
        self.text_node_property = text_node_property
170
        self.retrieval_query = retrieval_query
171
        self.embedding_dimension = embedding_dimension
172

173
        index_already_exists = self.retrieve_existing_index()
174
        if not index_already_exists:
175
            self.create_new_index()
176
        if hybrid_search:
177
            fts_node_label = self.retrieve_existing_fts_index()
178
            # If the FTS index doesn't exist yet
179
            if not fts_node_label:
180
                self.create_new_keyword_index()
181
            else:  # Validate that FTS and Vector index use the same information
182
                if not fts_node_label == self.node_label:
183
                    raise ValueError(
184
                        "Vector and keyword index don't index the same node label"
185
                    )
186

187
    def _verify_version(self) -> None:
188
        """
189
        Check if the connected Neo4j database version supports vector indexing.
190

191
        Queries the Neo4j database to retrieve its version and compares it
192
        against a target version (5.11.0) that is known to support vector
193
        indexing. Raises a ValueError if the connected Neo4j version is
194
        not supported.
195
        """
196
        version = self.database_query("CALL dbms.components()")[0]["versions"][0]
197
        if "aura" in version:
198
            version_tuple = (*tuple(map(int, version.split("-")[0].split("."))), 0)
199
        else:
200
            version_tuple = tuple(map(int, version.split(".")))
201

202
        target_version = (5, 11, 0)
203

204
        if version_tuple < target_version:
205
            raise ValueError(
206
                "Version index is only supported in Neo4j version 5.11 or greater"
207
            )
208

209
    def create_new_index(self) -> None:
210
        """
211
        This method constructs a Cypher query and executes it
212
        to create a new vector index in Neo4j.
213
        """
214
        index_query = (
215
            "CALL db.index.vector.createNodeIndex("
216
            "$index_name,"
217
            "$node_label,"
218
            "$embedding_node_property,"
219
            "toInteger($embedding_dimension),"
220
            "$similarity_metric )"
221
        )
222

223
        parameters = {
224
            "index_name": self.index_name,
225
            "node_label": self.node_label,
226
            "embedding_node_property": self.embedding_node_property,
227
            "embedding_dimension": self.embedding_dimension,
228
            "similarity_metric": self.distance_strategy,
229
        }
230
        self.database_query(index_query, params=parameters)
231

232
    def retrieve_existing_index(self) -> bool:
233
        """
234
        Check if the vector index exists in the Neo4j database
235
        and returns its embedding dimension.
236

237
        This method queries the Neo4j database for existing indexes
238
        and attempts to retrieve the dimension of the vector index
239
        with the specified name. If the index exists, its dimension is returned.
240
        If the index doesn't exist, `None` is returned.
241

242
        Returns:
243
            int or None: The embedding dimension of the existing index if found.
244
        """
245
        index_information = self.database_query(
246
            "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options "
247
            "WHERE type = 'VECTOR' AND (name = $index_name "
248
            "OR (labelsOrTypes[0] = $node_label AND "
249
            "properties[0] = $embedding_node_property)) "
250
            "RETURN name, labelsOrTypes, properties, options ",
251
            params={
252
                "index_name": self.index_name,
253
                "node_label": self.node_label,
254
                "embedding_node_property": self.embedding_node_property,
255
            },
256
        )
257
        # sort by index_name
258
        index_information = sort_by_index_name(index_information, self.index_name)
259
        try:
260
            self.index_name = index_information[0]["name"]
261
            self.node_label = index_information[0]["labelsOrTypes"][0]
262
            self.embedding_node_property = index_information[0]["properties"][0]
263
            self.embedding_dimension = index_information[0]["options"]["indexConfig"][
264
                "vector.dimensions"
265
            ]
266

267
            return True
268
        except IndexError:
269
            return False
270

271
    def retrieve_existing_fts_index(self) -> Optional[str]:
272
        """Check if the fulltext index exists in the Neo4j database.
273

274
        This method queries the Neo4j database for existing fts indexes
275
        with the specified name.
276

277
        Returns:
278
            (Tuple): keyword index information
279
        """
280
        index_information = self.database_query(
281
            "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options "
282
            "WHERE type = 'FULLTEXT' AND (name = $keyword_index_name "
283
            "OR (labelsOrTypes = [$node_label] AND "
284
            "properties = $text_node_property)) "
285
            "RETURN name, labelsOrTypes, properties, options ",
286
            params={
287
                "keyword_index_name": self.keyword_index_name,
288
                "node_label": self.node_label,
289
                "text_node_property": self.text_node_property,
290
            },
291
        )
292
        # sort by index_name
293
        index_information = sort_by_index_name(index_information, self.index_name)
294
        try:
295
            self.keyword_index_name = index_information[0]["name"]
296
            self.text_node_property = index_information[0]["properties"][0]
297
            return index_information[0]["labelsOrTypes"][0]
298
        except IndexError:
299
            return None
300

301
    def create_new_keyword_index(self, text_node_properties: List[str] = []) -> None:
302
        """
303
        This method constructs a Cypher query and executes it
304
        to create a new full text index in Neo4j.
305
        """
306
        node_props = text_node_properties or [self.text_node_property]
307
        fts_index_query = (
308
            f"CREATE FULLTEXT INDEX {self.keyword_index_name} "
309
            f"FOR (n:`{self.node_label}`) ON EACH "
310
            f"[{', '.join(['n.`' + el + '`' for el in node_props])}]"
311
        )
312
        self.database_query(fts_index_query)
313

314
    def database_query(
315
        self, query: str, params: Optional[dict] = None
316
    ) -> List[Dict[str, Any]]:
317
        """
318
        This method sends a Cypher query to the connected Neo4j database
319
        and returns the results as a list of dictionaries.
320

321
        Args:
322
            query (str): The Cypher query to execute.
323
            params (dict, optional): Dictionary of query parameters. Defaults to {}.
324

325
        Returns:
326
            List[Dict[str, Any]]: List of dictionaries containing the query results.
327
        """
328
        from neo4j.exceptions import CypherSyntaxError
329

330
        params = params or {}
331
        with self._driver.session(database=self._database) as session:
332
            try:
333
                data = session.run(query, params)
334
                return [r.data() for r in data]
335
            except CypherSyntaxError as e:
336
                raise ValueError(f"Cypher Statement is not valid\n{e}")
337

338
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
339
        ids = [r.node_id for r in nodes]
340
        import_query = (
341
            "UNWIND $data AS row "
342
            "CALL { WITH row "
343
            f"MERGE (c:`{self.node_label}` {{id: row.id}}) "
344
            "WITH c, row "
345
            f"CALL db.create.setVectorProperty(c, "
346
            f"'{self.embedding_node_property}', row.embedding) "
347
            "YIELD node "
348
            f"SET c.`{self.text_node_property}` = row.text "
349
            "SET c += row.metadata } IN TRANSACTIONS OF 1000 ROWS"
350
        )
351

352
        self.database_query(
353
            import_query,
354
            params={"data": clean_params(nodes)},
355
        )
356

357
        return ids
358

359
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
360
        default_retrieval = (
361
            f"RETURN node.`{self.text_node_property}` AS text, score, "
362
            "node.id AS id, "
363
            f"node {{.*, `{self.text_node_property}`: Null, "
364
            f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
365
        )
366

367
        retrieval_query = self.retrieval_query or default_retrieval
368
        read_query = _get_search_index_query(self.hybrid_search) + retrieval_query
369

370
        parameters = {
371
            "index": self.index_name,
372
            "k": query.similarity_top_k,
373
            "embedding": query.query_embedding,
374
            "keyword_index": self.keyword_index_name,
375
            "query": remove_lucene_chars(query.query_str),
376
        }
377

378
        results = self.database_query(read_query, params=parameters)
379

380
        nodes = []
381
        similarities = []
382
        ids = []
383
        for record in results:
384
            node = metadata_dict_to_node(record["metadata"])
385
            node.set_content(str(record["text"]))
386
            nodes.append(node)
387
            similarities.append(record["score"])
388
            ids.append(record["id"])
389

390
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
391

392
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
393
        self.database_query(
394
            f"MATCH (n:`{self.node_label}`) WHERE n.ref_doc_id = $id DETACH DELETE n",
395
            params={"id": ref_doc_id},
396
        )
397

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.