llama-index

Форк
0
232 строки · 8.9 Кб
1
from typing import Any, Dict, List, Optional, Sequence, Tuple
2

3
from llama_index.legacy.callbacks.base import CallbackManager
4
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
5
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response
6
from llama_index.legacy.indices.multi_modal import MultiModalVectorIndexRetriever
7
from llama_index.legacy.indices.query.base import BaseQueryEngine
8
from llama_index.legacy.indices.query.schema import QueryBundle, QueryType
9
from llama_index.legacy.multi_modal_llms.base import MultiModalLLM
10
from llama_index.legacy.multi_modal_llms.openai import OpenAIMultiModal
11
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
12
from llama_index.legacy.prompts import BasePromptTemplate
13
from llama_index.legacy.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT
14
from llama_index.legacy.prompts.mixin import PromptMixinType
15
from llama_index.legacy.schema import ImageNode, NodeWithScore
16

17

18
def _get_image_and_text_nodes(
19
    nodes: List[NodeWithScore],
20
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
21
    image_nodes = []
22
    text_nodes = []
23
    for res_node in nodes:
24
        if isinstance(res_node.node, ImageNode):
25
            image_nodes.append(res_node)
26
        else:
27
            text_nodes.append(res_node)
28
    return image_nodes, text_nodes
29

30

31
class SimpleMultiModalQueryEngine(BaseQueryEngine):
32
    """Simple Multi Modal Retriever query engine.
33

34
    Assumes that retrieved text context fits within context window of LLM, along with images.
35

36
    Args:
37
        retriever (MultiModalVectorIndexRetriever): A retriever object.
38
        multi_modal_llm (Optional[MultiModalLLM]): MultiModalLLM Models.
39
        text_qa_template (Optional[BasePromptTemplate]): Text QA Prompt Template.
40
        image_qa_template (Optional[BasePromptTemplate]): Image QA Prompt Template.
41
        node_postprocessors (Optional[List[BaseNodePostprocessor]]): Node Postprocessors.
42
        callback_manager (Optional[CallbackManager]): A callback manager.
43
    """
44

45
    def __init__(
46
        self,
47
        retriever: MultiModalVectorIndexRetriever,
48
        multi_modal_llm: Optional[MultiModalLLM] = None,
49
        text_qa_template: Optional[BasePromptTemplate] = None,
50
        image_qa_template: Optional[BasePromptTemplate] = None,
51
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
52
        callback_manager: Optional[CallbackManager] = None,
53
        **kwargs: Any,
54
    ) -> None:
55
        self._retriever = retriever
56
        self._multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
57
            model="gpt-4-vision-preview", max_new_tokens=1000
58
        )
59
        self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
60
        self._image_qa_template = image_qa_template or DEFAULT_TEXT_QA_PROMPT
61

62
        self._node_postprocessors = node_postprocessors or []
63
        callback_manager = callback_manager or CallbackManager([])
64
        for node_postprocessor in self._node_postprocessors:
65
            node_postprocessor.callback_manager = callback_manager
66

67
        super().__init__(callback_manager)
68

69
    def _get_prompts(self) -> Dict[str, Any]:
70
        """Get prompts."""
71
        return {"text_qa_template": self._text_qa_template}
72

73
    def _get_prompt_modules(self) -> PromptMixinType:
74
        """Get prompt sub-modules."""
75
        return {}
76

77
    def _apply_node_postprocessors(
78
        self, nodes: List[NodeWithScore], query_bundle: QueryBundle
79
    ) -> List[NodeWithScore]:
80
        for node_postprocessor in self._node_postprocessors:
81
            nodes = node_postprocessor.postprocess_nodes(
82
                nodes, query_bundle=query_bundle
83
            )
84
        return nodes
85

86
    def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
87
        nodes = self._retriever.retrieve(query_bundle)
88
        return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
89

90
    async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
91
        nodes = await self._retriever.aretrieve(query_bundle)
92
        return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
93

94
    def synthesize(
95
        self,
96
        query_bundle: QueryBundle,
97
        nodes: List[NodeWithScore],
98
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
99
    ) -> RESPONSE_TYPE:
100
        image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
