llama-index

Форк
0
1
import logging
2
from typing import Any, Callable, Generator, Optional, Sequence, Type, cast
3

4
from llama_index.legacy.bridge.pydantic import BaseModel, Field, ValidationError
5
from llama_index.legacy.indices.utils import truncate_text
6
from llama_index.legacy.llm_predictor.base import LLMPredictorType
7
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
8
from llama_index.legacy.prompts.default_prompt_selectors import (
9
    DEFAULT_REFINE_PROMPT_SEL,
10
    DEFAULT_TEXT_QA_PROMPT_SEL,
11
)
12
from llama_index.legacy.prompts.mixin import PromptDictType
13
from llama_index.legacy.response.utils import get_response_text
14
from llama_index.legacy.response_synthesizers.base import BaseSynthesizer
15
from llama_index.legacy.service_context import ServiceContext
16
from llama_index.legacy.types import RESPONSE_TEXT_TYPE, BasePydanticProgram
17

18
logger = logging.getLogger(__name__)
19

20

21
class StructuredRefineResponse(BaseModel):
22
    """
23
    Used to answer a given query based on the provided context.
24

25
    Also indicates if the query was satisfied with the provided answer.
26
    """
27

28
    answer: str = Field(
29
        description="The answer for the given query, based on the context and not "
30
        "prior knowledge."
31
    )
32
    query_satisfied: bool = Field(
33
        description="True if there was enough context given to provide an answer "
34
        "that satisfies the query."
35
    )
36

37

38
class DefaultRefineProgram(BasePydanticProgram):
39
    """
40
    Runs the query on the LLM as normal and always returns the answer with
41
    query_satisfied=True. In effect, doesn't do any answer filtering.
42
    """
43

44
    def __init__(
45
        self, prompt: BasePromptTemplate, llm: LLMPredictorType, output_cls: BaseModel
46
    ):
47
        self._prompt = prompt
48
        self._llm = llm
49
        self._output_cls = output_cls
50

51
    @property
52
    def output_cls(self) -> Type[BaseModel]:
53
        return StructuredRefineResponse
54

55
    def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
56
        if self._output_cls is not None:
57
            answer = self._llm.structured_predict(
58
                self._output_cls,
59
                self._prompt,
60
                **kwds,
61
            )
62
            answer = answer.json()
63
        else:
64
            answer = self._llm.predict(
65
                self._prompt,
66
                **kwds,
67
            )
68
        return StructuredRefineResponse(answer=answer, query_satisfied=True)
69

