llama-index

Форк
0
332 строки · 12.2 Кб
1
""" Knowledge Graph Query Engine."""
2

3
import logging
4
from typing import Any, Dict, List, Optional, Sequence
5

6
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
7
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
8
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
9
from llama_index.legacy.graph_stores.registry import (
10
    GRAPH_STORE_CLASS_TO_GRAPH_STORE_TYPE,
11
    GraphStoreType,
12
)
13
from llama_index.legacy.prompts.base import (
14
    BasePromptTemplate,
15
    PromptTemplate,
16
    PromptType,
17
)
18
from llama_index.legacy.prompts.mixin import PromptDictType, PromptMixinType
19
from llama_index.legacy.response_synthesizers import (
20
    BaseSynthesizer,
21
    get_response_synthesizer,
22
)
23
from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode
24
from llama_index.legacy.service_context import ServiceContext
25
from llama_index.legacy.storage.storage_context import StorageContext
26
from llama_index.legacy.utils import print_text
27

28
logger = logging.getLogger(__name__)
29

30
# Prompt
31
DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT_TMPL = """
32
Generate NebulaGraph query from natural language.
33
Use only the provided relationship types and properties in the schema.
34
Do not use any other relationship types or properties that are not provided.
35
Schema:
36
---
37
{schema}
38
---
39
Note: NebulaGraph speaks a dialect of Cypher, comparing to standard Cypher:
40

41
1. it uses double equals sign for comparison: `==` rather than `=`
42
2. it needs explicit label specification when referring to node properties, i.e.
43
v is a variable of a node, and we know its label is Foo, v.`foo`.name is correct
44
while v.name is not.
45

46
For example, see this diff between standard and NebulaGraph Cypher dialect:
47
```diff
48
< MATCH (p:person)-[:directed]->(m:movie) WHERE m.name = 'The Godfather'
49
< RETURN p.name;
50
---
51
> MATCH (p:`person`)-[:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather'
52
> RETURN p.`person`.`name`;
53
```
54

55
Question: {query_str}
56

57
NebulaGraph Cypher dialect query:
58
"""
59
DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT = PromptTemplate(
60
    DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT_TMPL,
61
    prompt_type=PromptType.TEXT_TO_GRAPH_QUERY,
62
)
63

64
# Prompt
65
DEFAULT_NEO4J_NL2CYPHER_PROMPT_TMPL = (
66
    "Task:Generate Cypher statement to query a graph database.\n"
67
    "Instructions:\n"
68
    "Use only the provided relationship types and properties in the schema.\n"
69
    "Do not use any other relationship types or properties that are not provided.\n"
70
    "Schema:\n"
71
    "{schema}\n"
72
    "Note: Do not include any explanations or apologies in your responses.\n"
73
    "Do not respond to any questions that might ask anything else than for you "
74
    "to construct a Cypher statement. \n"
75
    "Do not include any text except the generated Cypher statement.\n"
76
    "\n"
77
    "The question is:\n"
78
    "{query_str}\n"
79
)
80

81
DEFAULT_NEO4J_NL2CYPHER_PROMPT = PromptTemplate(
82
    DEFAULT_NEO4J_NL2CYPHER_PROMPT_TMPL,
83
    prompt_type=PromptType.TEXT_TO_GRAPH_QUERY,
84
)
85

86
DEFAULT_NL2GRAPH_PROMPT_MAP = {
87
    GraphStoreType.NEBULA: DEFAULT_NEBULAGRAPH_NL2CYPHER_PROMPT,
88
    GraphStoreType.NEO4J: DEFAULT_NEO4J_NL2CYPHER_PROMPT,
89
}
90

91
DEFAULT_KG_RESPONSE_ANSWER_PROMPT_TMPL = """
92
The original question is given below.
93
This question has been translated into a Graph Database query.
94
Both the Graph query and the response are given below.
95
Given the Graph Query response, synthesise a response to the original question.
96

97
Original question: {query_str}
98
Graph query: {kg_query_str}
99
Graph response: {kg_response_str}
100
Response:
101
"""
102

103
DEFAULT_KG_RESPONSE_ANSWER_PROMPT = PromptTemplate(
104
    DEFAULT_KG_RESPONSE_ANSWER_PROMPT_TMPL,
105
    prompt_type=PromptType.QUESTION_ANSWER,
106
)
107

