llama-index

Форк
0
256 строк · 10.4 Кб
1
"""Query engines based on the FLARE paper.
2

3
Active Retrieval Augmented Generation.
4

5
"""
6

7
from typing import Any, Dict, Optional
8

9
from llama_index.legacy.callbacks.base import CallbackManager
10
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
11
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response
12
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
13
from llama_index.legacy.prompts.mixin import PromptDictType, PromptMixinType
14
from llama_index.legacy.query_engine.flare.answer_inserter import (
15
    BaseLookaheadAnswerInserter,
16
    LLMLookaheadAnswerInserter,
17
)
18
from llama_index.legacy.query_engine.flare.output_parser import (
19
    IsDoneOutputParser,
20
    QueryTaskOutputParser,
21
)
22
from llama_index.legacy.schema import QueryBundle
23
from llama_index.legacy.service_context import ServiceContext
24
from llama_index.legacy.utils import print_text
25

26
# These prompts are taken from the FLARE repo:
27
# https://github.com/jzbjyb/FLARE/blob/main/src/templates.py
28

29
DEFAULT_EXAMPLES = """
30
Query: But what are the risks during production of nanomaterials?
31
Answer: [Search(What are some nanomaterial production risks?)]
32

33
Query: The colors on the flag of Ghana have the following meanings.
34
Answer: Red is for [Search(What is the meaning of Ghana's flag being red?)], \
35
    green for forests, and gold for mineral wealth.
36

37
Query: What did the author do during his time in college?
38
Answer: The author took classes in [Search(What classes did the author take in \
39
    college?)].
40

41
"""
42

43
DEFAULT_FIRST_SKILL = f"""\
44
Skill 1. Use the Search API to look up relevant information by writing \
45
    "[Search(query)]" where "query" is the search query you want to look up. \
46
    For example:
47
{DEFAULT_EXAMPLES}
48

49
"""
50

51
DEFAULT_SECOND_SKILL = """\
52
Skill 2. Solve more complex generation tasks by thinking step by step. For example:
53

54
Query: Give a summary of the author's life and career.
55
Answer: The author was born in 1990. Growing up, he [Search(What did the \
56
    author do during his childhood?)].
57

58
Query: Can you write a summary of the Great Gatsby.
59
Answer: The Great Gatsby is a novel written by F. Scott Fitzgerald. It is about \
60
    [Search(What is the Great Gatsby about?)].
61

62
"""
63

64
DEFAULT_END = """
65
Now given the following task, and the stub of an existing answer, generate the \
66
next portion of the answer. You may use the Search API \
67
"[Search(query)]" whenever possible.
68
If the answer is complete and no longer contains any "[Search(query)]" tags, write \
69
    "done" to finish the task.
70
Do not write "done" if the answer still contains "[Search(query)]" tags.
71
Do not make up answers. It is better to generate one "[Search(query)]" tag and stop \
72
generation
73
than to fill in the answer with made up information with no "[Search(query)]" tags
74
or multiple "[Search(query)]" tags that assume a structure in the answer.
75
Try to limit generation to one sentence if possible.
76

77
"""
78

79
DEFAULT_INSTRUCT_PROMPT_TMPL = (
80
    DEFAULT_FIRST_SKILL
81
    + DEFAULT_SECOND_SKILL
82
    + DEFAULT_END
83
    + (
84
        """
85
Query: {query_str}
86
Existing Answer: {existing_answer}
87
Answer: """
88
    )
89
)
90

91
DEFAULT_INSTRUCT_PROMPT = PromptTemplate(DEFAULT_INSTRUCT_PROMPT_TMPL)
92

93

94
class FLAREInstructQueryEngine(BaseQueryEngine):
95
    """FLARE Instruct query engine.
96

97
    This is the version of FLARE that uses retrieval-encouraging instructions.
98

99
    NOTE: this is a beta feature. Interfaces might change, and it might not
100
    always give correct answers.
101

102
    Args:
103
        query_engine (BaseQueryEngine): query engine to use
104
        service_context (Optional[ServiceContext]): service context.
105
            Defaults to None.
106
        instruct_prompt (Optional[PromptTemplate]): instruct prompt. Defaults to None.
107
        lookahead_answer_inserter (Optional[BaseLookaheadAnswerInserter]):
108
            lookahead answer inserter. Defaults to None.
109
        done_output_parser (Optional[IsDoneOutputParser]): done output parser.
110
            Defaults to None.
111
        query_task_output_parser (Optional[QueryTaskOutputParser]):
112
            query task output parser. Defaults to None.
113
        max_iterations (int): max iterations. Defaults to 10.
114
        max_lookahead_query_tasks (int): max lookahead query tasks. Defaults to 1.
115
        callback_manager (Optional[CallbackManager]): callback manager.
116
            Defaults to None.
117
        verbose (bool): give verbose outputs. Defaults to False.
118

119
    """
120

121
    def __init__(
122
        self,
123
        query_engine: BaseQueryEngine,
124
        service_context: Optional[ServiceContext] = None,
125
        instruct_prompt: Optional[BasePromptTemplate] = None,
126
        lookahead_answer_inserter: Optional[BaseLookaheadAnswerInserter] = None,
127
        done_output_parser: Optional[IsDoneOutputParser] = None,
128
        query_task_output_parser: Optional[QueryTaskOutputParser] = None,
129
        max_iterations: int = 10,
130
        max_lookahead_query_tasks: int = 1,
131
        callback_manager: Optional[CallbackManager] = None,
132
        verbose: bool = False,
133
    ) -> None:
