llama-index

Форк
0
304 строки · 12.3 Кб
1
from typing import Any, List, Optional, Sequence
2

3
from llama_index.legacy.callbacks.base import CallbackManager
4
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
5
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
6
from llama_index.legacy.core.base_retriever import BaseRetriever
7
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
8
from llama_index.legacy.indices.base import BaseGPTIndex
9
from llama_index.legacy.node_parser import SentenceSplitter, TextSplitter
10
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
11
from llama_index.legacy.prompts import PromptTemplate
12
from llama_index.legacy.prompts.base import BasePromptTemplate
13
from llama_index.legacy.prompts.mixin import PromptMixinType
14
from llama_index.legacy.response_synthesizers import (
15
    BaseSynthesizer,
16
    ResponseMode,
17
    get_response_synthesizer,
18
)
19
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
20

21
CITATION_QA_TEMPLATE = PromptTemplate(
22
    "Please provide an answer based solely on the provided sources. "
23
    "When referencing information from a source, "
24
    "cite the appropriate source(s) using their corresponding numbers. "
25
    "Every answer should include at least one source citation. "
26
    "Only cite a source when you are explicitly referencing it. "
27
    "If none of the sources are helpful, you should indicate that. "
28
    "For example:\n"
29
    "Source 1:\n"
30
    "The sky is red in the evening and blue in the morning.\n"
31
    "Source 2:\n"
32
    "Water is wet when the sky is red.\n"
33
    "Query: When is water wet?\n"
34
    "Answer: Water will be wet when the sky is red [2], "
35
    "which occurs in the evening [1].\n"
36
    "Now it's your turn. Below are several numbered sources of information:"
37
    "\n------\n"
38
    "{context_str}"
39
    "\n------\n"
40
    "Query: {query_str}\n"
41
    "Answer: "
42
)
43

44
CITATION_REFINE_TEMPLATE = PromptTemplate(
45
    "Please provide an answer based solely on the provided sources. "
46
    "When referencing information from a source, "
47
    "cite the appropriate source(s) using their corresponding numbers. "
48
    "Every answer should include at least one source citation. "
49
    "Only cite a source when you are explicitly referencing it. "
50
    "If none of the sources are helpful, you should indicate that. "
51
    "For example:\n"
52
    "Source 1:\n"
53
    "The sky is red in the evening and blue in the morning.\n"
54
    "Source 2:\n"
55
    "Water is wet when the sky is red.\n"
56
    "Query: When is water wet?\n"
57
    "Answer: Water will be wet when the sky is red [2], "
58
    "which occurs in the evening [1].\n"
59
    "Now it's your turn. "
60
    "We have provided an existing answer: {existing_answer}"
61
    "Below are several numbered sources of information. "
62
    "Use them to refine the existing answer. "
63
    "If the provided sources are not helpful, you will repeat the existing answer."
64
    "\nBegin refining!"
65
    "\n------\n"
66
    "{context_msg}"
67
    "\n------\n"
68
    "Query: {query_str}\n"
69
    "Answer: "
70
)
71

72
DEFAULT_CITATION_CHUNK_SIZE = 512
73
DEFAULT_CITATION_CHUNK_OVERLAP = 20
74

75

76
class CitationQueryEngine(BaseQueryEngine):
77
    """Citation query engine.
78

79
    Args:
80
        retriever (BaseRetriever): A retriever object.
81
        response_synthesizer (Optional[BaseSynthesizer]):
82
            A BaseSynthesizer object.
83
        citation_chunk_size (int):
84
            Size of citation chunks, default=512. Useful for controlling
85
            granularity of sources.
86
        citation_chunk_overlap (int): Overlap of citation nodes, default=20.
87
        text_splitter (Optional[TextSplitter]):
88
            A text splitter for creating citation source nodes. Default is
89
            a SentenceSplitter.
90
        callback_manager (Optional[CallbackManager]): A callback manager.
91
        metadata_mode (MetadataMode): A MetadataMode object that controls how
92
            metadata is included in the citation prompt.
93
    """
94

95
    def __init__(
96
        self,
97
        retriever: BaseRetriever,
98
        response_synthesizer: Optional[BaseSynthesizer] = None,
99
        citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,
100
        citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,
101
        text_splitter: Optional[TextSplitter] = None,
102
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
103
        callback_manager: Optional[CallbackManager] = None,
104
        metadata_mode: MetadataMode = MetadataMode.NONE,
105
    ) -> None:
106
        self.text_splitter = text_splitter or SentenceSplitter(
107
            chunk_size=citation_chunk_size, chunk_overlap=citation_chunk_overlap
108
        )
109
        self._retriever = retriever
110
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
111
            service_context=retriever.get_service_context(),
112
            callback_manager=callback_manager,
113
        )
114
        self._node_postprocessors = node_postprocessors or []
115
        self._metadata_mode = metadata_mode
116

117
        callback_manager = callback_manager or CallbackManager()
118
        for node_postprocessor in self._node_postprocessors:
119
            node_postprocessor.callback_manager = callback_manager
120

121
        super().__init__(callback_manager)
122

123
    @classmethod
124
    def from_args(
125
        cls,
126
        index: BaseGPTIndex,
127
        response_synthesizer: Optional[BaseSynthesizer] = None,
128
        citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,
129
        citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,
130
        text_splitter: Optional[TextSplitter] = None,
131
        citation_qa_template: BasePromptTemplate = CITATION_QA_TEMPLATE,
132
        citation_refine_template: BasePromptTemplate = CITATION_REFINE_TEMPLATE,
133
        retriever: Optional[BaseRetriever] = None,
134
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
135
        # response synthesizer args
136
        response_mode: ResponseMode = ResponseMode.COMPACT,
137
        use_async: bool = False,
138
        streaming: bool = False,
139
        # class-specific args
140
        metadata_mode: MetadataMode = MetadataMode.NONE,
141
        **kwargs: Any,
142
    ) -> "CitationQueryEngine":
