llama-index

Форк
0
177 строк · 6.6 Кб
1
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
2

3
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
4
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
5
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
6
from llama_index.legacy.indices.query.query_transform.base import (
7
    StepDecomposeQueryTransform,
8
)
9
from llama_index.legacy.prompts.mixin import PromptMixinType
10
from llama_index.legacy.response_synthesizers import (
11
    BaseSynthesizer,
12
    get_response_synthesizer,
13
)
14
from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode
15

16

17
def default_stop_fn(stop_dict: Dict) -> bool:
18
    """Stop function for multi-step query combiner."""
19
    query_bundle = cast(QueryBundle, stop_dict.get("query_bundle"))
20
    if query_bundle is None:
21
        raise ValueError("Response must be provided to stop function.")
22

23
    return "none" in query_bundle.query_str.lower()
24

25

26
class MultiStepQueryEngine(BaseQueryEngine):
27
    """Multi-step query engine.
28

29
    This query engine can operate over an existing base query engine,
30
    along with the multi-step query transform.
31

32
    Args:
33
        query_engine (BaseQueryEngine): A BaseQueryEngine object.
34
        query_transform (StepDecomposeQueryTransform): A StepDecomposeQueryTransform
35
            object.
36
        response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
37
            object.
38
        num_steps (Optional[int]): Number of steps to run the multi-step query.
39
        early_stopping (bool): Whether to stop early if the stop function returns True.
40
        index_summary (str): A string summary of the index.
41
        stop_fn (Optional[Callable[[Dict], bool]]): A stop function that takes in a
42
            dictionary of information and returns a boolean.
43

44
    """
45

46
    def __init__(
47
        self,
48
        query_engine: BaseQueryEngine,
49
        query_transform: StepDecomposeQueryTransform,
50
        response_synthesizer: Optional[BaseSynthesizer] = None,
51
        num_steps: Optional[int] = 3,
52
        early_stopping: bool = True,
53
        index_summary: str = "None",
54
        stop_fn: Optional[Callable[[Dict], bool]] = None,
55
    ) -> None:
56
        self._query_engine = query_engine
57
        self._query_transform = query_transform
58
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
59
            callback_manager=self._query_engine.callback_manager
60
        )
61

62
        self._index_summary = index_summary
63
        self._num_steps = num_steps
64
        self._early_stopping = early_stopping
65
        # TODO: make interface to stop function better
66
        self._stop_fn = stop_fn or default_stop_fn
67
        # num_steps must be provided if early_stopping is False
68
        if not self._early_stopping and self._num_steps is None:
69
            raise ValueError("Must specify num_steps if early_stopping is False.")
70

71
        callback_manager = self._query_engine.callback_manager
72
        super().__init__(callback_manager)
73

74
    def _get_prompt_modules(self) -> PromptMixinType:
75
        """Get prompt sub-modules."""
76
        return {
77
            "response_synthesizer": self._response_synthesizer,
78
            "query_transform": self._query_transform,
79
        }
80

81
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
82
        with self.callback_manager.event(
83
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
84
        ) as query_event:
85
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)
86

87
            final_response = self._response_synthesizer.synthesize(
88
                query=query_bundle,
89
                nodes=nodes,
90
                additional_source_nodes=source_nodes,
91
            )
92
            final_response.metadata = metadata
93

94
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
95

96
        return final_response
97

98
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
99
        with self.callback_manager.event(
100
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
101
        ) as query_event:
102
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)
103

104
            final_response = await self._response_synthesizer.asynthesize(
105
                query=query_bundle,
106
                nodes=nodes,
107
                additional_source_nodes=source_nodes,
108
            )
109
            final_response.metadata = metadata
110

111
            query_event.on_end(payload={EventPayload.RESPONSE: final_response})
112

113
        return final_response
114

115
    def _combine_queries(
116
        self, query_bundle: QueryBundle, prev_reasoning: str
117
    ) -> QueryBundle:
118
        """Combine queries."""
119
        transform_metadata = {
120
            "prev_reasoning": prev_reasoning,
121
            "index_summary": self._index_summary,
122
        }
123
        return self._query_transform(query_bundle, metadata=transform_metadata)
124

125
    def _query_multistep(
126
        self, query_bundle: QueryBundle
127
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore], Dict[str, Any]]:
128
        """Run query combiner."""
129
        prev_reasoning = ""
130
        cur_response = None
131
        should_stop = False
132
        cur_steps = 0
133

134
        # use response
135
        final_response_metadata: Dict[str, Any] = {"sub_qa": []}
136

137
        text_chunks = []
138
        source_nodes = []
139
        while not should_stop:
140
            if self._num_steps is not None and cur_steps >= self._num_steps:
141
                should_stop = True
142
                break
143
            elif should_stop:
144
                break
145

146
            updated_query_bundle = self._combine_queries(query_bundle, prev_reasoning)
147

148
            # TODO: make stop logic better
149
            stop_dict = {"query_bundle": updated_query_bundle}
150
            if self._stop_fn(stop_dict):
151
                should_stop = True
152
                break
153

154
            cur_response = self._query_engine.query(updated_query_bundle)
155

156
            # append to response builder
157
            cur_qa_text = (
158
                f"\nQuestion: {updated_query_bundle.query_str}\n"
159
                f"Answer: {cur_response!s}"
160
            )
161
            text_chunks.append(cur_qa_text)
162
            for source_node in cur_response.source_nodes:
163
                source_nodes.append(source_node)
164
            # update metadata
165
            final_response_metadata["sub_qa"].append(
166
                (updated_query_bundle.query_str, cur_response)
167
            )
168

169
            prev_reasoning += (
170
                f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
171
            )
172
            cur_steps += 1
173

174
        nodes = [
175
            NodeWithScore(node=TextNode(text=text_chunk)) for text_chunk in text_chunks
176
        ]
177
        return nodes, source_nodes, final_response_metadata
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.