llama-index

Форк
0
385 строк · 15.0 Кб
1
import logging
2
from typing import Callable, List, Optional, Sequence
3

4
from llama_index.legacy.async_utils import run_async_tasks
5
from llama_index.legacy.bridge.pydantic import BaseModel
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
8
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
9
from llama_index.legacy.core.base_retriever import BaseRetriever
10
from llama_index.legacy.core.base_selector import BaseSelector
11
from llama_index.legacy.core.response.schema import (
12
    RESPONSE_TYPE,
13
    PydanticResponse,
14
    Response,
15
    StreamingResponse,
16
)
17
from llama_index.legacy.objects.base import ObjectRetriever
18
from llama_index.legacy.prompts.default_prompt_selectors import (
19
    DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
20
)
21
from llama_index.legacy.prompts.mixin import PromptMixinType
22
from llama_index.legacy.response_synthesizers import TreeSummarize
23
from llama_index.legacy.schema import BaseNode, QueryBundle
24
from llama_index.legacy.selectors.utils import get_selector_from_context
25
from llama_index.legacy.service_context import ServiceContext
26
from llama_index.legacy.tools.query_engine import QueryEngineTool
27
from llama_index.legacy.tools.types import ToolMetadata
28
from llama_index.legacy.utils import print_text
29

30
logger = logging.getLogger(__name__)
31

32

33
def combine_responses(
34
    summarizer: TreeSummarize, responses: List[RESPONSE_TYPE], query_bundle: QueryBundle
35
) -> RESPONSE_TYPE:
36
    """Combine multiple response from sub-engines."""
37
    logger.info("Combining responses from multiple query engines.")
38

39
    response_strs = []
40
    source_nodes = []
41
    for response in responses:
42
        if isinstance(response, (StreamingResponse, PydanticResponse)):
43
            response_obj = response.get_response()
44
        else:
45
            response_obj = response
46
        source_nodes.extend(response_obj.source_nodes)
47
        response_strs.append(str(response))
48

49
    summary = summarizer.get_response(query_bundle.query_str, response_strs)
50

51
    if isinstance(summary, str):
52
        return Response(response=summary, source_nodes=source_nodes)
53
    elif isinstance(summary, BaseModel):
54
        return PydanticResponse(response=summary, source_nodes=source_nodes)
55
    else:
56
        return StreamingResponse(response_gen=summary, source_nodes=source_nodes)
57

58

59
async def acombine_responses(
60
    summarizer: TreeSummarize, responses: List[RESPONSE_TYPE], query_bundle: QueryBundle
61
) -> RESPONSE_TYPE:
62
    """Async combine multiple response from sub-engines."""
63
    logger.info("Combining responses from multiple query engines.")
64

65
    response_strs = []
66
    source_nodes = []
67
    for response in responses:
68
        if isinstance(response, (StreamingResponse, PydanticResponse)):
69
            response_obj = response.get_response()
70
        else:
71
            response_obj = response
72
        source_nodes.extend(response_obj.source_nodes)
73
        response_strs.append(str(response))
74

75
    summary = await summarizer.aget_response(query_bundle.query_str, response_strs)
76

77
    if isinstance(summary, str):
78
        return Response(response=summary, source_nodes=source_nodes)
79
    elif isinstance(summary, BaseModel):
80
        return PydanticResponse(response=summary, source_nodes=source_nodes)
81
    else:
82
        return StreamingResponse(response_gen=summary, source_nodes=source_nodes)
83

84

85
class RouterQueryEngine(BaseQueryEngine):
86
    """Router query engine.
87

88
    Selects one out of several candidate query engines to execute a query.
89

90
    Args:
91
        selector (BaseSelector): A selector that chooses one out of many options based
92
            on each candidate's metadata and query.
93
        query_engine_tools (Sequence[QueryEngineTool]): A sequence of candidate
94
            query engines. They must be wrapped as tools to expose metadata to
95
            the selector.
96
        service_context (Optional[ServiceContext]): A service context.
97
        summarizer (Optional[TreeSummarize]): Tree summarizer to summarize sub-results.
98

99
    """
100

101
    def __init__(
102
        self,
103
        selector: BaseSelector,
104
        query_engine_tools: Sequence[QueryEngineTool],
105
        service_context: Optional[ServiceContext] = None,
106
        summarizer: Optional[TreeSummarize] = None,
107
        verbose: bool = False,
108
    ) -> None:
109
        self.service_context = service_context or ServiceContext.from_defaults()
110
        self._selector = selector
111
        self._query_engines = [x.query_engine for x in query_engine_tools]
112
        self._metadatas = [x.metadata for x in query_engine_tools]
