llama-index

Форк
0
148 строк · 4.9 Кб
1
import asyncio
2
from typing import Any, Callable, List, Optional, Sequence
3

4
from llama_index.legacy.async_utils import run_async_tasks
5
from llama_index.legacy.prompts import BasePromptTemplate
6
from llama_index.legacy.prompts.default_prompt_selectors import (
7
    DEFAULT_TEXT_QA_PROMPT_SEL,
8
)
9
from llama_index.legacy.prompts.mixin import PromptDictType
10
from llama_index.legacy.response_synthesizers.base import BaseSynthesizer
11
from llama_index.legacy.service_context import ServiceContext
12
from llama_index.legacy.types import RESPONSE_TEXT_TYPE
13

14

15
class Accumulate(BaseSynthesizer):
16
    """Accumulate responses from multiple text chunks."""
17

18
    def __init__(
19
        self,
20
        text_qa_template: Optional[BasePromptTemplate] = None,
21
        service_context: Optional[ServiceContext] = None,
22
        output_cls: Optional[Any] = None,
23
        streaming: bool = False,
24
        use_async: bool = False,
25
    ) -> None:
26
        super().__init__(
27
            service_context=service_context,
28
            streaming=streaming,
29
        )
30
        self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
31
        self._use_async = use_async
32
        self._output_cls = output_cls
33

34
    def _get_prompts(self) -> PromptDictType:
35
        """Get prompts."""
36
        return {"text_qa_template": self._text_qa_template}
37

38
    def _update_prompts(self, prompts: PromptDictType) -> None:
39
        """Update prompts."""
40
        if "text_qa_template" in prompts:
41
            self._text_qa_template = prompts["text_qa_template"]
42

43
    def flatten_list(self, md_array: List[List[Any]]) -> List[Any]:
44
        return [item for sublist in md_array for item in sublist]
45

46
    def _format_response(self, outputs: List[Any], separator: str) -> str:
47
        responses: List[str] = []
48
        for response in outputs:
49
            responses.append(response or "Empty Response")
50

51
        return separator.join(
52
            [f"Response {index + 1}: {item}" for index, item in enumerate(responses)]
53
        )
54

55
    async def aget_response(
56
        self,
57
        query_str: str,
58
        text_chunks: Sequence[str],
59
        separator: str = "\n---------------------\n",
60
        **response_kwargs: Any,
61
    ) -> RESPONSE_TEXT_TYPE:
62
        """Apply the same prompt to text chunks and return async responses."""
63
        if self._streaming:
64
            raise ValueError("Unable to stream in Accumulate response mode")
65

66
        tasks = [
67
            self._give_responses(
68
                query_str, text_chunk, use_async=True, **response_kwargs
69
            )
70
            for text_chunk in text_chunks
71
        ]
72

73
        flattened_tasks = self.flatten_list(tasks)
74
        outputs = await asyncio.gather(*flattened_tasks)
75

76
        return self._format_response(outputs, separator)
77

78
    def get_response(
79
        self,
80
        query_str: str,
81
        text_chunks: Sequence[str],
82
        separator: str = "\n---------------------\n",
83
        **response_kwargs: Any,
84
    ) -> RESPONSE_TEXT_TYPE:
85
        """Apply the same prompt to text chunks and return responses."""
86
        if self._streaming:
87
            raise ValueError("Unable to stream in Accumulate response mode")
88

89
        tasks = [
90
            self._give_responses(
91
                query_str, text_chunk, use_async=self._use_async, **response_kwargs
92
            )
93
            for text_chunk in text_chunks
94
        ]
95

96
        outputs = self.flatten_list(tasks)
97

98
        if self._use_async:
99
            outputs = run_async_tasks(outputs)
100

101
        return self._format_response(outputs, separator)
102

103
    def _give_responses(
104
        self,
105
        query_str: str,
106
        text_chunk: str,
107
        use_async: bool = False,
108
        **response_kwargs: Any,
109
    ) -> List[Any]:
110
        """Give responses given a query and a corresponding text chunk."""
111
        text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
112

113
        text_chunks = self._service_context.prompt_helper.repack(
114
            text_qa_template, [text_chunk]
115
        )
116

117
        predictor: Callable
118
        if self._output_cls is None:
119
            predictor = (
120
                self._service_context.llm.apredict
121
                if use_async
122
                else self._service_context.llm.predict
123
            )
124

125
            return [
126
                predictor(
127
                    text_qa_template,
128
                    context_str=cur_text_chunk,
129
                    **response_kwargs,
130
                )
131
                for cur_text_chunk in text_chunks
132
            ]
133
        else:
134
            predictor = (
135
                self._service_context.llm.astructured_predict
136
                if use_async
137
                else self._service_context.llm.structured_predict
138
            )
139

140
            return [
141
                predictor(
142
                    self._output_cls,
143
                    text_qa_template,
144
                    context_str=cur_text_chunk,
145
                    **response_kwargs,
146
                )
147
                for cur_text_chunk in text_chunks
148
            ]
149

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.