llama-index

Форк
0
1
"""ReAct multimodal agent."""
2

3
import uuid
4
from typing import (
5
    Any,
6
    Dict,
7
    List,
8
    Optional,
9
    Sequence,
10
    Tuple,
11
    cast,
12
)
13

14
from llama_index.legacy.agent.react.formatter import ReActChatFormatter
15
from llama_index.legacy.agent.react.output_parser import ReActOutputParser
16
from llama_index.legacy.agent.react.types import (
17
    ActionReasoningStep,
18
    BaseReasoningStep,
19
    ObservationReasoningStep,
20
    ResponseReasoningStep,
21
)
22
from llama_index.legacy.agent.react_multimodal.prompts import (
23
    REACT_MM_CHAT_SYSTEM_HEADER,
24
)
25
from llama_index.legacy.agent.types import (
26
    BaseAgentWorker,
27
    Task,
28
    TaskStep,
29
    TaskStepOutput,
30
)
31
from llama_index.legacy.callbacks import (
32
    CallbackManager,
33
    CBEventType,
34
    EventPayload,
35
    trace_method,
36
)
37
from llama_index.legacy.chat_engine.types import (
38
    AGENT_CHAT_RESPONSE_TYPE,
39
    AgentChatResponse,
40
)
41
from llama_index.legacy.core.llms.types import MessageRole
42
from llama_index.legacy.llms.base import ChatMessage, ChatResponse
43
from llama_index.legacy.memory.chat_memory_buffer import ChatMemoryBuffer
44
from llama_index.legacy.memory.types import BaseMemory
45
from llama_index.legacy.multi_modal_llms.base import MultiModalLLM
46
from llama_index.legacy.multi_modal_llms.openai import OpenAIMultiModal
47
from llama_index.legacy.multi_modal_llms.openai_utils import (
48
    generate_openai_multi_modal_chat_message,
49
)
50
from llama_index.legacy.objects.base import ObjectRetriever
51
from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool
52
from llama_index.legacy.tools.types import AsyncBaseTool
53
from llama_index.legacy.utils import print_text
54

55
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
56

57

58
def add_user_step_to_reasoning(
59
    step: TaskStep,
60
    memory: BaseMemory,
61
    current_reasoning: List[BaseReasoningStep],
62
    verbose: bool = False,
63
) -> None:
64
    """Add user step to reasoning.
65

66
    Adds both text input and image input to reasoning.
67

68
    """
69
    # raise error if step.input is None
70
    if step.input is None:
71
        raise ValueError("Step input is None.")
72
    # TODO: support gemini as well. Currently just supports OpenAI
73

74
    # TODO: currently assume that you can't generate images in the loop,
75
    # so step_state contains the original image_docs from the task
76
    # (it doesn't change)
77
    image_docs = step.step_state["image_docs"]
78
    image_kwargs = step.step_state.get("image_kwargs", {})
79

80
    if "is_first" in step.step_state and step.step_state["is_first"]:
81
        mm_message = generate_openai_multi_modal_chat_message(
82
            prompt=step.input,
83
            role=MessageRole.USER,
84
            image_documents=image_docs,
85
            **image_kwargs,
86
        )
87
        # add to new memory
88
        memory.put(mm_message)
89
        step.step_state["is_first"] = False
90
    else:
91
        # NOTE: this is where the user specifies an intermediate step in the middle
92
        # TODO: don't support specifying image_docs here for now
93
        reasoning_step = ObservationReasoningStep(observation=step.input)
94
        current_reasoning.append(reasoning_step)
95
        if verbose:
96
            print(f"Added user message to memory: {step.input}")
97

98

99
class MultimodalReActAgentWorker(BaseAgentWorker):
100
    """Multimodal ReAct Agent worker.
101

102
    **NOTE**: This is a BETA feature.
103

104
    """
105

106
    def __init__(
107
        self,
108
        tools: Sequence[BaseTool],
109
        multi_modal_llm: MultiModalLLM,
110
        max_iterations: int = 10,
111
        react_chat_formatter: Optional[ReActChatFormatter] = None,
112
        output_parser: Optional[ReActOutputParser] = None,
113
        callback_manager: Optional[CallbackManager] = None,
114
        verbose: bool = False,
115
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
116
    ) -> None:
117
        self._multi_modal_llm = multi_modal_llm