108

109
class KnowledgeGraphQueryEngine(BaseQueryEngine):
110
    """Knowledge graph query engine.
111

112
    Query engine to call a knowledge graph.
113

114
    Args:
115
        service_context (Optional[ServiceContext]): A service context to use.
116
        storage_context (Optional[StorageContext]): A storage context to use.
117
        refresh_schema (bool): Whether to refresh the schema.
118
        verbose (bool): Whether to print intermediate results.
119
        response_synthesizer (Optional[BaseSynthesizer]):
120
            A BaseSynthesizer object.
121
        **kwargs: Additional keyword arguments.
122

123
    """
124

125
    def __init__(
126
        self,
127
        service_context: Optional[ServiceContext] = None,
128
        storage_context: Optional[StorageContext] = None,
129
        graph_query_synthesis_prompt: Optional[BasePromptTemplate] = None,
130
        graph_response_answer_prompt: Optional[BasePromptTemplate] = None,
131
        refresh_schema: bool = False,
132
        verbose: bool = False,
133
        response_synthesizer: Optional[BaseSynthesizer] = None,
134
        **kwargs: Any,
135
    ):
136
        # Ensure that we have a graph store
137
        assert storage_context is not None, "Must provide a storage context."
138
        assert (
139
            storage_context.graph_store is not None
140
        ), "Must provide a graph store in the storage context."
141
        self._storage_context = storage_context
142
        self.graph_store = storage_context.graph_store
143

144
        self._service_context = service_context or ServiceContext.from_defaults()
145

146
        # Get Graph Store Type
147
        self._graph_store_type = GRAPH_STORE_CLASS_TO_GRAPH_STORE_TYPE[
148
            self.graph_store.__class__
149
        ]
150

151
        # Get Graph schema
152
        self._graph_schema = self.graph_store.get_schema(refresh=refresh_schema)
153

154
        # Get graph store query synthesis prompt
155
        self._graph_query_synthesis_prompt = (
156
            graph_query_synthesis_prompt
157
            or DEFAULT_NL2GRAPH_PROMPT_MAP[self._graph_store_type]
158
        )
159

160
        self._graph_response_answer_prompt = (
161
            graph_response_answer_prompt or DEFAULT_KG_RESPONSE_ANSWER_PROMPT
162
        )
163
        self._verbose = verbose
164
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
165
            callback_manager=self._service_context.callback_manager,
166
            service_context=self._service_context,
167
        )
168

169
        super().__init__(self._service_context.callback_manager)
170

171
    def _get_prompts(self) -> Dict[str, Any]:
172
        """Get prompts."""
173
        return {
174
            "graph_query_synthesis_prompt": self._graph_query_synthesis_prompt,
175
            "graph_response_answer_prompt": self._graph_response_answer_prompt,
176
        }
177

178
    def _update_prompts(self, prompts: PromptDictType) -> None:
179
        """Update prompts."""
180
        if "graph_query_synthesis_prompt" in prompts:
181
            self._graph_query_synthesis_prompt = prompts["graph_query_synthesis_prompt"]
182
        if "graph_response_answer_prompt" in prompts:
183
            self._graph_response_answer_prompt = prompts["graph_response_answer_prompt"]
184

185
    def _get_prompt_modules(self) -> PromptMixinType:
186
        """Get prompt sub-modules."""
187
        return {"response_synthesizer": self._response_synthesizer}
188

189
    def generate_query(self, query_str: str) -> str:
190
        """Generate a Graph Store Query from a query bundle."""
191
        # Get the query engine query string
192

193
        graph_store_query: str = self._service_context.llm.predict(
194
            self._graph_query_synthesis_prompt,
195
            query_str=query_str,
196
            schema=self._graph_schema,
197
        )
198

199
        return graph_store_query
200

201
    async def agenerate_query(self, query_str: str) -> str:
202
        """Generate a Graph Store Query from a query bundle."""
203
        # Get the query engine query string
204

205
        graph_store_query: str = await self._service_context.llm.apredict(
206
            self._graph_query_synthesis_prompt,
207
            query_str=query_str,
208
            schema=self._graph_schema,
209
        )
210

211
        return graph_store_query
212

213
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
214
        """Get nodes for response."""
