llama-index
198 строк · 7.4 Кб
1from typing import Dict, List, Optional, Tuple, Union2
3from llama_index.legacy.callbacks.base import CallbackManager4from llama_index.legacy.callbacks.schema import CBEventType, EventPayload5from llama_index.legacy.core.base_query_engine import BaseQueryEngine6from llama_index.legacy.core.base_retriever import BaseRetriever7from llama_index.legacy.schema import (8BaseNode,9IndexNode,10NodeWithScore,11QueryBundle,12TextNode,13)
14from llama_index.legacy.utils import print_text15
16DEFAULT_QUERY_RESPONSE_TMPL = "Query: {query_str}\nResponse: {response}"17
18
19RQN_TYPE = Union[BaseRetriever, BaseQueryEngine, BaseNode]20
21
22class RecursiveRetriever(BaseRetriever):23"""Recursive retriever.24
25This retriever will recursively explore links from nodes to other
26retrievers/query engines.
27
28For any retrieved nodes, if any of the nodes are IndexNodes,
29then it will explore the linked retriever/query engine, and query that.
30
31Args:
32root_id (str): The root id of the query graph.
33retriever_dict (Optional[Dict[str, BaseRetriever]]): A dictionary
34of id to retrievers.
35query_engine_dict (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
36id to query engines.
37
38"""
39
40def __init__(41self,42root_id: str,43retriever_dict: Dict[str, BaseRetriever],44query_engine_dict: Optional[Dict[str, BaseQueryEngine]] = None,45node_dict: Optional[Dict[str, BaseNode]] = None,46callback_manager: Optional[CallbackManager] = None,47query_response_tmpl: Optional[str] = None,48verbose: bool = False,49) -> None:50"""Init params."""51self._root_id = root_id52if root_id not in retriever_dict:53raise ValueError(54f"Root id {root_id} not in retriever_dict, it must be a retriever."55)56self._retriever_dict = retriever_dict57self._query_engine_dict = query_engine_dict or {}58self._node_dict = node_dict or {}59# make sure keys don't overlap60if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()):61raise ValueError("Retriever and query engine ids must not overlap.")62
63self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL64super().__init__(callback_manager, verbose=verbose)65
66def _query_retrieved_nodes(67self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]68) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:69"""Query for retrieved nodes.70
71If node is an IndexNode, then recursively query the retriever/query engine.
72If node is a TextNode, then simply return the node.
73
74"""
75nodes_to_add = []76additional_nodes = []77visited_ids = set()78
79# dedup index nodes that reference same index id80new_nodes_with_score = []81for node_with_score in nodes_with_score:82node = node_with_score.node83if isinstance(node, IndexNode):84if node.index_id not in visited_ids:85visited_ids.add(node.index_id)86new_nodes_with_score.append(node_with_score)87else:88new_nodes_with_score.append(node_with_score)89
90nodes_with_score = new_nodes_with_score91
92# recursively retrieve93for node_with_score in nodes_with_score:94node = node_with_score.node95if isinstance(node, IndexNode):96if self._verbose:97print_text(98"Retrieved node with id, entering: " f"{node.index_id}\n",99color="pink",100)101cur_retrieved_nodes, cur_additional_nodes = self._retrieve_rec(102query_bundle,103query_id=node.index_id,104cur_similarity=node_with_score.score,105)106else:107assert isinstance(node, TextNode)108if self._verbose:109print_text(110"Retrieving text node: " f"{node.get_content()}\n",111color="pink",112)113cur_retrieved_nodes = [node_with_score]114cur_additional_nodes = []115nodes_to_add.extend(cur_retrieved_nodes)116additional_nodes.extend(cur_additional_nodes)117
118return nodes_to_add, additional_nodes119
120def _get_object(self, query_id: str) -> RQN_TYPE:121"""Fetch retriever or query engine."""122node = self._node_dict.get(query_id, None)123if node is not None:124return node125retriever = self._retriever_dict.get(query_id, None)126if retriever is not None:127return retriever128query_engine = self._query_engine_dict.get(query_id, None)129if query_engine is not None:130return query_engine131raise ValueError(132f"Query id {query_id} not found in either `retriever_dict` "133"or `query_engine_dict`."134)135
136def _retrieve_rec(137self,138query_bundle: QueryBundle,139query_id: Optional[str] = None,140cur_similarity: Optional[float] = None,141) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:142"""Query recursively."""143if self._verbose:144print_text(145f"Retrieving with query id {query_id}: {query_bundle.query_str}\n",146color="blue",147)148query_id = query_id or self._root_id149cur_similarity = cur_similarity or 1.0150
151obj = self._get_object(query_id)152if isinstance(obj, BaseNode):153nodes_to_add = [NodeWithScore(node=obj, score=cur_similarity)]154additional_nodes: List[NodeWithScore] = []155elif isinstance(obj, BaseRetriever):156with self.callback_manager.event(157CBEventType.RETRIEVE,158payload={EventPayload.QUERY_STR: query_bundle.query_str},159) as event:160nodes = obj.retrieve(query_bundle)161event.on_end(payload={EventPayload.NODES: nodes})162
163nodes_to_add, additional_nodes = self._query_retrieved_nodes(164query_bundle, nodes165)166
167elif isinstance(obj, BaseQueryEngine):168sub_resp = obj.query(query_bundle)169if self._verbose:170print_text(171f"Got response: {sub_resp!s}\n",172color="green",173)174# format with both the query and the response175node_text = self._query_response_tmpl.format(176query_str=query_bundle.query_str, response=str(sub_resp)177)178node = TextNode(text=node_text)179nodes_to_add = [NodeWithScore(node=node, score=cur_similarity)]180additional_nodes = sub_resp.source_nodes181else:182raise ValueError("Must be a retriever or query engine.")183
184return nodes_to_add, additional_nodes185
186def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:187retrieved_nodes, _ = self._retrieve_rec(query_bundle, query_id=None)188return retrieved_nodes189
190def retrieve_all(191self, query_bundle: QueryBundle192) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:193"""Retrieve all nodes.194
195Unlike default `retrieve` method, this also fetches additional sources.
196
197"""
198return self._retrieve_rec(query_bundle, query_id=None)199