118
        self.callback_manager = callback_manager or CallbackManager([])
119
        self._max_iterations = max_iterations
120
        self._react_chat_formatter = react_chat_formatter or ReActChatFormatter(
121
            system_header=REACT_MM_CHAT_SYSTEM_HEADER
122
        )
123
        self._output_parser = output_parser or ReActOutputParser()
124
        self._verbose = verbose
125

126
        if len(tools) > 0 and tool_retriever is not None:
127
            raise ValueError("Cannot specify both tools and tool_retriever")
128
        elif len(tools) > 0:
129
            self._get_tools = lambda _: tools
130
        elif tool_retriever is not None:
131
            tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
132
            self._get_tools = lambda message: tool_retriever_c.retrieve(message)
133
        else:
134
            self._get_tools = lambda _: []
135

136
    @classmethod
137
    def from_tools(
138
        cls,
139
        tools: Optional[Sequence[BaseTool]] = None,
140
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
141
        multi_modal_llm: Optional[MultiModalLLM] = None,
142
        max_iterations: int = 10,
143
        react_chat_formatter: Optional[ReActChatFormatter] = None,
144
        output_parser: Optional[ReActOutputParser] = None,
145
        callback_manager: Optional[CallbackManager] = None,
146
        verbose: bool = False,
147
        **kwargs: Any,
148
    ) -> "MultimodalReActAgentWorker":
149
        """Convenience constructor method from set of of BaseTools (Optional).
150

151
        NOTE: kwargs should have been exhausted by this point. In other words
152
        the various upstream components such as BaseSynthesizer (response synthesizer)
153
        or BaseRetriever should have picked up off their respective kwargs in their
154
        constructions.
155

156
        Returns:
157
            ReActAgent
158
        """
159
        multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
160
            model="gpt-4-vision-preview", max_new_tokens=1000
161
        )
162
        return cls(
163
            tools=tools or [],
164
            tool_retriever=tool_retriever,
165
            multi_modal_llm=multi_modal_llm,
166
            max_iterations=max_iterations,
167
            react_chat_formatter=react_chat_formatter,
168
            output_parser=output_parser,
169
            callback_manager=callback_manager,
170
            verbose=verbose,
171
        )
172

173
    def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
174
        """Initialize step from task."""
175
        sources: List[ToolOutput] = []
176
        current_reasoning: List[BaseReasoningStep] = []
177
        # temporary memory for new messages
178
        new_memory = ChatMemoryBuffer.from_defaults()
179

180
        # validation
181
        if "image_docs" not in task.extra_state:
182
            raise ValueError("Image docs not found in task extra state.")
183

184
        # initialize task state
185
        task_state = {
186
            "sources": sources,
187
            "current_reasoning": current_reasoning,
188
            "new_memory": new_memory,
189
        }
190
        task.extra_state.update(task_state)
191

192
        return TaskStep(
193
            task_id=task.task_id,
194
            step_id=str(uuid.uuid4()),
195
            input=task.input,
196
            step_state={"is_first": True, "image_docs": task.extra_state["image_docs"]},
197
        )
198

199
    def get_tools(self, input: str) -> List[AsyncBaseTool]:
200
        """Get tools."""
201
        return [adapt_to_async_tool(t) for t in self._get_tools(input)]
202

203
    def _extract_reasoning_step(
204
        self, output: ChatResponse, is_streaming: bool = False
205
    ) -> Tuple[str, List[BaseReasoningStep], bool]:
206
        """
207
        Extracts the reasoning step from the given output.
208

209
        This method parses the message content from the output,
210
        extracts the reasoning step, and determines whether the processing is
211
        complete. It also performs validation checks on the output and
212
        handles possible errors.
213
        """
214
        if output.message.content is None:
215
            raise ValueError("Got empty message.")
216
        message_content = output.message.content
217
        current_reasoning = []
218
        try:
219
            reasoning_step = self._output_parser.parse(message_content, is_streaming)
220
        except BaseException as exc:
221
            raise ValueError(f"Could not parse output: {message_content}") from exc
222
        if self._verbose:
223
            print_text(f"{reasoning_step.get_content()}\n", color="pink")
224
        current_reasoning.append(reasoning_step)
225

226
        if reasoning_step.is_done:
227
            return message_content, current_reasoning, True
228

229
        reasoning_step = cast(ActionReasoningStep, reasoning_step)
