llama-index
304 строки · 12.3 Кб
1from typing import Any, List, Optional, Sequence2
3from llama_index.legacy.callbacks.base import CallbackManager4from llama_index.legacy.callbacks.schema import CBEventType, EventPayload5from llama_index.legacy.core.base_query_engine import BaseQueryEngine6from llama_index.legacy.core.base_retriever import BaseRetriever7from llama_index.legacy.core.response.schema import RESPONSE_TYPE8from llama_index.legacy.indices.base import BaseGPTIndex9from llama_index.legacy.node_parser import SentenceSplitter, TextSplitter10from llama_index.legacy.postprocessor.types import BaseNodePostprocessor11from llama_index.legacy.prompts import PromptTemplate12from llama_index.legacy.prompts.base import BasePromptTemplate13from llama_index.legacy.prompts.mixin import PromptMixinType14from llama_index.legacy.response_synthesizers import (15BaseSynthesizer,16ResponseMode,17get_response_synthesizer,18)
19from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode20
21CITATION_QA_TEMPLATE = PromptTemplate(22"Please provide an answer based solely on the provided sources. "23"When referencing information from a source, "24"cite the appropriate source(s) using their corresponding numbers. "25"Every answer should include at least one source citation. "26"Only cite a source when you are explicitly referencing it. "27"If none of the sources are helpful, you should indicate that. "28"For example:\n"29"Source 1:\n"30"The sky is red in the evening and blue in the morning.\n"31"Source 2:\n"32"Water is wet when the sky is red.\n"33"Query: When is water wet?\n"34"Answer: Water will be wet when the sky is red [2], "35"which occurs in the evening [1].\n"36"Now it's your turn. Below are several numbered sources of information:"37"\n------\n"38"{context_str}"39"\n------\n"40"Query: {query_str}\n"41"Answer: "42)
43
44CITATION_REFINE_TEMPLATE = PromptTemplate(45"Please provide an answer based solely on the provided sources. "46"When referencing information from a source, "47"cite the appropriate source(s) using their corresponding numbers. "48"Every answer should include at least one source citation. "49"Only cite a source when you are explicitly referencing it. "50"If none of the sources are helpful, you should indicate that. "51"For example:\n"52"Source 1:\n"53"The sky is red in the evening and blue in the morning.\n"54"Source 2:\n"55"Water is wet when the sky is red.\n"56"Query: When is water wet?\n"57"Answer: Water will be wet when the sky is red [2], "58"which occurs in the evening [1].\n"59"Now it's your turn. "60"We have provided an existing answer: {existing_answer}"61"Below are several numbered sources of information. "62"Use them to refine the existing answer. "63"If the provided sources are not helpful, you will repeat the existing answer."64"\nBegin refining!"65"\n------\n"66"{context_msg}"67"\n------\n"68"Query: {query_str}\n"69"Answer: "70)
71
72DEFAULT_CITATION_CHUNK_SIZE = 51273DEFAULT_CITATION_CHUNK_OVERLAP = 2074
75
76class CitationQueryEngine(BaseQueryEngine):77"""Citation query engine.78
79Args:
80retriever (BaseRetriever): A retriever object.
81response_synthesizer (Optional[BaseSynthesizer]):
82A BaseSynthesizer object.
83citation_chunk_size (int):
84Size of citation chunks, default=512. Useful for controlling
85granularity of sources.
86citation_chunk_overlap (int): Overlap of citation nodes, default=20.
87text_splitter (Optional[TextSplitter]):
88A text splitter for creating citation source nodes. Default is
89a SentenceSplitter.
90callback_manager (Optional[CallbackManager]): A callback manager.
91metadata_mode (MetadataMode): A MetadataMode object that controls how
92metadata is included in the citation prompt.
93"""
94
95def __init__(96self,97retriever: BaseRetriever,98response_synthesizer: Optional[BaseSynthesizer] = None,99citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,100citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,101text_splitter: Optional[TextSplitter] = None,102node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,103callback_manager: Optional[CallbackManager] = None,104metadata_mode: MetadataMode = MetadataMode.NONE,105) -> None:106self.text_splitter = text_splitter or SentenceSplitter(107chunk_size=citation_chunk_size, chunk_overlap=citation_chunk_overlap108)109self._retriever = retriever110self._response_synthesizer = response_synthesizer or get_response_synthesizer(111service_context=retriever.get_service_context(),112callback_manager=callback_manager,113)114self._node_postprocessors = node_postprocessors or []115self._metadata_mode = metadata_mode116
117callback_manager = callback_manager or CallbackManager()118for node_postprocessor in self._node_postprocessors:119node_postprocessor.callback_manager = callback_manager120
121super().__init__(callback_manager)122
123@classmethod124def from_args(125cls,126index: BaseGPTIndex,127response_synthesizer: Optional[BaseSynthesizer] = None,128citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,129citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,130text_splitter: Optional[TextSplitter] = None,131citation_qa_template: BasePromptTemplate = CITATION_QA_TEMPLATE,132citation_refine_template: BasePromptTemplate = CITATION_REFINE_TEMPLATE,133retriever: Optional[BaseRetriever] = None,134node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,135# response synthesizer args136response_mode: ResponseMode = ResponseMode.COMPACT,137use_async: bool = False,138streaming: bool = False,139# class-specific args140metadata_mode: MetadataMode = MetadataMode.NONE,141**kwargs: Any,142) -> "CitationQueryEngine":143"""Initialize a CitationQueryEngine object.".144
145Args:
146index: (BastGPTIndex): index to use for querying
147citation_chunk_size (int):
148Size of citation chunks, default=512. Useful for controlling
149granularity of sources.
150citation_chunk_overlap (int): Overlap of citation nodes, default=20.
151text_splitter (Optional[TextSplitter]):
152A text splitter for creating citation source nodes. Default is
153a SentenceSplitter.
154citation_qa_template (BasePromptTemplate): Template for initial citation QA
155citation_refine_template (BasePromptTemplate):
156Template for citation refinement.
157retriever (BaseRetriever): A retriever object.
158service_context (Optional[ServiceContext]): A ServiceContext object.
159node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
160node postprocessors.
161verbose (bool): Whether to print out debug info.
162response_mode (ResponseMode): A ResponseMode object.
163use_async (bool): Whether to use async.
164streaming (bool): Whether to use streaming.
165optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
166object.
167
168"""
169retriever = retriever or index.as_retriever(**kwargs)170
171response_synthesizer = response_synthesizer or get_response_synthesizer(172service_context=index.service_context,173text_qa_template=citation_qa_template,174refine_template=citation_refine_template,175response_mode=response_mode,176use_async=use_async,177streaming=streaming,178)179
180return cls(181retriever=retriever,182response_synthesizer=response_synthesizer,183callback_manager=index.service_context.callback_manager,184citation_chunk_size=citation_chunk_size,185citation_chunk_overlap=citation_chunk_overlap,186text_splitter=text_splitter,187node_postprocessors=node_postprocessors,188metadata_mode=metadata_mode,189)190
191def _get_prompt_modules(self) -> PromptMixinType:192"""Get prompt sub-modules."""193return {"response_synthesizer": self._response_synthesizer}194
195def _create_citation_nodes(self, nodes: List[NodeWithScore]) -> List[NodeWithScore]:196"""Modify retrieved nodes to be granular sources."""197new_nodes: List[NodeWithScore] = []198for node in nodes:199text_chunks = self.text_splitter.split_text(200node.node.get_content(metadata_mode=self._metadata_mode)201)202
203for text_chunk in text_chunks:204text = f"Source {len(new_nodes)+1}:\n{text_chunk}\n"205
206new_node = NodeWithScore(207node=TextNode.parse_obj(node.node), score=node.score208)209new_node.node.text = text210new_nodes.append(new_node)211return new_nodes212
213def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:214nodes = self._retriever.retrieve(query_bundle)215
216for postprocessor in self._node_postprocessors:217nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)218
219return nodes220
221async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:222nodes = await self._retriever.aretrieve(query_bundle)223
224for postprocessor in self._node_postprocessors:225nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)226
227return nodes228
229@property230def retriever(self) -> BaseRetriever:231"""Get the retriever object."""232return self._retriever233
234def synthesize(235self,236query_bundle: QueryBundle,237nodes: List[NodeWithScore],238additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,239) -> RESPONSE_TYPE:240nodes = self._create_citation_nodes(nodes)241return self._response_synthesizer.synthesize(242query=query_bundle,243nodes=nodes,244additional_source_nodes=additional_source_nodes,245)246
247async def asynthesize(248self,249query_bundle: QueryBundle,250nodes: List[NodeWithScore],251additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,252) -> RESPONSE_TYPE:253nodes = self._create_citation_nodes(nodes)254return await self._response_synthesizer.asynthesize(255query=query_bundle,256nodes=nodes,257additional_source_nodes=additional_source_nodes,258)259
260def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:261"""Answer a query."""262with self.callback_manager.event(263CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}264) as query_event:265with self.callback_manager.event(266CBEventType.RETRIEVE,267payload={EventPayload.QUERY_STR: query_bundle.query_str},268) as retrieve_event:269nodes = self.retrieve(query_bundle)270nodes = self._create_citation_nodes(nodes)271
272retrieve_event.on_end(payload={EventPayload.NODES: nodes})273
274response = self._response_synthesizer.synthesize(275query=query_bundle,276nodes=nodes,277)278
279query_event.on_end(payload={EventPayload.RESPONSE: response})280
281return response282
283async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:284"""Answer a query."""285with self.callback_manager.event(286CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}287) as query_event:288with self.callback_manager.event(289CBEventType.RETRIEVE,290payload={EventPayload.QUERY_STR: query_bundle.query_str},291) as retrieve_event:292nodes = await self.aretrieve(query_bundle)293nodes = self._create_citation_nodes(nodes)294
295retrieve_event.on_end(payload={EventPayload.NODES: nodes})296
297response = await self._response_synthesizer.asynthesize(298query=query_bundle,299nodes=nodes,300)301
302query_event.on_end(payload={EventPayload.RESPONSE: response})303
304return response305