134
        """Init params."""
135
        super().__init__(callback_manager=callback_manager)
136
        self._query_engine = query_engine
137
        self._service_context = service_context or ServiceContext.from_defaults()
138
        self._instruct_prompt = instruct_prompt or DEFAULT_INSTRUCT_PROMPT
139
        self._lookahead_answer_inserter = lookahead_answer_inserter or (
140
            LLMLookaheadAnswerInserter(service_context=self._service_context)
141
        )
142
        self._done_output_parser = done_output_parser or IsDoneOutputParser()
143
        self._query_task_output_parser = (
144
            query_task_output_parser or QueryTaskOutputParser()
145
        )
146
        self._max_iterations = max_iterations
147
        self._max_lookahead_query_tasks = max_lookahead_query_tasks
148
        self._verbose = verbose
149

150
    def _get_prompts(self) -> Dict[str, Any]:
151
        """Get prompts."""
152
        return {
153
            "instruct_prompt": self._instruct_prompt,
154
        }
155

156
    def _update_prompts(self, prompts: PromptDictType) -> None:
157
        """Update prompts."""
158
        if "instruct_prompt" in prompts:
159
            self._instruct_prompt = prompts["instruct_prompt"]
160

161
    def _get_prompt_modules(self) -> PromptMixinType:
162
        """Get prompt sub-modules."""
163
        return {
164
            "query_engine": self._query_engine,
165
            "lookahead_answer_inserter": self._lookahead_answer_inserter,
166
        }
167

168
    def _get_relevant_lookahead_response(self, updated_lookahead_resp: str) -> str:
169
        """Get relevant lookahead response."""
170
        # if there's remaining query tasks, then truncate the response
171
        # until the start position of the first tag
172
        # there may be remaining query tasks because the _max_lookahead_query_tasks
173
        # is less than the total number of generated [Search(query)] tags
174
        remaining_query_tasks = self._query_task_output_parser.parse(
175
            updated_lookahead_resp
176
        )
177
        if len(remaining_query_tasks) == 0:
178
            relevant_lookahead_resp = updated_lookahead_resp
179
        else:
180
            first_task = remaining_query_tasks[0]
181
            relevant_lookahead_resp = updated_lookahead_resp[: first_task.start_idx]
182
        return relevant_lookahead_resp
183

184
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
185
        """Query and get response."""
186
        print_text(f"Query: {query_bundle.query_str}\n", color="green")
187
        cur_response = ""
188
        source_nodes = []
189
        for iter in range(self._max_iterations):
190
            if self._verbose:
191
                print_text(f"Current response: {cur_response}\n", color="blue")
192
            # generate "lookahead response" that contains "[Search(query)]" tags
193
            # e.g.
194
            # The colors on the flag of Ghana have the following meanings. Red is
195
            # for [Search(Ghana flag meaning)],...
196
            lookahead_resp = self._service_context.llm.predict(
197
                self._instruct_prompt,
198
                query_str=query_bundle.query_str,
199
                existing_answer=cur_response,
200
            )
201
            lookahead_resp = lookahead_resp.strip()
202
            if self._verbose:
203
                print_text(f"Lookahead response: {lookahead_resp}\n", color="pink")
204

205
            is_done, fmt_lookahead = self._done_output_parser.parse(lookahead_resp)
206
            if is_done:
207
                cur_response = cur_response.strip() + " " + fmt_lookahead.strip()
208
                break
209

210
            # parse lookahead response into query tasks
211
            query_tasks = self._query_task_output_parser.parse(lookahead_resp)
212

213
            # get answers for each query task
214
            query_tasks = query_tasks[: self._max_lookahead_query_tasks]
215
            query_answers = []
216
            for _, query_task in enumerate(query_tasks):
217
                answer_obj = self._query_engine.query(query_task.query_str)
218
                if not isinstance(answer_obj, Response):
219
                    raise ValueError(
220
                        f"Expected Response object, got {type(answer_obj)} instead."
221
                    )
222
                query_answer = str(answer_obj)
223
                query_answers.append(query_answer)
224
                source_nodes.extend(answer_obj.source_nodes)
225

226
            # fill in the lookahead response template with the query answers
227
            # from the query engine
228
            updated_lookahead_resp = self._lookahead_answer_inserter.insert(
229
                lookahead_resp, query_tasks, query_answers, prev_response=cur_response
230
            )
231

232
            # get "relevant" lookahead response by truncating the updated
233
            # lookahead response until the start position of the first tag
234
            # also remove the prefix from the lookahead response, so that
235
            # we can concatenate it with the existing response
236
            relevant_lookahead_resp_wo_prefix = self._get_relevant_lookahead_response(
237
                updated_lookahead_resp
238
            )
239

240
            if self._verbose:
241
                print_text(
242
                    "Updated lookahead response: "
243
                    + f"{relevant_lookahead_resp_wo_prefix}\n",
244
                    color="pink",
245
                )
246

247
            # append the relevant lookahead response to the final response
248
            cur_response = (
249
                cur_response.strip() + " " + relevant_lookahead_resp_wo_prefix.strip()
250
            )
251

252
        # NOTE: at the moment, does not support streaming
253
        return Response(response=cur_response, source_nodes=source_nodes)
254

255
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
256
        return self._query(query_bundle)
257

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.