llama-index

Форк
0
142 строки · 5.7 Кб
1
"""Router retriever."""
2

3
import asyncio
4
import logging
5
from typing import List, Optional, Sequence
6

7
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
8
from llama_index.legacy.core.base_retriever import BaseRetriever
9
from llama_index.legacy.core.base_selector import BaseSelector
10
from llama_index.legacy.prompts.mixin import PromptMixinType
11
from llama_index.legacy.schema import IndexNode, NodeWithScore, QueryBundle
12
from llama_index.legacy.selectors.utils import get_selector_from_context
13
from llama_index.legacy.service_context import ServiceContext
14
from llama_index.legacy.tools.retriever_tool import RetrieverTool
15

16
logger = logging.getLogger(__name__)
17

18

19
class RouterRetriever(BaseRetriever):
20
    """Router retriever.
21

22
    Selects one (or multiple) out of several candidate retrievers to execute a query.
23

24
    Args:
25
        selector (BaseSelector): A selector that chooses one out of many options based
26
            on each candidate's metadata and query.
27
        retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
28
            retrievers. They must be wrapped as tools to expose metadata to
29
            the selector.
30
        service_context (Optional[ServiceContext]): A service context.
31

32
    """
33

34
    def __init__(
35
        self,
36
        selector: BaseSelector,
37
        retriever_tools: Sequence[RetrieverTool],
38
        service_context: Optional[ServiceContext] = None,
39
        objects: Optional[List[IndexNode]] = None,
40
        object_map: Optional[dict] = None,
41
        verbose: bool = False,
42
    ) -> None:
43
        self.service_context = service_context or ServiceContext.from_defaults()
44
        self._selector = selector
45
        self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
46
        self._metadatas = [x.metadata for x in retriever_tools]
47

48
        super().__init__(
49
            callback_manager=self.service_context.callback_manager,
50
            object_map=object_map,
51
            objects=objects,
52
            verbose=verbose,
53
        )
54

55
    def _get_prompt_modules(self) -> PromptMixinType:
56
        """Get prompt sub-modules."""
57
        # NOTE: don't include tools for now
58
        return {"selector": self._selector}
59

60
    @classmethod
61
    def from_defaults(
62
        cls,
63
        retriever_tools: Sequence[RetrieverTool],
64
        service_context: Optional[ServiceContext] = None,
65
        selector: Optional[BaseSelector] = None,
66
        select_multi: bool = False,
67
    ) -> "RouterRetriever":
68
        selector = selector or get_selector_from_context(
69
            service_context or ServiceContext.from_defaults(), is_multi=select_multi
70
        )
71

72
        return cls(
73
            selector,
74
            retriever_tools,
75
            service_context=service_context,
76
        )
77

78
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
79
        with self.callback_manager.event(
80
            CBEventType.RETRIEVE,
81
            payload={EventPayload.QUERY_STR: query_bundle.query_str},
82
        ) as query_event:
83
            result = self._selector.select(self._metadatas, query_bundle)
84

85
            if len(result.inds) > 1:
86
                retrieved_results = {}
87
                for i, engine_ind in enumerate(result.inds):
88
                    logger.info(
89
                        f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
90
                    )
91
                    selected_retriever = self._retrievers[engine_ind]
92
                    cur_results = selected_retriever.retrieve(query_bundle)
93
                    retrieved_results.update({n.node.node_id: n for n in cur_results})
94
            else:
95
                try:
96
                    selected_retriever = self._retrievers[result.ind]
97
                    logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
98
                except ValueError as e:
99
                    raise ValueError("Failed to select retriever") from e
100

101
                cur_results = selected_retriever.retrieve(query_bundle)
102
                retrieved_results = {n.node.node_id: n for n in cur_results}
103

104
            query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
105

106
        return list(retrieved_results.values())
107

108
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
109
        with self.callback_manager.event(
110
            CBEventType.RETRIEVE,
111
            payload={EventPayload.QUERY_STR: query_bundle.query_str},
112
        ) as query_event:
113
            result = await self._selector.aselect(self._metadatas, query_bundle)
114

115
            if len(result.inds) > 1:
116
                retrieved_results = {}
117
                tasks = []
118
                for i, engine_ind in enumerate(result.inds):
119
                    logger.info(
120
                        f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
121
                    )
122
                    selected_retriever = self._retrievers[engine_ind]
123
                    tasks.append(selected_retriever.aretrieve(query_bundle))
124

125
                results_of_results = await asyncio.gather(*tasks)
126
                cur_results = [
127
                    item for sublist in results_of_results for item in sublist
128
                ]
129
                retrieved_results.update({n.node.node_id: n for n in cur_results})
130
            else:
131
                try:
132
                    selected_retriever = self._retrievers[result.ind]
133
                    logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
134
                except ValueError as e:
135
                    raise ValueError("Failed to select retriever") from e
136

137
                cur_results = await selected_retriever.aretrieve(query_bundle)
138
                retrieved_results = {n.node.node_id: n for n in cur_results}
139

140
            query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
141

142
        return list(retrieved_results.values())
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.