llama-index

Форк
0
1
"""Response builder class.
2

3
This class provides general functions for taking in a set of text
4
and generating a response.
5

6
Will support different modes, from 1) stuffing chunks into prompt,
7
2) create and refine separately over each chunk, 3) tree summarization.
8

9
"""
10

11
import logging
12
from abc import abstractmethod
13
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
14

15
from llama_index.legacy.bridge.pydantic import BaseModel, Field
16
from llama_index.legacy.callbacks.base import CallbackManager
17
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
18
from llama_index.legacy.core.query_pipeline.query_component import (
19
    ChainableMixin,
20
    InputKeys,
21
    OutputKeys,
22
    QueryComponent,
23
    validate_and_convert_stringable,
24
)
25
from llama_index.legacy.core.response.schema import (
26
    RESPONSE_TYPE,
27
    PydanticResponse,
28
    Response,
29
    StreamingResponse,
30
)
31
from llama_index.legacy.prompts.mixin import PromptMixin
32
from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
33
from llama_index.legacy.service_context import ServiceContext
34
from llama_index.legacy.types import RESPONSE_TEXT_TYPE
35

36
logger = logging.getLogger(__name__)
37

38
QueryTextType = Union[str, QueryBundle]
39

40

41
class BaseSynthesizer(ChainableMixin, PromptMixin):
42
    """Response builder class."""
43

44
    def __init__(
45
        self,
46
        service_context: Optional[ServiceContext] = None,
47
        streaming: bool = False,
48
        output_cls: BaseModel = None,
49
    ) -> None:
50
        """Init params."""
51
        self._service_context = service_context or ServiceContext.from_defaults()
52
        self._callback_manager = self._service_context.callback_manager
53
        self._streaming = streaming
54
        self._output_cls = output_cls
55

56
    def _get_prompt_modules(self) -> Dict[str, Any]:
57
        """Get prompt modules."""
58
        # TODO: keep this for now since response synthesizers don't generally have sub-modules
59
        return {}
60

61
    @property
62
    def service_context(self) -> ServiceContext:
63
        return self._service_context
64

65
    @property
66
    def callback_manager(self) -> CallbackManager:
67
        return self._callback_manager
68

69
    @callback_manager.setter
70
    def callback_manager(self, callback_manager: CallbackManager) -> None:
71
        """Set callback manager."""
72
        self._callback_manager = callback_manager
73
        # TODO: please fix this later
74
        self._service_context.callback_manager = callback_manager
75
        self._service_context.llm.callback_manager = callback_manager
76
        self._service_context.embed_model.callback_manager = callback_manager
77
        self._service_context.node_parser.callback_manager = callback_manager
78

79
    @abstractmethod
80
    def get_response(
81
        self,
82
        query_str: str,
83
        text_chunks: Sequence[str],
84
        **response_kwargs: Any,
85
    ) -> RESPONSE_TEXT_TYPE:
86
        """Get response."""
87
        ...
88

89
    @abstractmethod
90
    async def aget_response(
91
        self,
92
        query_str: str,
93
        text_chunks: Sequence[str],
94
        **response_kwargs: Any,
95
    ) -> RESPONSE_TEXT_TYPE:
96
        """Get response."""
97
        ...
98

99
    def _log_prompt_and_response(
100
        self,
101
        formatted_prompt: str,
102
        response: RESPONSE_TEXT_TYPE,
103
        log_prefix: str = "",
104
    ) -> None:
105
        """Log prompt and response from LLM."""
106
        logger.debug(f"> {log_prefix} prompt template: {formatted_prompt}")
107
        self._service_context.llama_logger.add_log(
108
            {"formatted_prompt_template": formatted_prompt}
109
        )
110
        logger.debug(f"> {log_prefix} response: {response}")
111
        self._service_context.llama_logger.add_log(
112
            {f"{log_prefix.lower()}_response": response or "Empty Response"}
113
        )
114

115
    def _get_metadata_for_response(
116
        self,
117
        nodes: List[BaseNode],
118
    ) -> Optional[Dict[str, Any]]:
119
        """Get metadata for response."""
120
        return {node.node_id: node.metadata for node in nodes}
121

122
    def _prepare_response_output(
123
        self,
124
        response_str: Optional[RESPONSE_TEXT_TYPE],
125
        source_nodes: List[NodeWithScore],
126
    ) -> RESPONSE_TYPE:
127
        """Prepare response object from response string."""
128
        response_metadata = self._get_metadata_for_response(
129
            [node_with_score.node for node_with_score in source_nodes]
130
        )