113
        self._summarizer = summarizer or TreeSummarize(
114
            service_context=self.service_context,
115
            summary_template=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
116
        )
117
        self._verbose = verbose
118

119
        super().__init__(self.service_context.callback_manager)
120

121
    def _get_prompt_modules(self) -> PromptMixinType:
122
        """Get prompt sub-modules."""
123
        # NOTE: don't include tools for now
124
        return {"summarizer": self._summarizer, "selector": self._selector}
125

126
    @classmethod
127
    def from_defaults(
128
        cls,
129
        query_engine_tools: Sequence[QueryEngineTool],
130
        service_context: Optional[ServiceContext] = None,
131
        selector: Optional[BaseSelector] = None,
132
        summarizer: Optional[TreeSummarize] = None,
133
        select_multi: bool = False,
134
    ) -> "RouterQueryEngine":
135
        service_context = service_context or ServiceContext.from_defaults()
136

137
        selector = selector or get_selector_from_context(
138
            service_context, is_multi=select_multi
139
        )
140

141
        assert selector is not None
142

143
        return cls(
144
            selector,
145
            query_engine_tools,
146
            service_context=service_context,
147
            summarizer=summarizer,
148
        )
149

150
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
151
        with self.callback_manager.event(
152
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
153
        ) as query_event:
154
            result = self._selector.select(self._metadatas, query_bundle)
155

156
            if len(result.inds) > 1:
157
                responses = []
158
                for i, engine_ind in enumerate(result.inds):
159
                    log_str = (
160
                        f"Selecting query engine {engine_ind}: " f"{result.reasons[i]}."
161
                    )
162
                    logger.info(log_str)
163
                    if self._verbose:
164
                        print_text(log_str + "\n", color="pink")
165

166
                    selected_query_engine = self._query_engines[engine_ind]
167
                    responses.append(selected_query_engine.query(query_bundle))
168

169
                if len(responses) > 1:
170
                    final_response = combine_responses(
171
                        self._summarizer, responses, query_bundle
172
                    )
173
                else:
174
                    final_response = responses[0]
175
            else:
176
                try:
177
                    selected_query_engine = self._query_engines[result.ind]
178
                    log_str = f"Selecting query engine {result.ind}: {result.reason}."
179
                    logger.info(log_str)
180
                    if self._verbose:
181
                        print_text(log_str + "\n", color="pink")
182
                except ValueError as e:
183
                    raise ValueError("Failed to select query engine") from e
184

185
                final_response = selected_query_engine.query(query_bundle)
186

187
            # add selected result
188
            final_response.metadata = final_response.metadata or {}
189
            final_response.metadata["selector_result"] = result
190

191
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
192

193
        return final_response
194

195
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
196
        with self.callback_manager.event(
197
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
198
        ) as query_event:
199
            result = await self._selector.aselect(self._metadatas, query_bundle)
200

201
            if len(result.inds) > 1:
202
                tasks = []
203
                for i, engine_ind in enumerate(result.inds):
204
                    log_str = (
205
                        f"Selecting query engine {engine_ind}: " f"{result.reasons[i]}."
206
                    )
207
                    logger.info(log_str)
208
                    if self._verbose:
209
                        print_text(log_str + "\n", color="pink")
210
                    selected_query_engine = self._query_engines[engine_ind]
211
                    tasks.append(selected_query_engine.aquery(query_bundle))
212

213
                responses = run_async_tasks(tasks)
214
                if len(responses) > 1:
215
                    final_response = await acombine_responses(
216
                        self._summarizer, responses, query_bundle
217
                    )
218
                else:
219
                    final_response = responses[0]
220
            else:
221
                try:
222
                    selected_query_engine = self._query_engines[result.ind]
223
                    log_str = f"Selecting query engine {result.ind}: {result.reason}."
224
                    logger.info(log_str)
225
                    if self._verbose:
226
                        print_text(log_str + "\n", color="pink")
227
                except ValueError as e:
228
                    raise ValueError("Failed to select query engine") from e
229

230
                final_response = await selected_query_engine.aquery(query_bundle)
231

232
            # add selected result
233
            final_response.metadata = final_response.metadata or {}
234
            final_response.metadata["selector_result"] = result
235

236
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
237

238
        return final_response
239

240

241
def default_node_to_metadata_fn(node: BaseNode) -> ToolMetadata:
242
    """Default node to metadata function.
243

244
    We use the node's text as the Tool description.
245

246
    """
247
    metadata = node.metadata or {}
248
    if "tool_name" not in metadata:
249
        raise ValueError("Node must have a tool_name in metadata.")
250
    return ToolMetadata(name=metadata["tool_name"], description=node.get_content())
251

252

