llama-index
272 строки · 10.5 Кб
1import asyncio
2import logging
3from typing import List, Optional, Sequence, cast
4
5from llama_index.legacy.async_utils import run_async_tasks
6from llama_index.legacy.bridge.pydantic import BaseModel, Field
7from llama_index.legacy.callbacks.base import CallbackManager
8from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
9from llama_index.legacy.core.base_query_engine import BaseQueryEngine
10from llama_index.legacy.core.response.schema import RESPONSE_TYPE
11from llama_index.legacy.prompts.mixin import PromptMixinType
12from llama_index.legacy.question_gen.llm_generators import LLMQuestionGenerator
13from llama_index.legacy.question_gen.openai_generator import OpenAIQuestionGenerator
14from llama_index.legacy.question_gen.types import BaseQuestionGenerator, SubQuestion
15from llama_index.legacy.response_synthesizers import (
16BaseSynthesizer,
17get_response_synthesizer,
18)
19from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode
20from llama_index.legacy.service_context import ServiceContext
21from llama_index.legacy.tools.query_engine import QueryEngineTool
22from llama_index.legacy.utils import get_color_mapping, print_text
23
24logger = logging.getLogger(__name__)
25
26
27class SubQuestionAnswerPair(BaseModel):
28"""
29Pair of the sub question and optionally its answer (if its been answered yet).
30"""
31
32sub_q: SubQuestion
33answer: Optional[str] = None
34sources: List[NodeWithScore] = Field(default_factory=list)
35
36
37class SubQuestionQueryEngine(BaseQueryEngine):
38"""Sub question query engine.
39
40A query engine that breaks down a complex query (e.g. compare and contrast) into
41many sub questions and their target query engine for execution.
42After executing all sub questions, all responses are gathered and sent to
43response synthesizer to produce the final response.
44
45Args:
46question_gen (BaseQuestionGenerator): A module for generating sub questions
47given a complex question and tools.
48response_synthesizer (BaseSynthesizer): A response synthesizer for
49generating the final response
50query_engine_tools (Sequence[QueryEngineTool]): Tools to answer the
51sub questions.
52verbose (bool): whether to print intermediate questions and answers.
53Defaults to True
54use_async (bool): whether to execute the sub questions with asyncio.
55Defaults to True
56"""
57
58def __init__(
59self,
60question_gen: BaseQuestionGenerator,
61response_synthesizer: BaseSynthesizer,
62query_engine_tools: Sequence[QueryEngineTool],
63callback_manager: Optional[CallbackManager] = None,
64verbose: bool = True,
65use_async: bool = False,
66) -> None:
67self._question_gen = question_gen
68self._response_synthesizer = response_synthesizer
69self._metadatas = [x.metadata for x in query_engine_tools]
70self._query_engines = {
71tool.metadata.name: tool.query_engine for tool in query_engine_tools
72}
73self._verbose = verbose
74self._use_async = use_async
75super().__init__(callback_manager)
76
77def _get_prompt_modules(self) -> PromptMixinType:
78"""Get prompt sub-modules."""
79return {
80"question_gen": self._question_gen,
81"response_synthesizer": self._response_synthesizer,
82}
83
84@classmethod
85def from_defaults(
86cls,
87query_engine_tools: Sequence[QueryEngineTool],
88question_gen: Optional[BaseQuestionGenerator] = None,
89response_synthesizer: Optional[BaseSynthesizer] = None,
90service_context: Optional[ServiceContext] = None,
91verbose: bool = True,
92use_async: bool = True,
93) -> "SubQuestionQueryEngine":
94callback_manager = None
95if service_context is not None:
96callback_manager = service_context.callback_manager
97elif len(query_engine_tools) > 0:
98callback_manager = query_engine_tools[0].query_engine.callback_manager
99
100service_context = service_context or ServiceContext.from_defaults()
101if question_gen is None:
102# try to use OpenAI function calling based question generator.
103# if incompatible, use general LLM question generator
104try:
105question_gen = OpenAIQuestionGenerator.from_defaults(
106llm=service_context.llm
107)
108except ValueError:
109question_gen = LLMQuestionGenerator.from_defaults(
110service_context=service_context
111)
112
113synth = response_synthesizer or get_response_synthesizer(
114callback_manager=callback_manager,
115service_context=service_context,
116use_async=use_async,
117)
118
119return cls(
120question_gen,
121synth,
122query_engine_tools,
123callback_manager=callback_manager,
124verbose=verbose,
125use_async=use_async,
126)
127
128def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
129with self.callback_manager.event(
130CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
131) as query_event:
132sub_questions = self._question_gen.generate(self._metadatas, query_bundle)
133
134colors = get_color_mapping([str(i) for i in range(len(sub_questions))])
135
136if self._verbose:
137print_text(f"Generated {len(sub_questions)} sub questions.\n")
138
139if self._use_async:
140tasks = [
141self._aquery_subq(sub_q, color=colors[str(ind)])
142for ind, sub_q in enumerate(sub_questions)
143]
144
145qa_pairs_all = run_async_tasks(tasks)
146qa_pairs_all = cast(List[Optional[SubQuestionAnswerPair]], qa_pairs_all)
147else:
148qa_pairs_all = [
149self._query_subq(sub_q, color=colors[str(ind)])
150for ind, sub_q in enumerate(sub_questions)
151]
152
153# filter out sub questions that failed
154qa_pairs: List[SubQuestionAnswerPair] = list(filter(None, qa_pairs_all))
155
156nodes = [self._construct_node(pair) for pair in qa_pairs]
157
158source_nodes = [node for qa_pair in qa_pairs for node in qa_pair.sources]
159response = self._response_synthesizer.synthesize(
160query=query_bundle,
161nodes=nodes,
162additional_source_nodes=source_nodes,
163)
164
165query_event.on_end(payload={EventPayload.RESPONSE: response})
166
167return response
168
169async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
170with self.callback_manager.event(
171CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
172) as query_event:
173sub_questions = await self._question_gen.agenerate(
174self._metadatas, query_bundle
175)
176
177colors = get_color_mapping([str(i) for i in range(len(sub_questions))])
178
179if self._verbose:
180print_text(f"Generated {len(sub_questions)} sub questions.\n")
181
182tasks = [
183self._aquery_subq(sub_q, color=colors[str(ind)])
184for ind, sub_q in enumerate(sub_questions)
185]
186
187qa_pairs_all = await asyncio.gather(*tasks)
188qa_pairs_all = cast(List[Optional[SubQuestionAnswerPair]], qa_pairs_all)
189
190# filter out sub questions that failed
191qa_pairs: List[SubQuestionAnswerPair] = list(filter(None, qa_pairs_all))
192
193nodes = [self._construct_node(pair) for pair in qa_pairs]
194
195source_nodes = [node for qa_pair in qa_pairs for node in qa_pair.sources]
196response = await self._response_synthesizer.asynthesize(
197query=query_bundle,
198nodes=nodes,
199additional_source_nodes=source_nodes,
200)
201
202query_event.on_end(payload={EventPayload.RESPONSE: response})
203
204return response
205
206def _construct_node(self, qa_pair: SubQuestionAnswerPair) -> NodeWithScore:
207node_text = (
208f"Sub question: {qa_pair.sub_q.sub_question}\nResponse: {qa_pair.answer}"
209)
210return NodeWithScore(node=TextNode(text=node_text))
211
212async def _aquery_subq(
213self, sub_q: SubQuestion, color: Optional[str] = None
214) -> Optional[SubQuestionAnswerPair]:
215try:
216with self.callback_manager.event(
217CBEventType.SUB_QUESTION,
218payload={EventPayload.SUB_QUESTION: SubQuestionAnswerPair(sub_q=sub_q)},
219) as event:
220question = sub_q.sub_question
221query_engine = self._query_engines[sub_q.tool_name]
222
223if self._verbose:
224print_text(f"[{sub_q.tool_name}] Q: {question}\n", color=color)
225
226response = await query_engine.aquery(question)
227response_text = str(response)
228
229if self._verbose:
230print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)
231
232qa_pair = SubQuestionAnswerPair(
233sub_q=sub_q, answer=response_text, sources=response.source_nodes
234)
235
236event.on_end(payload={EventPayload.SUB_QUESTION: qa_pair})
237
238return qa_pair
239except ValueError:
240logger.warning(f"[{sub_q.tool_name}] Failed to run {question}")
241return None
242
243def _query_subq(
244self, sub_q: SubQuestion, color: Optional[str] = None
245) -> Optional[SubQuestionAnswerPair]:
246try:
247with self.callback_manager.event(
248CBEventType.SUB_QUESTION,
249payload={EventPayload.SUB_QUESTION: SubQuestionAnswerPair(sub_q=sub_q)},
250) as event:
251question = sub_q.sub_question
252query_engine = self._query_engines[sub_q.tool_name]
253
254if self._verbose:
255print_text(f"[{sub_q.tool_name}] Q: {question}\n", color=color)
256
257response = query_engine.query(question)
258response_text = str(response)
259
260if self._verbose:
261print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)
262
263qa_pair = SubQuestionAnswerPair(
264sub_q=sub_q, answer=response_text, sources=response.source_nodes
265)
266
267event.on_end(payload={EventPayload.SUB_QUESTION: qa_pair})
268
269return qa_pair
270except ValueError:
271logger.warning(f"[{sub_q.tool_name}] Failed to run {question}")
272return None
273