llama-index
93 строки · 3.4 Кб
1from typing import List, Optional, Sequence
2
3from llama_index.legacy.callbacks.base import CallbackManager
4from llama_index.legacy.core.base_query_engine import BaseQueryEngine
5from llama_index.legacy.core.response.schema import RESPONSE_TYPE
6from llama_index.legacy.indices.query.query_transform.base import BaseQueryTransform
7from llama_index.legacy.prompts.mixin import PromptMixinType
8from llama_index.legacy.schema import NodeWithScore, QueryBundle
9
10
11class TransformQueryEngine(BaseQueryEngine):
12"""Transform query engine.
13
14Applies a query transform to a query bundle before passing
15it to a query engine.
16
17Args:
18query_engine (BaseQueryEngine): A query engine object.
19query_transform (BaseQueryTransform): A query transform object.
20transform_metadata (Optional[dict]): metadata to pass to the
21query transform.
22callback_manager (Optional[CallbackManager]): A callback manager.
23
24"""
25
26def __init__(
27self,
28query_engine: BaseQueryEngine,
29query_transform: BaseQueryTransform,
30transform_metadata: Optional[dict] = None,
31callback_manager: Optional[CallbackManager] = None,
32) -> None:
33self._query_engine = query_engine
34self._query_transform = query_transform
35self._transform_metadata = transform_metadata
36super().__init__(callback_manager)
37
38def _get_prompt_modules(self) -> PromptMixinType:
39"""Get prompt sub-modules."""
40return {
41"query_transform": self._query_transform,
42"query_engine": self._query_engine,
43}
44
45def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
46query_bundle = self._query_transform.run(
47query_bundle, metadata=self._transform_metadata
48)
49return self._query_engine.retrieve(query_bundle)
50
51def synthesize(
52self,
53query_bundle: QueryBundle,
54nodes: List[NodeWithScore],
55additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
56) -> RESPONSE_TYPE:
57query_bundle = self._query_transform.run(
58query_bundle, metadata=self._transform_metadata
59)
60return self._query_engine.synthesize(
61query_bundle=query_bundle,
62nodes=nodes,
63additional_source_nodes=additional_source_nodes,
64)
65
66async def asynthesize(
67self,
68query_bundle: QueryBundle,
69nodes: List[NodeWithScore],
70additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
71) -> RESPONSE_TYPE:
72query_bundle = self._query_transform.run(
73query_bundle, metadata=self._transform_metadata
74)
75return await self._query_engine.asynthesize(
76query_bundle=query_bundle,
77nodes=nodes,
78additional_source_nodes=additional_source_nodes,
79)
80
81def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
82"""Answer a query."""
83query_bundle = self._query_transform.run(
84query_bundle, metadata=self._transform_metadata
85)
86return self._query_engine.query(query_bundle)
87
88async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
89"""Answer a query."""
90query_bundle = self._query_transform.run(
91query_bundle, metadata=self._transform_metadata
92)
93return await self._query_engine.aquery(query_bundle)
94