101
        context_str = "\n\n".join([r.get_content() for r in text_nodes])
102
        fmt_prompt = self._text_qa_template.format(
103
            context_str=context_str, query_str=query_bundle.query_str
104
        )
105

106
        llm_response = self._multi_modal_llm.complete(
107
            prompt=fmt_prompt,
108
            image_documents=[image_node.node for image_node in image_nodes],
109
        )
110
        return Response(
111
            response=str(llm_response),
112
            source_nodes=nodes,
113
            metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
114
        )
115

116
    def _get_response_with_images(
117
        self,
118
        prompt_str: str,
119
        image_nodes: List[ImageNode],
120
    ) -> RESPONSE_TYPE:
121
        fmt_prompt = self._image_qa_template.format(
122
            query_str=prompt_str,
123
        )
124

125
        llm_response = self._multi_modal_llm.complete(
126
            prompt=fmt_prompt,
127
            image_documents=[image_node.node for image_node in image_nodes],
128
        )
129
        return Response(
130
            response=str(llm_response),
131
            source_nodes=image_nodes,
132
            metadata={"image_nodes": image_nodes},
133
        )
134

135
    async def asynthesize(
136
        self,
137
        query_bundle: QueryBundle,
138
        nodes: List[NodeWithScore],
139
        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
140
    ) -> RESPONSE_TYPE:
141
        image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
142
        context_str = "\n\n".join([r.get_content() for r in text_nodes])
143
        fmt_prompt = self._text_qa_template.format(
144
            context_str=context_str, query_str=query_bundle.query_str
145
        )
146
        llm_response = await self._multi_modal_llm.acomplete(
147
            prompt=fmt_prompt,
148
            image_documents=image_nodes,
149
        )
150
        return Response(
151
            response=str(llm_response),
152
            source_nodes=nodes,
153
            metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
154
        )
155

156
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
157
        """Answer a query."""
158
        with self.callback_manager.event(
159
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
160
        ) as query_event:
161
            with self.callback_manager.event(
162
                CBEventType.RETRIEVE,
163
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
164
            ) as retrieve_event:
165
                nodes = self.retrieve(query_bundle)
166

167
                retrieve_event.on_end(
168
                    payload={EventPayload.NODES: nodes},
169
                )
170

171
            response = self.synthesize(
172
                query_bundle,
173
                nodes=nodes,
174
            )
175

176
            query_event.on_end(payload={EventPayload.RESPONSE: response})
177

178
        return response
179

180
    def image_query(self, image_path: QueryType, prompt_str: str) -> RESPONSE_TYPE:
181
        """Answer a image query."""
182
        with self.callback_manager.event(
183
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: str(image_path)}
184
        ) as query_event:
185
            with self.callback_manager.event(
186
                CBEventType.RETRIEVE,
187
                payload={EventPayload.QUERY_STR: str(image_path)},
188
            ) as retrieve_event:
189
                nodes = self._retriever.image_to_image_retrieve(image_path)
190

191
                retrieve_event.on_end(
192
                    payload={EventPayload.NODES: nodes},
193
                )
194

195
            image_nodes, _ = _get_image_and_text_nodes(nodes)
196
            response = self._get_response_with_images(
197
                prompt_str=prompt_str,
198
                image_nodes=image_nodes,
199
            )
200

201
            query_event.on_end(payload={EventPayload.RESPONSE: response})
202

203
        return response
204

205
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
206
        """Answer a query."""
207
        with self.callback_manager.event(
208
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
209
        ) as query_event:
210
            with self.callback_manager.event(
211
                CBEventType.RETRIEVE,
212
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
213
            ) as retrieve_event:
214
                nodes = await self.aretrieve(query_bundle)
215

216
                retrieve_event.on_end(
217
                    payload={EventPayload.NODES: nodes},
218
                )
219

220
            response = await self.asynthesize(
221
                query_bundle,
222
                nodes=nodes,
223
            )
224

225
            query_event.on_end(payload={EventPayload.RESPONSE: response})
226

227
        return response
228

229
    @property
230
    def retriever(self) -> MultiModalVectorIndexRetriever:
231
        """Get the retriever object."""
232
        return self._retriever
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.