llama-index

Форк
0
272 строки · 10.5 Кб
1
import asyncio
2
import logging
3
from typing import List, Optional, Sequence, cast
4

5
from llama_index.legacy.async_utils import run_async_tasks
6
from llama_index.legacy.bridge.pydantic import BaseModel, Field
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
9
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
10
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
11
from llama_index.legacy.prompts.mixin import PromptMixinType
12
from llama_index.legacy.question_gen.llm_generators import LLMQuestionGenerator
13
from llama_index.legacy.question_gen.openai_generator import OpenAIQuestionGenerator
14
from llama_index.legacy.question_gen.types import BaseQuestionGenerator, SubQuestion
15
from llama_index.legacy.response_synthesizers import (
16
    BaseSynthesizer,
17
    get_response_synthesizer,
18
)
19
from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode
20
from llama_index.legacy.service_context import ServiceContext
21
from llama_index.legacy.tools.query_engine import QueryEngineTool
22
from llama_index.legacy.utils import get_color_mapping, print_text
23

24
logger = logging.getLogger(__name__)
25

26

27
class SubQuestionAnswerPair(BaseModel):
28
    """
29
    Pair of the sub question and optionally its answer (if its been answered yet).
30
    """
31

32
    sub_q: SubQuestion
33
    answer: Optional[str] = None
34
    sources: List[NodeWithScore] = Field(default_factory=list)
35

36

37
class SubQuestionQueryEngine(BaseQueryEngine):
38
    """Sub question query engine.
39

40
    A query engine that breaks down a complex query (e.g. compare and contrast) into
41
        many sub questions and their target query engine for execution.
42
        After executing all sub questions, all responses are gathered and sent to
43
        response synthesizer to produce the final response.
44

45
    Args:
46
        question_gen (BaseQuestionGenerator): A module for generating sub questions
47
            given a complex question and tools.
48
        response_synthesizer (BaseSynthesizer): A response synthesizer for
49
            generating the final response
50
        query_engine_tools (Sequence[QueryEngineTool]): Tools to answer the
51
            sub questions.
52
        verbose (bool): whether to print intermediate questions and answers.
53
            Defaults to True
54
        use_async (bool): whether to execute the sub questions with asyncio.
55
            Defaults to True
56
    """
57

58
    def __init__(
59
        self,
60
        question_gen: BaseQuestionGenerator,
61
        response_synthesizer: BaseSynthesizer,
62
        query_engine_tools: Sequence[QueryEngineTool],
63
        callback_manager: Optional[CallbackManager] = None,
64
        verbose: bool = True,
65
        use_async: bool = False,
66
    ) -> None:
67
        self._question_gen = question_gen
68
        self._response_synthesizer = response_synthesizer
69
        self._metadatas = [x.metadata for x in query_engine_tools]
70
        self._query_engines = {
71
            tool.metadata.name: tool.query_engine for tool in query_engine_tools
72
        }
73
        self._verbose = verbose
74
        self._use_async = use_async
75
        super().__init__(callback_manager)
76

77
    def _get_prompt_modules(self) -> PromptMixinType:
78
        """Get prompt sub-modules."""
79
        return {
80
            "question_gen": self._question_gen,
81
            "response_synthesizer": self._response_synthesizer,
82
        }
83

84
    @classmethod
85
    def from_defaults(
86
        cls,
87
        query_engine_tools: Sequence[QueryEngineTool],
88
        question_gen: Optional[BaseQuestionGenerator] = None,
89
        response_synthesizer: Optional[BaseSynthesizer] = None,
90
        service_context: Optional[ServiceContext] = None,
91
        verbose: bool = True,
92
        use_async: bool = True,
93
    ) -> "SubQuestionQueryEngine":
94
        callback_manager = None
95
        if service_context is not None:
96
            callback_manager = service_context.callback_manager
97
        elif len(query_engine_tools) > 0:
98
            callback_manager = query_engine_tools[0].query_engine.callback_manager
99

100
        service_context = service_context or ServiceContext.from_defaults()
101
        if question_gen is None:
102
            # try to use OpenAI function calling based question generator.
103
            # if incompatible, use general LLM question generator
104
            try:
105
                question_gen = OpenAIQuestionGenerator.from_defaults(
106
                    llm=service_context.llm
107
                )
108
            except ValueError:
109
                question_gen = LLMQuestionGenerator.from_defaults(
110
                    service_context=service_context
111
                )
112

113
        synth = response_synthesizer or get_response_synthesizer(
114
            callback_manager=callback_manager,
115
            service_context=service_context,
116
            use_async=use_async,
117
        )
118

119
        return cls(
120
            question_gen,
121
            synth,
122
            query_engine_tools,
123
            callback_manager=callback_manager,
124
            verbose=verbose,
125
            use_async=use_async,
126
        )
127

128
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
129
        with self.callback_manager.event(
130
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
131
        ) as query_event:
132
            sub_questions = self._question_gen.generate(self._metadatas, query_bundle)
133

134
            colors = get_color_mapping([str(i) for i in range(len(sub_questions))])
135

136
            if self._verbose:
137
                print_text(f"Generated {len(sub_questions)} sub questions.\n")
138

139
            if self._use_async:
