llama-index

Форк
0
198 строк · 7.4 Кб
1
from typing import Dict, List, Optional, Tuple, Union
2

3
from llama_index.legacy.callbacks.base import CallbackManager
4
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
5
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
6
from llama_index.legacy.core.base_retriever import BaseRetriever
7
from llama_index.legacy.schema import (
8
    BaseNode,
9
    IndexNode,
10
    NodeWithScore,
11
    QueryBundle,
12
    TextNode,
13
)
14
from llama_index.legacy.utils import print_text
15

16
DEFAULT_QUERY_RESPONSE_TMPL = "Query: {query_str}\nResponse: {response}"
17

18

19
RQN_TYPE = Union[BaseRetriever, BaseQueryEngine, BaseNode]
20

21

22
class RecursiveRetriever(BaseRetriever):
23
    """Recursive retriever.
24

25
    This retriever will recursively explore links from nodes to other
26
    retrievers/query engines.
27

28
    For any retrieved nodes, if any of the nodes are IndexNodes,
29
    then it will explore the linked retriever/query engine, and query that.
30

31
    Args:
32
        root_id (str): The root id of the query graph.
33
        retriever_dict (Optional[Dict[str, BaseRetriever]]): A dictionary
34
            of id to retrievers.
35
        query_engine_dict (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
36
            id to query engines.
37

38
    """
39

40
    def __init__(
41
        self,
42
        root_id: str,
43
        retriever_dict: Dict[str, BaseRetriever],
44
        query_engine_dict: Optional[Dict[str, BaseQueryEngine]] = None,
45
        node_dict: Optional[Dict[str, BaseNode]] = None,
46
        callback_manager: Optional[CallbackManager] = None,
47
        query_response_tmpl: Optional[str] = None,
48
        verbose: bool = False,
49
    ) -> None:
50
        """Init params."""
51
        self._root_id = root_id
52
        if root_id not in retriever_dict:
53
            raise ValueError(
54
                f"Root id {root_id} not in retriever_dict, it must be a retriever."
55
            )
56
        self._retriever_dict = retriever_dict
57
        self._query_engine_dict = query_engine_dict or {}
58
        self._node_dict = node_dict or {}
59
        # make sure keys don't overlap
60
        if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()):
61
            raise ValueError("Retriever and query engine ids must not overlap.")
62

63
        self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL
64
        super().__init__(callback_manager, verbose=verbose)
65

66
    def _query_retrieved_nodes(
67
        self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]
68
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
69
        """Query for retrieved nodes.
70

71
        If node is an IndexNode, then recursively query the retriever/query engine.
72
        If node is a TextNode, then simply return the node.
73

74
        """
75
        nodes_to_add = []
76
        additional_nodes = []
77
        visited_ids = set()
78

79
        # dedup index nodes that reference same index id
80
        new_nodes_with_score = []
81
        for node_with_score in nodes_with_score:
82
            node = node_with_score.node
83
            if isinstance(node, IndexNode):
84
                if node.index_id not in visited_ids:
85
                    visited_ids.add(node.index_id)
86
                    new_nodes_with_score.append(node_with_score)
87
            else:
88
                new_nodes_with_score.append(node_with_score)
89

90
        nodes_with_score = new_nodes_with_score
91

92
        # recursively retrieve
93
        for node_with_score in nodes_with_score:
94
            node = node_with_score.node
95
            if isinstance(node, IndexNode):
96
                if self._verbose:
97
                    print_text(
98
                        "Retrieved node with id, entering: " f"{node.index_id}\n",
99
                        color="pink",
100
                    )
101
                cur_retrieved_nodes, cur_additional_nodes = self._retrieve_rec(
102
                    query_bundle,
103
                    query_id=node.index_id,
104
                    cur_similarity=node_with_score.score,
105
                )
106
            else:
107
                assert isinstance(node, TextNode)
108
                if self._verbose:
109
                    print_text(
110
                        "Retrieving text node: " f"{node.get_content()}\n",
111
                        color="pink",
112
                    )
113
                cur_retrieved_nodes = [node_with_score]
114
                cur_additional_nodes = []
115
            nodes_to_add.extend(cur_retrieved_nodes)
116
            additional_nodes.extend(cur_additional_nodes)
117

118
        return nodes_to_add, additional_nodes
119

120
    def _get_object(self, query_id: str) -> RQN_TYPE:
121
        """Fetch retriever or query engine."""
122
        node = self._node_dict.get(query_id, None)
123
        if node is not None:
124
            return node
125
        retriever = self._retriever_dict.get(query_id, None)
126
        if retriever is not None:
127
            return retriever
128
        query_engine = self._query_engine_dict.get(query_id, None)
129
        if query_engine is not None:
130
            return query_engine
131
        raise ValueError(
132
            f"Query id {query_id} not found in either `retriever_dict` "
133
            "or `query_engine_dict`."
134
        )
135

136
    def _retrieve_rec(
137
        self,
138
        query_bundle: QueryBundle,
139
        query_id: Optional[str] = None,
140
        cur_similarity: Optional[float] = None,
141
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
142
        """Query recursively."""
143
        if self._verbose:
144
            print_text(
145
                f"Retrieving with query id {query_id}: {query_bundle.query_str}\n",
146
                color="blue",
147
            )
148
        query_id = query_id or self._root_id
149
        cur_similarity = cur_similarity or 1.0
150

151
        obj = self._get_object(query_id)
152
        if isinstance(obj, BaseNode):
153
            nodes_to_add = [NodeWithScore(node=obj, score=cur_similarity)]
154
            additional_nodes: List[NodeWithScore] = []
155
        elif isinstance(obj, BaseRetriever):
156
            with self.callback_manager.event(
157
                CBEventType.RETRIEVE,
158
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
159
            ) as event:
160
                nodes = obj.retrieve(query_bundle)
161
                event.on_end(payload={EventPayload.NODES: nodes})
162

163
            nodes_to_add, additional_nodes = self._query_retrieved_nodes(
164
                query_bundle, nodes
165
            )
166

167
        elif isinstance(obj, BaseQueryEngine):
168
            sub_resp = obj.query(query_bundle)
169
            if self._verbose:
170
                print_text(
171
                    f"Got response: {sub_resp!s}\n",
172
                    color="green",
173
                )
174
            # format with both the query and the response
175
            node_text = self._query_response_tmpl.format(
176
                query_str=query_bundle.query_str, response=str(sub_resp)
177
            )
178
            node = TextNode(text=node_text)
179
            nodes_to_add = [NodeWithScore(node=node, score=cur_similarity)]
180
            additional_nodes = sub_resp.source_nodes
181
        else:
182
            raise ValueError("Must be a retriever or query engine.")
183

184
        return nodes_to_add, additional_nodes
185

186
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
187
        retrieved_nodes, _ = self._retrieve_rec(query_bundle, query_id=None)
188
        return retrieved_nodes
189

190
    def retrieve_all(
191
        self, query_bundle: QueryBundle
192
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
193
        """Retrieve all nodes.
194

195
        Unlike default `retrieve` method, this also fetches additional sources.
196

197
        """
198
        return self._retrieve_rec(query_bundle, query_id=None)
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.