70
    async def acall(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
71
        if self._output_cls is not None:
72
            answer = await self._llm.astructured_predict(
73
                self._output_cls,
74
                self._prompt,
75
                **kwds,
76
            )
77
            answer = answer.json()
78
        else:
79
            answer = await self._llm.apredict(
80
                self._prompt,
81
                **kwds,
82
            )
83
        return StructuredRefineResponse(answer=answer, query_satisfied=True)
84

85

86
class Refine(BaseSynthesizer):
87
    """Refine a response to a query across text chunks."""
88

89
    def __init__(
90
        self,
91
        service_context: Optional[ServiceContext] = None,
92
        text_qa_template: Optional[BasePromptTemplate] = None,
93
        refine_template: Optional[BasePromptTemplate] = None,
94
        output_cls: Optional[BaseModel] = None,
95
        streaming: bool = False,
96
        verbose: bool = False,
97
        structured_answer_filtering: bool = False,
98
        program_factory: Optional[
99
            Callable[[BasePromptTemplate], BasePydanticProgram]
100
        ] = None,
101
    ) -> None:
102
        super().__init__(service_context=service_context, streaming=streaming)
103
        self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
104
        self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
105
        self._verbose = verbose
106
        self._structured_answer_filtering = structured_answer_filtering
107
        self._output_cls = output_cls
108

109
        if self._streaming and self._structured_answer_filtering:
110
            raise ValueError(
111
                "Streaming not supported with structured answer filtering."
112
            )
113
        if not self._structured_answer_filtering and program_factory is not None:
114
            raise ValueError(
115
                "Program factory not supported without structured answer filtering."
116
            )
117
        self._program_factory = program_factory or self._default_program_factory
118

119
    def _get_prompts(self) -> PromptDictType:
120
        """Get prompts."""
121
        return {
122
            "text_qa_template": self._text_qa_template,
123
            "refine_template": self._refine_template,
124
        }
125

126
    def _update_prompts(self, prompts: PromptDictType) -> None:
127
        """Update prompts."""
128
        if "text_qa_template" in prompts:
129
            self._text_qa_template = prompts["text_qa_template"]
130
        if "refine_template" in prompts:
131
            self._refine_template = prompts["refine_template"]
132

133
    def get_response(
134
        self,
135
        query_str: str,
136
        text_chunks: Sequence[str],
137
        prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
138
        **response_kwargs: Any,
139
    ) -> RESPONSE_TEXT_TYPE:
140
        """Give response over chunks."""
141
        response: Optional[RESPONSE_TEXT_TYPE] = None
142
        for text_chunk in text_chunks:
143
            if prev_response is None:
144
                # if this is the first chunk, and text chunk already
145
                # is an answer, then return it
146
                response = self._give_response_single(
147
                    query_str, text_chunk, **response_kwargs
148
                )
149
            else:
150
                # refine response if possible
151
                response = self._refine_response_single(
152
                    prev_response, query_str, text_chunk, **response_kwargs
153
                )
154
            prev_response = response
155
        if isinstance(response, str):
156
            if self._output_cls is not None:
157
                response = self._output_cls.parse_raw(response)
158
            else:
159
                response = response or "Empty Response"
160
        else:
161
            response = cast(Generator, response)
162
        return response
163

164
    def _default_program_factory(self, prompt: PromptTemplate) -> BasePydanticProgram:
165
        if self._structured_answer_filtering:
166
            from llama_index.legacy.program.utils import get_program_for_llm
167

168
            return get_program_for_llm(
169
                StructuredRefineResponse,
170
                prompt,
171
                self._service_context.llm,
172
                verbose=self._verbose,
173
            )
174
        else:
175
            return DefaultRefineProgram(
176
                prompt=prompt,
177
                llm=self._service_context.llm,
178
                output_cls=self._output_cls,
179
            )
180

181
    def _give_response_single(
182
        self,
183
        query_str: str,
184
        text_chunk: str,
185
        **response_kwargs: Any,
186
    ) -> RESPONSE_TEXT_TYPE:
187
        """Give response given a query and a corresponding text chunk."""
188
        text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
189
        text_chunks = self._service_context.prompt_helper.repack(
190
            text_qa_template, [text_chunk]
191
        )
192

193
        response: Optional[RESPONSE_TEXT_TYPE] = None
194
        program = self._program_factory(text_qa_template)
195
        # TODO: consolidate with loop in get_response_default
196
        for cur_text_chunk in text_chunks:
197
            query_satisfied = False
198
            if response is None and not self._streaming:
199
                try:
200
                    structured_response = cast(
201
                        StructuredRefineResponse,
202
                        program(
203
                            context_str=cur_text_chunk,
204
                            **response_kwargs,
205
                        ),
206
                    )
207
                    query_satisfied = structured_response.query_satisfied
208
                    if query_satisfied:
209
                        response = structured_response.answer
210
                except ValidationError as e:
211
                    logger.warning(
212
                        f"Validation error on structured response: {e}", exc_info=True
213
                    )
214
            elif response is None and self._streaming:
215
                response = self._service_context.llm.stream(
216
                    text_qa_template,
217
                    context_str=cur_text_chunk,
218
                    **response_kwargs,
219
                )
220
                query_satisfied = True
221
            else:
222
                response = self._refine_response_single(
223
                    cast(RESPONSE_TEXT_TYPE, response),
224
                    query_str,
225
                    cur_text_chunk,
226
                    **response_kwargs,
227
                )
228
        if response is None:
229
            response = "Empty Response"
230
        if isinstance(response, str):
231
            response = response or "Empty Response"
232
        else:
233
            response = cast(Generator, response)
234
        return response
235

236
    def _refine_response_single(
237
        self,
238
        response: RESPONSE_TEXT_TYPE,
239
        query_str: str,
240
        text_chunk: str,
241
        **response_kwargs: Any,
242
    ) -> Optional[RESPONSE_TEXT_TYPE]:
243
        """Refine response."""
244
        # TODO: consolidate with logic in response/schema.py
245
        if isinstance(response, Generator):
246
            response = get_response_text(response)
247

248
        fmt_text_chunk = truncate_text(text_chunk, 50)
249
        logger.debug(f"> Refine context: {fmt_text_chunk}")
250
        if self._verbose:
251
            print(f"> Refine context: {fmt_text_chunk}")
252

253
        # NOTE: partial format refine template with query_str and existing_answer here
254
        refine_template = self._refine_template.partial_format(
255
            query_str=query_str, existing_answer=response
256
        )
257

258
        # compute available chunk size to see if there is any available space
259
        # determine if the refine template is too big (which can happen if
260
        # prompt template + query + existing answer is too large)
261
        avail_chunk_size = (
262
            self._service_context.prompt_helper._get_available_chunk_size(
263
                refine_template
264
            )
265
        )
266

267
        if avail_chunk_size < 0:
268
            # if the available chunk size is negative, then the refine template
269
            # is too big and we just return the original response
270
            return response
271

272
        # obtain text chunks to add to the refine template
273
        text_chunks = self._service_context.prompt_helper.repack(
274
            refine_template, text_chunks=[text_chunk]
275
        )
276

277
        program = self._program_factory(refine_template)
278
        for cur_text_chunk in text_chunks:
279
            query_satisfied = False
280
            if not self._streaming:
281
                try:
282
                    structured_response = cast(
283
                        StructuredRefineResponse,
284
                        program(
285
                            context_msg=cur_text_chunk,
286
                            **response_kwargs,
287
                        ),
288
                    )
289
                    query_satisfied = structured_response.query_satisfied
290
                    if query_satisfied:
291
                        response = structured_response.answer
292
                except ValidationError as e:
293
                    logger.warning(
294
                        f"Validation error on structured response: {e}", exc_info=True
295
                    )
296
            else:
297
                # TODO: structured response not supported for streaming
298
                if isinstance(response, Generator):
299
                    response = "".join(response)
300

301
                refine_template = self._refine_template.partial_format(
302
                    query_str=query_str, existing_answer=response
303
                )
304

305
                response = self._service_context.llm.stream(
306
                    refine_template,
307
                    context_msg=cur_text_chunk,
308
                    **response_kwargs,
309
                )
310

311
        return response
312

313
    async def aget_response(
314
        self,
315
        query_str: str,
316
        text_chunks: Sequence[str],
317
        prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
318
        **response_kwargs: Any,
319
    ) -> RESPONSE_TEXT_TYPE:
320
        response: Optional[RESPONSE_TEXT_TYPE] = None
321
        for text_chunk in text_chunks:
322
            if prev_response is None:
323
                # if this is the first chunk, and text chunk already
324
                # is an answer, then return it
325
                response = await self._agive_response_single(
326
                    query_str, text_chunk, **response_kwargs
327
                )
328
            else:
329
                response = await self._arefine_response_single(
330
                    prev_response, query_str, text_chunk, **response_kwargs
331
                )
332
            prev_response = response
333
        if response is None:
334
            response = "Empty Response"
335
        if isinstance(response, str):
336
            if self._output_cls is not None:
337
                response = self._output_cls.parse_raw(response)
338
            else:
339
                response = response or "Empty Response"
340
        else:
341
            response = cast(Generator, response)
342
        return response
343

344
    async def _arefine_response_single(
345
        self,
346
        response: RESPONSE_TEXT_TYPE,
347
        query_str: str,
348
        text_chunk: str,
349
        **response_kwargs: Any,
350
    ) -> Optional[RESPONSE_TEXT_TYPE]:
351
        """Refine response."""
352
        # TODO: consolidate with logic in response/schema.py
353
        if isinstance(response, Generator):
354
            response = get_response_text(response)
355

356
        fmt_text_chunk = truncate_text(text_chunk, 50)
357
        logger.debug(f"> Refine context: {fmt_text_chunk}")
358

359
        # NOTE: partial format refine template with query_str and existing_answer here
360
        refine_template = self._refine_template.partial_format(
361
            query_str=query_str, existing_answer=response
362
        )
363

364
        # compute available chunk size to see if there is any available space
365
        # determine if the refine template is too big (which can happen if
366
        # prompt template + query + existing answer is too large)
367
        avail_chunk_size = (
368
            self._service_context.prompt_helper._get_available_chunk_size(
369
                refine_template
370
            )
371
        )
372

373
        if avail_chunk_size < 0:
374
            # if the available chunk size is negative, then the refine template
375
            # is too big and we just return the original response
376
            return response
377

378
        # obtain text chunks to add to the refine template
379
        text_chunks = self._service_context.prompt_helper.repack(
380
            refine_template, text_chunks=[text_chunk]
381
        )
382

383
        program = self._program_factory(refine_template)
384
        for cur_text_chunk in text_chunks:
385
            query_satisfied = False
386
            if not self._streaming:
387
                try:
388
                    structured_response = await program.acall(
389
                        context_msg=cur_text_chunk,
390
                        **response_kwargs,
391
                    )
392
                    structured_response = cast(
393
                        StructuredRefineResponse, structured_response
394
                    )
395
                    query_satisfied = structured_response.query_satisfied
396
                    if query_satisfied:
397
                        response = structured_response.answer
398
                except ValidationError as e:
399
                    logger.warning(
400
                        f"Validation error on structured response: {e}", exc_info=True
401
                    )
402
            else:
403
                raise ValueError("Streaming not supported for async")
404

405
            if query_satisfied:
406
                refine_template = self._refine_template.partial_format(
407
                    query_str=query_str, existing_answer=response
408
                )
409

410
        return response
411

412
    async def _agive_response_single(
413
        self,
414
        query_str: str,
415
        text_chunk: str,
416
        **response_kwargs: Any,
417
    ) -> RESPONSE_TEXT_TYPE:
418
        """Give response given a query and a corresponding text chunk."""
419
        text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
420
        text_chunks = self._service_context.prompt_helper.repack(
421
            text_qa_template, [text_chunk]
422
        )
423

424
        response: Optional[RESPONSE_TEXT_TYPE] = None
425
        program = self._program_factory(text_qa_template)
426
        # TODO: consolidate with loop in get_response_default
427
        for cur_text_chunk in text_chunks:
428
            if response is None and not self._streaming:
429
                try:
430
                    structured_response = await program.acall(
431
                        context_str=cur_text_chunk,
432
                        **response_kwargs,
433
                    )
434
                    structured_response = cast(
435
                        StructuredRefineResponse, structured_response
436
                    )
437
                    query_satisfied = structured_response.query_satisfied
438
                    if query_satisfied:
439
                        response = structured_response.answer
440
                except ValidationError as e:
441
                    logger.warning(
442
                        f"Validation error on structured response: {e}", exc_info=True
443
                    )
444
            elif response is None and self._streaming:
445
                raise ValueError("Streaming not supported for async")
446
            else:
447
                response = await self._arefine_response_single(
448
                    cast(RESPONSE_TEXT_TYPE, response),
449
                    query_str,
450
                    cur_text_chunk,
451
                    **response_kwargs,
452
                )
453
        if response is None:
454
            response = "Empty Response"
455
        if isinstance(response, str):
456
            response = response or "Empty Response"
457
        else:
458
            response = cast(Generator, response)
459
        return response
460

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.