140
                tasks = [
141
                    self._aquery_subq(sub_q, color=colors[str(ind)])
142
                    for ind, sub_q in enumerate(sub_questions)
143
                ]
144

145
                qa_pairs_all = run_async_tasks(tasks)
146
                qa_pairs_all = cast(List[Optional[SubQuestionAnswerPair]], qa_pairs_all)
147
            else:
148
                qa_pairs_all = [
149
                    self._query_subq(sub_q, color=colors[str(ind)])
150
                    for ind, sub_q in enumerate(sub_questions)
151
                ]
152

153
            # filter out sub questions that failed
154
            qa_pairs: List[SubQuestionAnswerPair] = list(filter(None, qa_pairs_all))
155

156
            nodes = [self._construct_node(pair) for pair in qa_pairs]
157

158
            source_nodes = [node for qa_pair in qa_pairs for node in qa_pair.sources]
159
            response = self._response_synthesizer.synthesize(
160
                query=query_bundle,
161
                nodes=nodes,
162
                additional_source_nodes=source_nodes,
163
            )
164

165
            query_event.on_end(payload={EventPayload.RESPONSE: response})
166

167
        return response
168

169
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
170
        with self.callback_manager.event(
171
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
172
        ) as query_event:
173
            sub_questions = await self._question_gen.agenerate(
174
                self._metadatas, query_bundle
175
            )
176

177
            colors = get_color_mapping([str(i) for i in range(len(sub_questions))])
178

179
            if self._verbose:
180
                print_text(f"Generated {len(sub_questions)} sub questions.\n")
181

182
            tasks = [
183
                self._aquery_subq(sub_q, color=colors[str(ind)])
184
                for ind, sub_q in enumerate(sub_questions)
185
            ]
186

187
            qa_pairs_all = await asyncio.gather(*tasks)
188
            qa_pairs_all = cast(List[Optional[SubQuestionAnswerPair]], qa_pairs_all)
189

190
            # filter out sub questions that failed
191
            qa_pairs: List[SubQuestionAnswerPair] = list(filter(None, qa_pairs_all))
192

193
            nodes = [self._construct_node(pair) for pair in qa_pairs]
194

195
            source_nodes = [node for qa_pair in qa_pairs for node in qa_pair.sources]
196
            response = await self._response_synthesizer.asynthesize(
197
                query=query_bundle,
198
                nodes=nodes,
199
                additional_source_nodes=source_nodes,
200
            )
201

202
            query_event.on_end(payload={EventPayload.RESPONSE: response})
203

204
        return response
205

206
    def _construct_node(self, qa_pair: SubQuestionAnswerPair) -> NodeWithScore:
207
        node_text = (
208
            f"Sub question: {qa_pair.sub_q.sub_question}\nResponse: {qa_pair.answer}"
209
        )
210
        return NodeWithScore(node=TextNode(text=node_text))
211

212
    async def _aquery_subq(
213
        self, sub_q: SubQuestion, color: Optional[str] = None
214
    ) -> Optional[SubQuestionAnswerPair]:
215
        try:
216
            with self.callback_manager.event(
217
                CBEventType.SUB_QUESTION,
218
                payload={EventPayload.SUB_QUESTION: SubQuestionAnswerPair(sub_q=sub_q)},
219
            ) as event:
220
                question = sub_q.sub_question
221
                query_engine = self._query_engines[sub_q.tool_name]
222

223
                if self._verbose:
224
                    print_text(f"[{sub_q.tool_name}] Q: {question}\n", color=color)
225

226
                response = await query_engine.aquery(question)
227
                response_text = str(response)
228

229
                if self._verbose:
230
                    print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)
231

232
                qa_pair = SubQuestionAnswerPair(
233
                    sub_q=sub_q, answer=response_text, sources=response.source_nodes
234
                )
235

236
                event.on_end(payload={EventPayload.SUB_QUESTION: qa_pair})
237

238
            return qa_pair
239
        except ValueError:
240
            logger.warning(f"[{sub_q.tool_name}] Failed to run {question}")
241
            return None
242

243
    def _query_subq(
244
        self, sub_q: SubQuestion, color: Optional[str] = None
245
    ) -> Optional[SubQuestionAnswerPair]:
246
        try:
247
            with self.callback_manager.event(
248
                CBEventType.SUB_QUESTION,
249
                payload={EventPayload.SUB_QUESTION: SubQuestionAnswerPair(sub_q=sub_q)},
250
            ) as event:
251
                question = sub_q.sub_question
252
                query_engine = self._query_engines[sub_q.tool_name]
253

254
                if self._verbose:
255
                    print_text(f"[{sub_q.tool_name}] Q: {question}\n", color=color)
256

257
                response = query_engine.query(question)
258
                response_text = str(response)
259

260
                if self._verbose:
261
                    print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color)
262

263
                qa_pair = SubQuestionAnswerPair(
264
                    sub_q=sub_q, answer=response_text, sources=response.source_nodes
265
                )
266

267
                event.on_end(payload={EventPayload.SUB_QUESTION: qa_pair})
268

269
            return qa_pair
270
        except ValueError:
271
            logger.warning(f"[{sub_q.tool_name}] Failed to run {question}")
272
            return None
273

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.