143
        """Initialize a CitationQueryEngine object.".
144

145
        Args:
146
            index: (BastGPTIndex): index to use for querying
147
            citation_chunk_size (int):
148
                Size of citation chunks, default=512. Useful for controlling
149
                granularity of sources.
150
            citation_chunk_overlap (int): Overlap of citation nodes, default=20.
151
            text_splitter (Optional[TextSplitter]):
152
                A text splitter for creating citation source nodes. Default is
153
                a SentenceSplitter.
154
            citation_qa_template (BasePromptTemplate): Template for initial citation QA
155
            citation_refine_template (BasePromptTemplate):
156
                Template for citation refinement.
157
            retriever (BaseRetriever): A retriever object.
158
            service_context (Optional[ServiceContext]): A ServiceContext object.
159
            node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
160
                node postprocessors.
161
            verbose (bool): Whether to print out debug info.
162
            response_mode (ResponseMode): A ResponseMode object.
163
            use_async (bool): Whether to use async.
164
            streaming (bool): Whether to use streaming.
165
            optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
166
                object.
167

168
        """
169
        retriever = retriever or index.as_retriever(**kwargs)
170

171
        response_synthesizer = response_synthesizer or get_response_synthesizer(
172
            service_context=index.service_context,
173
            text_qa_template=citation_qa_template,
174
            refine_template=citation_refine_template,
175
            response_mode=response_mode,
176
            use_async=use_async,
177
            streaming=streaming,
178
        )
179

180
        return cls(
181
            retriever=retriever,
182
            response_synthesizer=response_synthesizer,
183
            callback_manager=index.service_context.callback_manager,
184
            citation_chunk_size=citation_chunk_size,
185
            citation_chunk_overlap=citation_chunk_overlap,
186
            text_splitter=text_splitter,
187
            node_postprocessors=node_postprocessors,
188
            metadata_mode=metadata_mode,
189
        )
190

191
    def _get_prompt_modules(self) -> PromptMixinType:
192
        """Get prompt sub-modules."""
193
        return {"response_synthesizer": self._response_synthesizer}
194

195
    def _create_citation_nodes(self, nodes: List[NodeWithScore]) -> List[NodeWithScore]:
196
        """Modify retrieved nodes to be granular sources."""
197
        new_nodes: List[NodeWithScore] = []
198
        for node in nodes:
199
            text_chunks = self.text_splitter.split_text(
200
                node.node.get_content(metadata_mode=self._metadata_mode)
201
            )
202

203
            for text_chunk in text_chunks:
204
                text = f"Source {len(new_nodes)+1}:\n{text_chunk}\n"
205

206
                new_node = NodeWithScore(
207
                    node=TextNode.parse_obj(node.node), score=node.score
208
                )
209
                new_node.node.text = text
210
                new_nodes.append(new_node)
211
        return new_nodes
212

213
    def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
214
        nodes = self._retriever.retrieve(query_bundle)
215

216
        for postprocessor in self._node_postprocessors:
217
            nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)
218

219
        return nodes
220

221
    async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
222
        nodes = await self._retriever.aretrieve(query_bundle)
223

224
        for postprocessor in self._node_postprocessors:
225
            nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)
226

227
        return nodes
228

229
    @property
230
    def retriever(self) -> BaseRetriever:
231
        """Get the retriever object."""
232
        return self._retriever
233

234
    def synthesize(
235
        self,
236
        query_bundle: QueryBundle,
237
        nodes: List[NodeWithScore],
238
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
239
    ) -> RESPONSE_TYPE:
240
        nodes = self._create_citation_nodes(nodes)
241
        return self._response_synthesizer.synthesize(
242
            query=query_bundle,
243
            nodes=nodes,
244
            additional_source_nodes=additional_source_nodes,
245
        )
246

247
    async def asynthesize(
248
        self,
249
        query_bundle: QueryBundle,
250
        nodes: List[NodeWithScore],
251
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
252
    ) -> RESPONSE_TYPE:
253
        nodes = self._create_citation_nodes(nodes)
254
        return await self._response_synthesizer.asynthesize(
255
            query=query_bundle,
256
            nodes=nodes,
257
            additional_source_nodes=additional_source_nodes,
258
        )
259

260
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
261
        """Answer a query."""
262
        with self.callback_manager.event(
263
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
264
        ) as query_event:
265
            with self.callback_manager.event(
266
                CBEventType.RETRIEVE,
267
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
268
            ) as retrieve_event:
269
                nodes = self.retrieve(query_bundle)
270
                nodes = self._create_citation_nodes(nodes)
271

272
                retrieve_event.on_end(payload={EventPayload.NODES: nodes})
273

274
            response = self._response_synthesizer.synthesize(
275
                query=query_bundle,
276
                nodes=nodes,
277
            )
278

279
            query_event.on_end(payload={EventPayload.RESPONSE: response})
280

281
        return response
282

283
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
284
        """Answer a query."""
285
        with self.callback_manager.event(
286
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
287
        ) as query_event:
288
            with self.callback_manager.event(
289
                CBEventType.RETRIEVE,
290
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
291
            ) as retrieve_event:
292
                nodes = await self.aretrieve(query_bundle)
293
                nodes = self._create_citation_nodes(nodes)
294

295
                retrieve_event.on_end(payload={EventPayload.NODES: nodes})
296

297
            response = await self._response_synthesizer.asynthesize(
298
                query=query_bundle,
299
                nodes=nodes,
300
            )
301

302
            query_event.on_end(payload={EventPayload.RESPONSE: response})
303

304
        return response
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.