131

132
        if isinstance(response_str, str):
133
            return Response(
134
                response_str,
135
                source_nodes=source_nodes,
136
                metadata=response_metadata,
137
            )
138
        if isinstance(response_str, Generator):
139
            return StreamingResponse(
140
                response_str,
141
                source_nodes=source_nodes,
142
                metadata=response_metadata,
143
            )
144
        if isinstance(response_str, self._output_cls):
145
            return PydanticResponse(
146
                response_str, source_nodes=source_nodes, metadata=response_metadata
147
            )
148

149
        raise ValueError(
150
            f"Response must be a string or a generator. Found {type(response_str)}"
151
        )
152

153
    def synthesize(
154
        self,
155
        query: QueryTextType,
156
        nodes: List[NodeWithScore],
157
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
158
        **response_kwargs: Any,
159
    ) -> RESPONSE_TYPE:
160
        if len(nodes) == 0:
161
            return Response("Empty Response")
162

163
        if isinstance(query, str):
164
            query = QueryBundle(query_str=query)
165

166
        with self._callback_manager.event(
167
            CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
168
        ) as event:
169
            response_str = self.get_response(
170
                query_str=query.query_str,
171
                text_chunks=[
172
                    n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
173
                ],
174
                **response_kwargs,
175
            )
176

177
            additional_source_nodes = additional_source_nodes or []
178
            source_nodes = list(nodes) + list(additional_source_nodes)
179

180
            response = self._prepare_response_output(response_str, source_nodes)
181

182
            event.on_end(payload={EventPayload.RESPONSE: response})
183

184
        return response
185

186
    async def asynthesize(
187
        self,
188
        query: QueryTextType,
189
        nodes: List[NodeWithScore],
190
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
191
        **response_kwargs: Any,
192
    ) -> RESPONSE_TYPE:
193
        if len(nodes) == 0:
194
            return Response("Empty Response")
195

196
        if isinstance(query, str):
197
            query = QueryBundle(query_str=query)
198

199
        with self._callback_manager.event(
200
            CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
201
        ) as event:
202
            response_str = await self.aget_response(
203
                query_str=query.query_str,
204
                text_chunks=[
205
                    n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
206
                ],
207
                **response_kwargs,
208
            )
209

210
            additional_source_nodes = additional_source_nodes or []
211
            source_nodes = list(nodes) + list(additional_source_nodes)
212

213
            response = self._prepare_response_output(response_str, source_nodes)
214

215
            event.on_end(payload={EventPayload.RESPONSE: response})
216

217
        return response
218

219
    def _as_query_component(self, **kwargs: Any) -> QueryComponent:
220
        """As query component."""
221
        return SynthesizerComponent(synthesizer=self)
222

223

224
class SynthesizerComponent(QueryComponent):
225
    """Synthesizer component."""
226

227
    synthesizer: BaseSynthesizer = Field(..., description="Synthesizer")
228

229
    class Config:
230
        arbitrary_types_allowed = True
231

232
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
233
        """Set callback manager."""
234
        self.synthesizer.callback_manager = callback_manager
235

236
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
237
        """Validate component inputs during run_component."""
238
        # make sure both query_str and nodes are there
239
        if "query_str" not in input:
240
            raise ValueError("Input must have key 'query_str'")
241
        input["query_str"] = validate_and_convert_stringable(input["query_str"])
242

243
        if "nodes" not in input:
244
            raise ValueError("Input must have key 'nodes'")
245
        nodes = input["nodes"]
246
        if not isinstance(nodes, list):
247
            raise ValueError("Input nodes must be a list")
248
        for node in nodes:
249
            if not isinstance(node, NodeWithScore):
250
                raise ValueError("Input nodes must be a list of NodeWithScore")
251
        return input
252

253
    def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
254
        """Run component."""
255
        output = self.synthesizer.synthesize(kwargs["query_str"], kwargs["nodes"])
256
        return {"output": output}
257

258
    async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
259
        """Run component."""
260
        output = await self.synthesizer.asynthesize(
261
            kwargs["query_str"], kwargs["nodes"]
262
        )
263
        return {"output": output}
264

265
    @property
266
    def input_keys(self) -> InputKeys:
267
        """Input keys."""
268
        return InputKeys.from_keys({"query_str", "nodes"})
269

270
    @property
271
    def output_keys(self) -> OutputKeys:
272
        """Output keys."""
273
        return OutputKeys.from_keys({"output"})
274

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.