llama-index

Форк
0
123 строки · 4.5 Кб
1
from typing import Any, Dict, List, Optional, Tuple
2

3
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
4
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
5
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
6
from llama_index.legacy.indices.composability.graph import ComposableGraph
7
from llama_index.legacy.schema import IndexNode, NodeWithScore, QueryBundle, TextNode
8

9

10
class ComposableGraphQueryEngine(BaseQueryEngine):
11
    """Composable graph query engine.
12

13
    This query engine can operate over a ComposableGraph.
14
    It can take in custom query engines for its sub-indices.
15

16
    Args:
17
        graph (ComposableGraph): A ComposableGraph object.
18
        custom_query_engines (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
19
            custom query engines.
20
        recursive (bool): Whether to recursively query the graph.
21
        **kwargs: additional arguments to be passed to the underlying index query
22
            engine.
23

24
    """
25

26
    def __init__(
27
        self,
28
        graph: ComposableGraph,
29
        custom_query_engines: Optional[Dict[str, BaseQueryEngine]] = None,
30
        recursive: bool = True,
31
        **kwargs: Any
32
    ) -> None:
33
        """Init params."""
34
        self._graph = graph
35
        self._custom_query_engines = custom_query_engines or {}
36
        self._kwargs = kwargs
37

38
        # additional configs
39
        self._recursive = recursive
40
        callback_manager = self._graph.service_context.callback_manager
41
        super().__init__(callback_manager)
42

43
    def _get_prompt_modules(self) -> Dict[str, Any]:
44
        """Get prompt modules."""
45
        return {}
46

47
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
48
        return self._query_index(query_bundle, index_id=None, level=0)
49

50
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
51
        return self._query_index(query_bundle, index_id=None, level=0)
52

53
    def _query_index(
54
        self,
55
        query_bundle: QueryBundle,
56
        index_id: Optional[str] = None,
57
        level: int = 0,
58
    ) -> RESPONSE_TYPE:
59
        """Query a single index."""
60
        index_id = index_id or self._graph.root_id
61

62
        with self.callback_manager.event(
63
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
64
        ) as query_event:
65
            # get query engine
66
            if index_id in self._custom_query_engines:
67
                query_engine = self._custom_query_engines[index_id]
68
            else:
69
                query_engine = self._graph.get_index(index_id).as_query_engine(
70
                    **self._kwargs
71
                )
72

73
            with self.callback_manager.event(
74
                CBEventType.RETRIEVE,
75
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
76
            ) as retrieve_event:
77
                nodes = query_engine.retrieve(query_bundle)
78
                retrieve_event.on_end(payload={EventPayload.NODES: nodes})
79

80
            if self._recursive:
81
                # do recursion here
82
                nodes_for_synthesis = []
83
                additional_source_nodes = []
84
                for node_with_score in nodes:
85
                    node_with_score, source_nodes = self._fetch_recursive_nodes(
86
                        node_with_score, query_bundle, level
87
                    )
88
                    nodes_for_synthesis.append(node_with_score)
89
                    additional_source_nodes.extend(source_nodes)
90
                response = query_engine.synthesize(
91
                    query_bundle, nodes_for_synthesis, additional_source_nodes
92
                )
93
            else:
94
                response = query_engine.synthesize(query_bundle, nodes)
95

96
            query_event.on_end(payload={EventPayload.RESPONSE: response})
97

98
        return response
99

100
    def _fetch_recursive_nodes(
101
        self,
102
        node_with_score: NodeWithScore,
103
        query_bundle: QueryBundle,
104
        level: int,
105
    ) -> Tuple[NodeWithScore, List[NodeWithScore]]:
106
        """Fetch nodes.
107

108
        Uses existing node if it's not an index node.
109
        Otherwise fetch response from corresponding index.
110

111
        """
112
        if isinstance(node_with_score.node, IndexNode):
113
            index_node = node_with_score.node
114
            # recursive call
115
            response = self._query_index(query_bundle, index_node.index_id, level + 1)
116

117
            new_node = TextNode(text=str(response))
118
            new_node_with_score = NodeWithScore(
119
                node=new_node, score=node_with_score.score
120
            )
121
            return new_node_with_score, response.source_nodes
122
        else:
123
            return node_with_score, []
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.