llama-index

Форк
0
200 строк · 7.7 Кб
1
from typing import Any, List, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import BaseModel
4
from llama_index.legacy.callbacks.base import CallbackManager
5
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
6
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
7
from llama_index.legacy.core.base_retriever import BaseRetriever
8
from llama_index.legacy.core.response.schema import RESPONSE_TYPE
9
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
10
from llama_index.legacy.prompts import BasePromptTemplate
11
from llama_index.legacy.prompts.mixin import PromptMixinType
12
from llama_index.legacy.response_synthesizers import (
13
    BaseSynthesizer,
14
    ResponseMode,
15
    get_response_synthesizer,
16
)
17
from llama_index.legacy.schema import NodeWithScore, QueryBundle
18
from llama_index.legacy.service_context import ServiceContext
19

20

21
class RetrieverQueryEngine(BaseQueryEngine):
22
    """Retriever query engine.
23

24
    Args:
25
        retriever (BaseRetriever): A retriever object.
26
        response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
27
            object.
28
        callback_manager (Optional[CallbackManager]): A callback manager.
29
    """
30

31
    def __init__(
32
        self,
33
        retriever: BaseRetriever,
34
        response_synthesizer: Optional[BaseSynthesizer] = None,
35
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
36
        callback_manager: Optional[CallbackManager] = None,
37
    ) -> None:
38
        self._retriever = retriever
39
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
40
            service_context=retriever.get_service_context(),
41
            callback_manager=callback_manager,
42
        )
43

44
        self._node_postprocessors = node_postprocessors or []
45
        callback_manager = callback_manager or CallbackManager([])
46
        for node_postprocessor in self._node_postprocessors:
47
            node_postprocessor.callback_manager = callback_manager
48

49
        super().__init__(callback_manager)
50

51
    def _get_prompt_modules(self) -> PromptMixinType:
52
        """Get prompt sub-modules."""
53
        return {"response_synthesizer": self._response_synthesizer}
54

55
    @classmethod
56
    def from_args(
57
        cls,
58
        retriever: BaseRetriever,
59
        response_synthesizer: Optional[BaseSynthesizer] = None,
60
        service_context: Optional[ServiceContext] = None,
61
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
62
        # response synthesizer args
63
        response_mode: ResponseMode = ResponseMode.COMPACT,
64
        text_qa_template: Optional[BasePromptTemplate] = None,
65
        refine_template: Optional[BasePromptTemplate] = None,
66
        summary_template: Optional[BasePromptTemplate] = None,
67
        simple_template: Optional[BasePromptTemplate] = None,
68
        output_cls: Optional[BaseModel] = None,
69
        use_async: bool = False,
70
        streaming: bool = False,
71
        # class-specific args
72
        **kwargs: Any,
73
    ) -> "RetrieverQueryEngine":
74
        """Initialize a RetrieverQueryEngine object.".
75

76
        Args:
77
            retriever (BaseRetriever): A retriever object.
78
            service_context (Optional[ServiceContext]): A ServiceContext object.
79
            node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
80
                node postprocessors.
81
            verbose (bool): Whether to print out debug info.
82
            response_mode (ResponseMode): A ResponseMode object.
83
            text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate
84
                object.
85
            refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
86
            simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
87

88
            use_async (bool): Whether to use async.
89
            streaming (bool): Whether to use streaming.
90
            optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
91
                object.
92

93
        """
94
        response_synthesizer = response_synthesizer or get_response_synthesizer(
95
            service_context=service_context,
96
            text_qa_template=text_qa_template,
97
            refine_template=refine_template,
98
            summary_template=summary_template,
99
            simple_template=simple_template,
100
            response_mode=response_mode,
101
            output_cls=output_cls,
102
            use_async=use_async,
103
            streaming=streaming,
104
        )
105

106
        callback_manager = (
107
            service_context.callback_manager if service_context else CallbackManager([])
108
        )
109

110
        return cls(
111
            retriever=retriever,
112
            response_synthesizer=response_synthesizer,
113
            callback_manager=callback_manager,
114
            node_postprocessors=node_postprocessors,
115
        )
116

117
    def _apply_node_postprocessors(
118
        self, nodes: List[NodeWithScore], query_bundle: QueryBundle
119
    ) -> List[NodeWithScore]:
120
        for node_postprocessor in self._node_postprocessors:
121
            nodes = node_postprocessor.postprocess_nodes(
122
                nodes, query_bundle=query_bundle
123
            )
124
        return nodes
125

126
    def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
127
        nodes = self._retriever.retrieve(query_bundle)
128
        return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
129

130
    async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
131
        nodes = await self._retriever.aretrieve(query_bundle)
132
        return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
133

134
    def with_retriever(self, retriever: BaseRetriever) -> "RetrieverQueryEngine":
135
        return RetrieverQueryEngine(
136
            retriever=retriever,
137
            response_synthesizer=self._response_synthesizer,
138
            callback_manager=self.callback_manager,
139
            node_postprocessors=self._node_postprocessors,
140
        )
141

142
    def synthesize(
143
        self,
144
        query_bundle: QueryBundle,
145
        nodes: List[NodeWithScore],
146
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
147
    ) -> RESPONSE_TYPE:
148
        return self._response_synthesizer.synthesize(
149
            query=query_bundle,
150
            nodes=nodes,
151
            additional_source_nodes=additional_source_nodes,
152
        )
153

154
    async def asynthesize(
155
        self,
156
        query_bundle: QueryBundle,
157
        nodes: List[NodeWithScore],
158
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
159
    ) -> RESPONSE_TYPE:
160
        return await self._response_synthesizer.asynthesize(
161
            query=query_bundle,
162
            nodes=nodes,
163
            additional_source_nodes=additional_source_nodes,
164
        )
165

166
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
167
        """Answer a query."""
168
        with self.callback_manager.event(
169
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
170
        ) as query_event:
171
            nodes = self.retrieve(query_bundle)
172
            response = self._response_synthesizer.synthesize(
173
                query=query_bundle,
174
                nodes=nodes,
175
            )
176

177
            query_event.on_end(payload={EventPayload.RESPONSE: response})
178

179
        return response
180

181
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
182
        """Answer a query."""
183
        with self.callback_manager.event(
184
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
185
        ) as query_event:
186
            nodes = await self.aretrieve(query_bundle)
187

188
            response = await self._response_synthesizer.asynthesize(
189
                query=query_bundle,
190
                nodes=nodes,
191
            )
192

193
            query_event.on_end(payload={EventPayload.RESPONSE: response})
194

195
        return response
196

197
    @property
198
    def retriever(self) -> BaseRetriever:
199
        """Get the retriever object."""
200
        return self._retriever
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.