230
        if not isinstance(reasoning_step, ActionReasoningStep):
231
            raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
232

233
        return message_content, current_reasoning, False
234

235
    def _process_actions(
236
        self,
237
        task: Task,
238
        tools: Sequence[AsyncBaseTool],
239
        output: ChatResponse,
240
        is_streaming: bool = False,
241
    ) -> Tuple[List[BaseReasoningStep], bool]:
242
        tools_dict: Dict[str, AsyncBaseTool] = {
243
            tool.metadata.get_name(): tool for tool in tools
244
        }
245
        _, current_reasoning, is_done = self._extract_reasoning_step(
246
            output, is_streaming
247
        )
248

249
        if is_done:
250
            return current_reasoning, True
251

252
        # call tool with input
253
        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
254
        tool = tools_dict[reasoning_step.action]
255
        with self.callback_manager.event(
256
            CBEventType.FUNCTION_CALL,
257
            payload={
258
                EventPayload.FUNCTION_CALL: reasoning_step.action_input,
259
                EventPayload.TOOL: tool.metadata,
260
            },
261
        ) as event:
262
            tool_output = tool.call(**reasoning_step.action_input)
263
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
264

265
        task.extra_state["sources"].append(tool_output)
266

267
        observation_step = ObservationReasoningStep(observation=str(tool_output))
268
        current_reasoning.append(observation_step)
269
        if self._verbose:
270
            print_text(f"{observation_step.get_content()}\n", color="blue")
271
        return current_reasoning, False
272

273
    async def _aprocess_actions(
274
        self,
275
        task: Task,
276
        tools: Sequence[AsyncBaseTool],
277
        output: ChatResponse,
278
        is_streaming: bool = False,
279
    ) -> Tuple[List[BaseReasoningStep], bool]:
280
        tools_dict = {tool.metadata.name: tool for tool in tools}
281
        _, current_reasoning, is_done = self._extract_reasoning_step(
282
            output, is_streaming
283
        )
284

285
        if is_done:
286
            return current_reasoning, True
287

288
        # call tool with input
289
        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
290
        tool = tools_dict[reasoning_step.action]
291
        with self.callback_manager.event(
292
            CBEventType.FUNCTION_CALL,
293
            payload={
294
                EventPayload.FUNCTION_CALL: reasoning_step.action_input,
295
                EventPayload.TOOL: tool.metadata,
296
            },
297
        ) as event:
298
            tool_output = await tool.acall(**reasoning_step.action_input)
299
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
300

301
        task.extra_state["sources"].append(tool_output)
302

303
        observation_step = ObservationReasoningStep(observation=str(tool_output))
304
        current_reasoning.append(observation_step)
305
        if self._verbose:
306
            print_text(f"{observation_step.get_content()}\n", color="blue")
307
        return current_reasoning, False
308

309
    def _get_response(
310
        self,
311
        current_reasoning: List[BaseReasoningStep],
312
        sources: List[ToolOutput],
313
    ) -> AgentChatResponse:
314
        """Get response from reasoning steps."""
315
        if len(current_reasoning) == 0:
316
            raise ValueError("No reasoning steps were taken.")
317
        elif len(current_reasoning) == self._max_iterations:
318
            raise ValueError("Reached max iterations.")
319

320
        if isinstance(current_reasoning[-1], ResponseReasoningStep):
321
            response_step = cast(ResponseReasoningStep, current_reasoning[-1])
322
            response_str = response_step.response
323
        else:
324
            response_str = current_reasoning[-1].get_content()
325

326
        # TODO: add sources from reasoning steps
327
        return AgentChatResponse(response=response_str, sources=sources)
328

329
    def _get_task_step_response(
330
        self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
331
    ) -> TaskStepOutput:
332
        """Get task step response."""
333
        if is_done:
334
            new_steps = []
335
        else:
336
            new_steps = [
337
                step.get_next_step(
338
                    step_id=str(uuid.uuid4()),
339
                    # NOTE: input is unused
340
                    input=None,
341
                )
342
            ]
343

344
        return TaskStepOutput(
345
            output=agent_response,
346
            task_step=step,
347
            is_last=is_done,
348
            next_steps=new_steps,
349
        )
350

351
    def _run_step(
352
        self,
353
        step: TaskStep,
354
        task: Task,
355
    ) -> TaskStepOutput:
