llama-index
142 строки · 5.7 Кб
1"""Router retriever."""
2
3import asyncio
4import logging
5from typing import List, Optional, Sequence
6
7from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
8from llama_index.legacy.core.base_retriever import BaseRetriever
9from llama_index.legacy.core.base_selector import BaseSelector
10from llama_index.legacy.prompts.mixin import PromptMixinType
11from llama_index.legacy.schema import IndexNode, NodeWithScore, QueryBundle
12from llama_index.legacy.selectors.utils import get_selector_from_context
13from llama_index.legacy.service_context import ServiceContext
14from llama_index.legacy.tools.retriever_tool import RetrieverTool
15
16logger = logging.getLogger(__name__)
17
18
19class RouterRetriever(BaseRetriever):
20"""Router retriever.
21
22Selects one (or multiple) out of several candidate retrievers to execute a query.
23
24Args:
25selector (BaseSelector): A selector that chooses one out of many options based
26on each candidate's metadata and query.
27retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
28retrievers. They must be wrapped as tools to expose metadata to
29the selector.
30service_context (Optional[ServiceContext]): A service context.
31
32"""
33
34def __init__(
35self,
36selector: BaseSelector,
37retriever_tools: Sequence[RetrieverTool],
38service_context: Optional[ServiceContext] = None,
39objects: Optional[List[IndexNode]] = None,
40object_map: Optional[dict] = None,
41verbose: bool = False,
42) -> None:
43self.service_context = service_context or ServiceContext.from_defaults()
44self._selector = selector
45self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
46self._metadatas = [x.metadata for x in retriever_tools]
47
48super().__init__(
49callback_manager=self.service_context.callback_manager,
50object_map=object_map,
51objects=objects,
52verbose=verbose,
53)
54
55def _get_prompt_modules(self) -> PromptMixinType:
56"""Get prompt sub-modules."""
57# NOTE: don't include tools for now
58return {"selector": self._selector}
59
60@classmethod
61def from_defaults(
62cls,
63retriever_tools: Sequence[RetrieverTool],
64service_context: Optional[ServiceContext] = None,
65selector: Optional[BaseSelector] = None,
66select_multi: bool = False,
67) -> "RouterRetriever":
68selector = selector or get_selector_from_context(
69service_context or ServiceContext.from_defaults(), is_multi=select_multi
70)
71
72return cls(
73selector,
74retriever_tools,
75service_context=service_context,
76)
77
78def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
79with self.callback_manager.event(
80CBEventType.RETRIEVE,
81payload={EventPayload.QUERY_STR: query_bundle.query_str},
82) as query_event:
83result = self._selector.select(self._metadatas, query_bundle)
84
85if len(result.inds) > 1:
86retrieved_results = {}
87for i, engine_ind in enumerate(result.inds):
88logger.info(
89f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
90)
91selected_retriever = self._retrievers[engine_ind]
92cur_results = selected_retriever.retrieve(query_bundle)
93retrieved_results.update({n.node.node_id: n for n in cur_results})
94else:
95try:
96selected_retriever = self._retrievers[result.ind]
97logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
98except ValueError as e:
99raise ValueError("Failed to select retriever") from e
100
101cur_results = selected_retriever.retrieve(query_bundle)
102retrieved_results = {n.node.node_id: n for n in cur_results}
103
104query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
105
106return list(retrieved_results.values())
107
108async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
109with self.callback_manager.event(
110CBEventType.RETRIEVE,
111payload={EventPayload.QUERY_STR: query_bundle.query_str},
112) as query_event:
113result = await self._selector.aselect(self._metadatas, query_bundle)
114
115if len(result.inds) > 1:
116retrieved_results = {}
117tasks = []
118for i, engine_ind in enumerate(result.inds):
119logger.info(
120f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
121)
122selected_retriever = self._retrievers[engine_ind]
123tasks.append(selected_retriever.aretrieve(query_bundle))
124
125results_of_results = await asyncio.gather(*tasks)
126cur_results = [
127item for sublist in results_of_results for item in sublist
128]
129retrieved_results.update({n.node.node_id: n for n in cur_results})
130else:
131try:
132selected_retriever = self._retrievers[result.ind]
133logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
134except ValueError as e:
135raise ValueError("Failed to select retriever") from e
136
137cur_results = await selected_retriever.aretrieve(query_bundle)
138retrieved_results = {n.node.node_id: n for n in cur_results}
139
140query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
141
142return list(retrieved_results.values())
143