llama-index
273 строки · 9.2 Кб
1"""Response builder class.
2
3This class provides general functions for taking in a set of text
4and generating a response.
5
6Will support different modes, from 1) stuffing chunks into prompt,
72) create and refine separately over each chunk, 3) tree summarization.
8
9"""
10
11import logging
12from abc import abstractmethod
13from typing import Any, Dict, Generator, List, Optional, Sequence, Union
14
15from llama_index.legacy.bridge.pydantic import BaseModel, Field
16from llama_index.legacy.callbacks.base import CallbackManager
17from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
18from llama_index.legacy.core.query_pipeline.query_component import (
19ChainableMixin,
20InputKeys,
21OutputKeys,
22QueryComponent,
23validate_and_convert_stringable,
24)
25from llama_index.legacy.core.response.schema import (
26RESPONSE_TYPE,
27PydanticResponse,
28Response,
29StreamingResponse,
30)
31from llama_index.legacy.prompts.mixin import PromptMixin
32from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
33from llama_index.legacy.service_context import ServiceContext
34from llama_index.legacy.types import RESPONSE_TEXT_TYPE
35
36logger = logging.getLogger(__name__)
37
38QueryTextType = Union[str, QueryBundle]
39
40
41class BaseSynthesizer(ChainableMixin, PromptMixin):
42"""Response builder class."""
43
44def __init__(
45self,
46service_context: Optional[ServiceContext] = None,
47streaming: bool = False,
48output_cls: BaseModel = None,
49) -> None:
50"""Init params."""
51self._service_context = service_context or ServiceContext.from_defaults()
52self._callback_manager = self._service_context.callback_manager
53self._streaming = streaming
54self._output_cls = output_cls
55
56def _get_prompt_modules(self) -> Dict[str, Any]:
57"""Get prompt modules."""
58# TODO: keep this for now since response synthesizers don't generally have sub-modules
59return {}
60
61@property
62def service_context(self) -> ServiceContext:
63return self._service_context
64
65@property
66def callback_manager(self) -> CallbackManager:
67return self._callback_manager
68
69@callback_manager.setter
70def callback_manager(self, callback_manager: CallbackManager) -> None:
71"""Set callback manager."""
72self._callback_manager = callback_manager
73# TODO: please fix this later
74self._service_context.callback_manager = callback_manager
75self._service_context.llm.callback_manager = callback_manager
76self._service_context.embed_model.callback_manager = callback_manager
77self._service_context.node_parser.callback_manager = callback_manager
78
79@abstractmethod
80def get_response(
81self,
82query_str: str,
83text_chunks: Sequence[str],
84**response_kwargs: Any,
85) -> RESPONSE_TEXT_TYPE:
86"""Get response."""
87...
88
89@abstractmethod
90async def aget_response(
91self,
92query_str: str,
93text_chunks: Sequence[str],
94**response_kwargs: Any,
95) -> RESPONSE_TEXT_TYPE:
96"""Get response."""
97...
98
99def _log_prompt_and_response(
100self,
101formatted_prompt: str,
102response: RESPONSE_TEXT_TYPE,
103log_prefix: str = "",
104) -> None:
105"""Log prompt and response from LLM."""
106logger.debug(f"> {log_prefix} prompt template: {formatted_prompt}")
107self._service_context.llama_logger.add_log(
108{"formatted_prompt_template": formatted_prompt}
109)
110logger.debug(f"> {log_prefix} response: {response}")
111self._service_context.llama_logger.add_log(
112{f"{log_prefix.lower()}_response": response or "Empty Response"}
113)
114
115def _get_metadata_for_response(
116self,
117nodes: List[BaseNode],
118) -> Optional[Dict[str, Any]]:
119"""Get metadata for response."""
120return {node.node_id: node.metadata for node in nodes}
121
122def _prepare_response_output(
123self,
124response_str: Optional[RESPONSE_TEXT_TYPE],
125source_nodes: List[NodeWithScore],
126) -> RESPONSE_TYPE:
127"""Prepare response object from response string."""
128response_metadata = self._get_metadata_for_response(
129[node_with_score.node for node_with_score in source_nodes]
130)
131
132if isinstance(response_str, str):
133return Response(
134response_str,
135source_nodes=source_nodes,
136metadata=response_metadata,
137)
138if isinstance(response_str, Generator):
139return StreamingResponse(
140response_str,
141source_nodes=source_nodes,
142metadata=response_metadata,
143)
144if isinstance(response_str, self._output_cls):
145return PydanticResponse(
146response_str, source_nodes=source_nodes, metadata=response_metadata
147)
148
149raise ValueError(
150f"Response must be a string or a generator. Found {type(response_str)}"
151)
152
153def synthesize(
154self,
155query: QueryTextType,
156nodes: List[NodeWithScore],
157additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
158**response_kwargs: Any,
159) -> RESPONSE_TYPE:
160if len(nodes) == 0:
161return Response("Empty Response")
162
163if isinstance(query, str):
164query = QueryBundle(query_str=query)
165
166with self._callback_manager.event(
167CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
168) as event:
169response_str = self.get_response(
170query_str=query.query_str,
171text_chunks=[
172n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
173],
174**response_kwargs,
175)
176
177additional_source_nodes = additional_source_nodes or []
178source_nodes = list(nodes) + list(additional_source_nodes)
179
180response = self._prepare_response_output(response_str, source_nodes)
181
182event.on_end(payload={EventPayload.RESPONSE: response})
183
184return response
185
186async def asynthesize(
187self,
188query: QueryTextType,
189nodes: List[NodeWithScore],
190additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
191**response_kwargs: Any,
192) -> RESPONSE_TYPE:
193if len(nodes) == 0:
194return Response("Empty Response")
195
196if isinstance(query, str):
197query = QueryBundle(query_str=query)
198
199with self._callback_manager.event(
200CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
201) as event:
202response_str = await self.aget_response(
203query_str=query.query_str,
204text_chunks=[
205n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
206],
207**response_kwargs,
208)
209
210additional_source_nodes = additional_source_nodes or []
211source_nodes = list(nodes) + list(additional_source_nodes)
212
213response = self._prepare_response_output(response_str, source_nodes)
214
215event.on_end(payload={EventPayload.RESPONSE: response})
216
217return response
218
219def _as_query_component(self, **kwargs: Any) -> QueryComponent:
220"""As query component."""
221return SynthesizerComponent(synthesizer=self)
222
223
224class SynthesizerComponent(QueryComponent):
225"""Synthesizer component."""
226
227synthesizer: BaseSynthesizer = Field(..., description="Synthesizer")
228
229class Config:
230arbitrary_types_allowed = True
231
232def set_callback_manager(self, callback_manager: CallbackManager) -> None:
233"""Set callback manager."""
234self.synthesizer.callback_manager = callback_manager
235
236def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
237"""Validate component inputs during run_component."""
238# make sure both query_str and nodes are there
239if "query_str" not in input:
240raise ValueError("Input must have key 'query_str'")
241input["query_str"] = validate_and_convert_stringable(input["query_str"])
242
243if "nodes" not in input:
244raise ValueError("Input must have key 'nodes'")
245nodes = input["nodes"]
246if not isinstance(nodes, list):
247raise ValueError("Input nodes must be a list")
248for node in nodes:
249if not isinstance(node, NodeWithScore):
250raise ValueError("Input nodes must be a list of NodeWithScore")
251return input
252
253def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
254"""Run component."""
255output = self.synthesizer.synthesize(kwargs["query_str"], kwargs["nodes"])
256return {"output": output}
257
258async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
259"""Run component."""
260output = await self.synthesizer.asynthesize(
261kwargs["query_str"], kwargs["nodes"]
262)
263return {"output": output}
264
265@property
266def input_keys(self) -> InputKeys:
267"""Input keys."""
268return InputKeys.from_keys({"query_str", "nodes"})
269
270@property
271def output_keys(self) -> OutputKeys:
272"""Output keys."""
273return OutputKeys.from_keys({"output"})
274