253
class RetrieverRouterQueryEngine(BaseQueryEngine):
254
    """Retriever-based router query engine.
255

256
    NOTE: this is deprecated, please use our new ToolRetrieverRouterQueryEngine
257

258
    Use a retriever to select a set of Nodes. Each node will be converted
259
    into a ToolMetadata object, and also used to retrieve a query engine, to form
260
    a QueryEngineTool.
261

262
    NOTE: this is a beta feature. We are figuring out the right interface
263
    between the retriever and query engine.
264

265
    Args:
266
        selector (BaseSelector): A selector that chooses one out of many options based
267
            on each candidate's metadata and query.
268
        query_engine_tools (Sequence[QueryEngineTool]): A sequence of candidate
269
            query engines. They must be wrapped as tools to expose metadata to
270
            the selector.
271
        callback_manager (Optional[CallbackManager]): A callback manager.
272

273
    """
274

275
    def __init__(
276
        self,
277
        retriever: BaseRetriever,
278
        node_to_query_engine_fn: Callable,
279
        callback_manager: Optional[CallbackManager] = None,
280
    ) -> None:
281
        self._retriever = retriever
282
        self._node_to_query_engine_fn = node_to_query_engine_fn
283
        super().__init__(callback_manager)
284

285
    def _get_prompt_modules(self) -> PromptMixinType:
286
        """Get prompt sub-modules."""
287
        # NOTE: don't include tools for now
288
        return {"retriever": self._retriever}
289

290
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
291
        nodes_with_score = self._retriever.retrieve(query_bundle)
292
        # TODO: for now we only support retrieving one node
293
        if len(nodes_with_score) > 1:
294
            raise ValueError("Retrieved more than one node.")
295

296
        node = nodes_with_score[0].node
297
        query_engine = self._node_to_query_engine_fn(node)
298
        return query_engine.query(query_bundle)
299

300
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
301
        return self._query(query_bundle)
302

303

304
class ToolRetrieverRouterQueryEngine(BaseQueryEngine):
305
    """Tool Retriever router query engine.
306

307
    Selects a set of candidate query engines to execute a query.
308

309
    Args:
310
        retriever (ObjectRetriever): A retriever that retrieves a set of
311
            query engine tools.
312
        service_context (Optional[ServiceContext]): A service context.
313
        summarizer (Optional[TreeSummarize]): Tree summarizer to summarize sub-results.
314

315
    """
316

317
    def __init__(
318
        self,
319
        retriever: ObjectRetriever[QueryEngineTool],
320
        service_context: Optional[ServiceContext] = None,
321
        summarizer: Optional[TreeSummarize] = None,
322
    ) -> None:
323
        self.service_context = service_context or ServiceContext.from_defaults()
324
        self._summarizer = summarizer or TreeSummarize(
325
            service_context=self.service_context,
326
            summary_template=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
327
        )
328
        self._retriever = retriever
329

330
        super().__init__(self.service_context.callback_manager)
331

332
    def _get_prompt_modules(self) -> PromptMixinType:
333
        """Get prompt sub-modules."""
334
        # NOTE: don't include tools for now
335
        return {"summarizer": self._summarizer}
336

337
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
338
        with self.callback_manager.event(
339
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
340
        ) as query_event:
341
            query_engine_tools = self._retriever.retrieve(query_bundle)
342
            responses = []
343
            for query_engine_tool in query_engine_tools:
344
                query_engine = query_engine_tool.query_engine
345
                responses.append(query_engine.query(query_bundle))
346

347
            if len(responses) > 1:
348
                final_response = combine_responses(
349
                    self._summarizer, responses, query_bundle
350
                )
351
            else:
352
                final_response = responses[0]
353

354
            # add selected result
355
            final_response.metadata = final_response.metadata or {}
356
            final_response.metadata["retrieved_tools"] = query_engine_tools
357

358
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
359

360
        return final_response
361

362
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
363
        with self.callback_manager.event(
364
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
365
        ) as query_event:
366
            query_engine_tools = self._retriever.retrieve(query_bundle)
367
            tasks = []
368
            for query_engine_tool in query_engine_tools:
369
                query_engine = query_engine_tool.query_engine
370
                tasks.append(query_engine.aquery(query_bundle))
371
            responses = run_async_tasks(tasks)
372
            if len(responses) > 1:
373
                final_response = await acombine_responses(
374
                    self._summarizer, responses, query_bundle
375
                )
376
            else:
377
                final_response = responses[0]
378

379
            # add selected result
380
            final_response.metadata = final_response.metadata or {}
381
            final_response.metadata["retrieved_tools"] = query_engine_tools
382

383
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
384

385
        return final_response
386

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.