215
        graph_store_query = self.generate_query(query_bundle.query_str)
216
        if self._verbose:
217
            print_text(f"Graph Store Query:\n{graph_store_query}\n", color="yellow")
218
        logger.debug(f"Graph Store Query:\n{graph_store_query}")
219

220
        with self.callback_manager.event(
221
            CBEventType.RETRIEVE,
222
            payload={EventPayload.QUERY_STR: graph_store_query},
223
        ) as retrieve_event:
224
            # Get the graph store response
225
            graph_store_response = self.graph_store.query(query=graph_store_query)
226
            if self._verbose:
227
                print_text(
228
                    f"Graph Store Response:\n{graph_store_response}\n",
229
                    color="yellow",
230
                )
231
            logger.debug(f"Graph Store Response:\n{graph_store_response}")
232

233
            retrieve_event.on_end(payload={EventPayload.RESPONSE: graph_store_response})
234

235
        retrieved_graph_context: Sequence = self._graph_response_answer_prompt.format(
236
            query_str=query_bundle.query_str,
237
            kg_query_str=graph_store_query,
238
            kg_response_str=graph_store_response,
239
        )
240

241
        node = NodeWithScore(
242
            node=TextNode(
243
                text=retrieved_graph_context,
244
                score=1.0,
245
                metadata={
246
                    "query_str": query_bundle.query_str,
247
                    "graph_store_query": graph_store_query,
248
                    "graph_store_response": graph_store_response,
249
                    "graph_schema": self._graph_schema,
250
                },
251
            )
252
        )
253
        return [node]
254

255
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
256
        """Query the graph store."""
257
        with self.callback_manager.event(
258
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
259
        ) as query_event:
260
            nodes: List[NodeWithScore] = self._retrieve(query_bundle)
261

262
            response = self._response_synthesizer.synthesize(
263
                query=query_bundle,
264
                nodes=nodes,
265
            )
266

267
            if self._verbose:
268
                print_text(f"Final Response: {response}\n", color="green")
269

270
            query_event.on_end(payload={EventPayload.RESPONSE: response})
271

272
        return response
273

274
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
275
        graph_store_query = await self.agenerate_query(query_bundle.query_str)
276
        if self._verbose:
277
            print_text(f"Graph Store Query:\n{graph_store_query}\n", color="yellow")
278
        logger.debug(f"Graph Store Query:\n{graph_store_query}")
279

280
        with self.callback_manager.event(
281
            CBEventType.RETRIEVE,
282
            payload={EventPayload.QUERY_STR: graph_store_query},
283
        ) as retrieve_event:
284
            # Get the graph store response
285
            # TBD: This is a blocking call. We need to make it async.
286
            graph_store_response = self.graph_store.query(query=graph_store_query)
287
            if self._verbose:
288
                print_text(
289
                    f"Graph Store Response:\n{graph_store_response}\n",
290
                    color="yellow",
291
                )
292
            logger.debug(f"Graph Store Response:\n{graph_store_response}")
293

294
            retrieve_event.on_end(payload={EventPayload.RESPONSE: graph_store_response})
295

296
        retrieved_graph_context: Sequence = self._graph_response_answer_prompt.format(
297
            query_str=query_bundle.query_str,
298
            kg_query_str=graph_store_query,
299
            kg_response_str=graph_store_response,
300
        )
301

302
        node = NodeWithScore(
303
            node=TextNode(
304
                text=retrieved_graph_context,
305
                score=1.0,
306
                metadata={
307
                    "query_str": query_bundle.query_str,
308
                    "graph_store_query": graph_store_query,
309
                    "graph_store_response": graph_store_response,
310
                    "graph_schema": self._graph_schema,
311
                },
312
            )
313
        )
314
        return [node]
315

316
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
317
        """Query the graph store."""
318
        with self.callback_manager.event(
319
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
320
        ) as query_event:
321
            nodes = await self._aretrieve(query_bundle)
322
            response = await self._response_synthesizer.asynthesize(
323
                query=query_bundle,
324
                nodes=nodes,
325
            )
326

327
            if self._verbose:
328
                print_text(f"Final Response: {response}\n", color="green")
329

330
            query_event.on_end(payload={EventPayload.RESPONSE: response})
331

332
        return response
333

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.