llama-index

Форк
0
223 строки · 8.4 Кб
1
import asyncio
2
from typing import Any, Optional, Sequence
3

4
from llama_index.legacy.async_utils import run_async_tasks
5
from llama_index.legacy.prompts import BasePromptTemplate
6
from llama_index.legacy.prompts.default_prompt_selectors import (
7
    DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
8
)
9
from llama_index.legacy.prompts.mixin import PromptDictType
10
from llama_index.legacy.response_synthesizers.base import BaseSynthesizer
11
from llama_index.legacy.service_context import ServiceContext
12
from llama_index.legacy.types import RESPONSE_TEXT_TYPE, BaseModel
13

14

15
class TreeSummarize(BaseSynthesizer):
16
    """
17
    Tree summarize response builder.
18

19
    This response builder recursively merges text chunks and summarizes them
20
    in a bottom-up fashion (i.e. building a tree from leaves to root).
21

22
    More concretely, at each recursively step:
23
    1. we repack the text chunks so that each chunk fills the context window of the LLM
24
    2. if there is only one chunk, we give the final response
25
    3. otherwise, we summarize each chunk and recursively summarize the summaries.
26
    """
27

28
    def __init__(
29
        self,
30
        summary_template: Optional[BasePromptTemplate] = None,
31
        service_context: Optional[ServiceContext] = None,
32
        output_cls: Optional[BaseModel] = None,
33
        streaming: bool = False,
34
        use_async: bool = False,
35
        verbose: bool = False,
36
    ) -> None:
37
        super().__init__(
38
            service_context=service_context, streaming=streaming, output_cls=output_cls
39
        )
40
        self._summary_template = summary_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
41
        self._use_async = use_async
42
        self._verbose = verbose
43

44
    def _get_prompts(self) -> PromptDictType:
45
        """Get prompts."""
46
        return {"summary_template": self._summary_template}
47

48
    def _update_prompts(self, prompts: PromptDictType) -> None:
49
        """Update prompts."""
50
        if "summary_template" in prompts:
51
            self._summary_template = prompts["summary_template"]
52

53
    async def aget_response(
54
        self,
55
        query_str: str,
56
        text_chunks: Sequence[str],
57
        **response_kwargs: Any,
58
    ) -> RESPONSE_TEXT_TYPE:
59
        """Get tree summarize response."""
60
        summary_template = self._summary_template.partial_format(query_str=query_str)
61
        # repack text_chunks so that each chunk fills the context window
62
        text_chunks = self._service_context.prompt_helper.repack(
63
            summary_template, text_chunks=text_chunks
64
        )
65

66
        if self._verbose:
67
            print(f"{len(text_chunks)} text chunks after repacking")
68

69
        # give final response if there is only one chunk
70
        if len(text_chunks) == 1:
71
            response: RESPONSE_TEXT_TYPE
72
            if self._streaming:
73
                response = self._service_context.llm.stream(
74
                    summary_template, context_str=text_chunks[0], **response_kwargs
75
                )
76
            else:
77
                if self._output_cls is None:
78
                    response = await self._service_context.llm.apredict(
79
                        summary_template,
80
                        context_str=text_chunks[0],
81
                        **response_kwargs,
82
                    )
83
                else:
84
                    response = await self._service_context.llm.astructured_predict(
85
                        self._output_cls,
86
                        summary_template,
87
                        context_str=text_chunks[0],
88
                        **response_kwargs,
89
                    )
90

91
            # return pydantic object if output_cls is specified
92
            return response
93

94
        else:
95
            # summarize each chunk
96
            if self._output_cls is None:
97
                tasks = [
98
                    self._service_context.llm.apredict(
99
                        summary_template,
100
                        context_str=text_chunk,
101
                        **response_kwargs,
102
                    )
103
                    for text_chunk in text_chunks
104
                ]
105
            else:
106
                tasks = [
107
                    self._service_context.llm.astructured_predict(
108
                        self._output_cls,
109
                        summary_template,
110
                        context_str=text_chunk,
111
                        **response_kwargs,
112
                    )
113
                    for text_chunk in text_chunks
114
                ]
115

116
            summary_responses = await asyncio.gather(*tasks)
117
            if self._output_cls is not None:
118
                summaries = [summary.json() for summary in summary_responses]
119
            else:
120
                summaries = summary_responses
121

122
            # recursively summarize the summaries
123
            return await self.aget_response(
124
                query_str=query_str,
125
                text_chunks=summaries,
126
                **response_kwargs,
127
            )
128

129
    def get_response(
130
        self,
131
        query_str: str,
132
        text_chunks: Sequence[str],
133
        **response_kwargs: Any,
134
    ) -> RESPONSE_TEXT_TYPE:
135
        """Get tree summarize response."""
136
        summary_template = self._summary_template.partial_format(query_str=query_str)
137
        # repack text_chunks so that each chunk fills the context window
138
        text_chunks = self._service_context.prompt_helper.repack(
139
            summary_template, text_chunks=text_chunks
140
        )
141

142
        if self._verbose:
143
            print(f"{len(text_chunks)} text chunks after repacking")
144

145
        # give final response if there is only one chunk
146
        if len(text_chunks) == 1:
147
            response: RESPONSE_TEXT_TYPE
148
            if self._streaming:
149
                response = self._service_context.llm.stream(
150
                    summary_template, context_str=text_chunks[0], **response_kwargs
151
                )
152
            else:
153
                if self._output_cls is None:
154
                    response = self._service_context.llm.predict(
155
                        summary_template,
156
                        context_str=text_chunks[0],
157
                        **response_kwargs,
158
                    )
159
                else:
160
                    response = self._service_context.llm.structured_predict(
161
                        self._output_cls,
162
                        summary_template,
163
                        context_str=text_chunks[0],
164
                        **response_kwargs,
165
                    )
166

167
            return response
168

169
        else:
170
            # summarize each chunk
171
            if self._use_async:
172
                if self._output_cls is None:
173
                    tasks = [
174
                        self._service_context.llm.apredict(
175
                            summary_template,
176
                            context_str=text_chunk,
177
                            **response_kwargs,
178
                        )
179
                        for text_chunk in text_chunks
180
                    ]
181
                else:
182
                    tasks = [
183
                        self._service_context.llm.astructured_predict(
184
                            self._output_cls,
185
                            summary_template,
186
                            context_str=text_chunk,
187
                            **response_kwargs,
188
                        )
189
                        for text_chunk in text_chunks
190
                    ]
191

192
                summary_responses = run_async_tasks(tasks)
193

194
                if self._output_cls is not None:
195
                    summaries = [summary.json() for summary in summary_responses]
196
                else:
197
                    summaries = summary_responses
198
            else:
199
                if self._output_cls is None:
200
                    summaries = [
201
                        self._service_context.llm.predict(
202
                            summary_template,
203
                            context_str=text_chunk,
204
                            **response_kwargs,
205
                        )
206
                        for text_chunk in text_chunks
207
                    ]
208
                else:
209
                    summaries = [
210
                        self._service_context.llm.structured_predict(
211
                            self._output_cls,
212
                            summary_template,
213
                            context_str=text_chunk,
214
                            **response_kwargs,
215
                        )
216
                        for text_chunk in text_chunks
217
                    ]
218
                    summaries = [summary.json() for summary in summaries]
219

220
            # recursively summarize the summaries
221
            return self.get_response(
222
                query_str=query_str, text_chunks=summaries, **response_kwargs
223
            )
224

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.