356
        """Run step."""
357
        # This is either not None on the first step or if the user specifies
358
        # an intermediate step in the middle
359
        if step.input is not None:
360
            add_user_step_to_reasoning(
361
                step,
362
                task.extra_state["new_memory"],
363
                task.extra_state["current_reasoning"],
364
                verbose=self._verbose,
365
            )
366
        # TODO: see if we want to do step-based inputs
367
        tools = self.get_tools(task.input)
368

369
        input_chat = self._react_chat_formatter.format(
370
            tools,
371
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
372
            current_reasoning=task.extra_state["current_reasoning"],
373
        )
374

375
        # send prompt
376
        chat_response = self._multi_modal_llm.chat(input_chat)
377
        # given react prompt outputs, call tools or return response
378
        reasoning_steps, is_done = self._process_actions(
379
            task, tools, output=chat_response
380
        )
381
        task.extra_state["current_reasoning"].extend(reasoning_steps)
382
        agent_response = self._get_response(
383
            task.extra_state["current_reasoning"], task.extra_state["sources"]
384
        )
385
        if is_done:
386
            task.extra_state["new_memory"].put(
387
                ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
388
            )
389

390
        return self._get_task_step_response(agent_response, step, is_done)
391

392
    async def _arun_step(
393
        self,
394
        step: TaskStep,
395
        task: Task,
396
    ) -> TaskStepOutput:
397
        """Run step."""
398
        if step.input is not None:
399
            add_user_step_to_reasoning(
400
                step,
401
                task.extra_state["new_memory"],
402
                task.extra_state["current_reasoning"],
403
                verbose=self._verbose,
404
            )
405
        # TODO: see if we want to do step-based inputs
406
        tools = self.get_tools(task.input)
407

408
        input_chat = self._react_chat_formatter.format(
409
            tools,
410
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
411
            current_reasoning=task.extra_state["current_reasoning"],
412
        )
413
        # send prompt
414
        chat_response = await self._multi_modal_llm.achat(input_chat)
415
        # given react prompt outputs, call tools or return response
416
        reasoning_steps, is_done = await self._aprocess_actions(
417
            task, tools, output=chat_response
418
        )
419
        task.extra_state["current_reasoning"].extend(reasoning_steps)
420
        agent_response = self._get_response(
421
            task.extra_state["current_reasoning"], task.extra_state["sources"]
422
        )
423
        if is_done:
424
            task.extra_state["new_memory"].put(
425
                ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
426
            )
427

428
        return self._get_task_step_response(agent_response, step, is_done)
429

430
    def _run_step_stream(
431
        self,
432
        step: TaskStep,
433
        task: Task,
434
    ) -> TaskStepOutput:
435
        """Run step."""
436
        raise NotImplementedError("Stream step not implemented yet.")
437

438
    async def _arun_step_stream(
439
        self,
440
        step: TaskStep,
441
        task: Task,
442
    ) -> TaskStepOutput:
443
        """Run step."""
444
        raise NotImplementedError("Stream step not implemented yet.")
445

446
    @trace_method("run_step")
447
    def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
448
        """Run step."""
449
        return self._run_step(step, task)
450

451
    @trace_method("run_step")
452
    async def arun_step(
453
        self, step: TaskStep, task: Task, **kwargs: Any
454
    ) -> TaskStepOutput:
455
        """Run step (async)."""
456
        return await self._arun_step(step, task)
457

458
    @trace_method("run_step")
459
    def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
460
        """Run step (stream)."""
461
        # TODO: figure out if we need a different type for TaskStepOutput
462
        return self._run_step_stream(step, task)
463

464
    @trace_method("run_step")
465
    async def astream_step(
466
        self, step: TaskStep, task: Task, **kwargs: Any
467
    ) -> TaskStepOutput:
468
        """Run step (async stream)."""
469
        return await self._arun_step_stream(step, task)
470

471
    def finalize_task(self, task: Task, **kwargs: Any) -> None:
472
        """Finalize task, after all the steps are completed."""
473
        # add new messages to memory
474
        task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
475
        # reset new memory
476
        task.extra_state["new_memory"].reset()
477

478
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
479
        """Set callback manager."""
480
        # TODO: make this abstractmethod (right now will break some agent impls)
481
        self.callback_manager = callback_manager
482

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.