llama-index

Форк
0
640 строк · 23.4 Кб
1
"""ReAct agent worker."""
2

3
import asyncio
4
import uuid
5
from itertools import chain
6
from threading import Thread
7
from typing import (
8
    Any,
9
    AsyncGenerator,
10
    Dict,
11
    Generator,
12
    List,
13
    Optional,
14
    Sequence,
15
    Tuple,
16
    cast,
17
)
18

19
from llama_index.legacy.agent.react.formatter import ReActChatFormatter
20
from llama_index.legacy.agent.react.output_parser import ReActOutputParser
21
from llama_index.legacy.agent.react.types import (
22
    ActionReasoningStep,
23
    BaseReasoningStep,
24
    ObservationReasoningStep,
25
    ResponseReasoningStep,
26
)
27
from llama_index.legacy.agent.types import (
28
    BaseAgentWorker,
29
    Task,
30
    TaskStep,
31
    TaskStepOutput,
32
)
33
from llama_index.legacy.callbacks import (
34
    CallbackManager,
35
    CBEventType,
36
    EventPayload,
37
    trace_method,
38
)
39
from llama_index.legacy.chat_engine.types import (
40
    AGENT_CHAT_RESPONSE_TYPE,
41
    AgentChatResponse,
42
    StreamingAgentChatResponse,
43
)
44
from llama_index.legacy.core.llms.types import MessageRole
45
from llama_index.legacy.llms.base import ChatMessage, ChatResponse
46
from llama_index.legacy.llms.llm import LLM
47
from llama_index.legacy.llms.openai import OpenAI
48
from llama_index.legacy.memory.chat_memory_buffer import ChatMemoryBuffer
49
from llama_index.legacy.memory.types import BaseMemory
50
from llama_index.legacy.objects.base import ObjectRetriever
51
from llama_index.legacy.prompts.base import PromptTemplate
52
from llama_index.legacy.prompts.mixin import PromptDictType
53
from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool
54
from llama_index.legacy.tools.types import AsyncBaseTool
55
from llama_index.legacy.utils import print_text, unit_generator
56

57
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
58

59

60
def add_user_step_to_reasoning(
61
    step: TaskStep,
62
    memory: BaseMemory,
63
    current_reasoning: List[BaseReasoningStep],
64
    verbose: bool = False,
65
) -> None:
66
    """Add user step to memory."""
67
    if "is_first" in step.step_state and step.step_state["is_first"]:
68
        # add to new memory
69
        memory.put(ChatMessage(content=step.input, role=MessageRole.USER))
70
        step.step_state["is_first"] = False
71
    else:
72
        reasoning_step = ObservationReasoningStep(observation=step.input)
73
        current_reasoning.append(reasoning_step)
74
        if verbose:
75
            print(f"Added user message to memory: {step.input}")
76

77

78
class ReActAgentWorker(BaseAgentWorker):
79
    """OpenAI Agent worker."""
80

81
    def __init__(
82
        self,
83
        tools: Sequence[BaseTool],
84
        llm: LLM,
85
        max_iterations: int = 10,
86
        react_chat_formatter: Optional[ReActChatFormatter] = None,
87
        output_parser: Optional[ReActOutputParser] = None,
88
        callback_manager: Optional[CallbackManager] = None,
89
        verbose: bool = False,
90
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
91
    ) -> None:
92
        self._llm = llm
93
        self.callback_manager = callback_manager or llm.callback_manager
94
        self._max_iterations = max_iterations
95
        self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
96
        self._output_parser = output_parser or ReActOutputParser()
97
        self._verbose = verbose
98

99
        if len(tools) > 0 and tool_retriever is not None:
100
            raise ValueError("Cannot specify both tools and tool_retriever")
101
        elif len(tools) > 0:
102
            self._get_tools = lambda _: tools
103
        elif tool_retriever is not None:
104
            tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
105
            self._get_tools = lambda message: tool_retriever_c.retrieve(message)
106
        else:
107
            self._get_tools = lambda _: []
108

109
    @classmethod
110
    def from_tools(
111
        cls,
112
        tools: Optional[Sequence[BaseTool]] = None,
113
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
114
        llm: Optional[LLM] = None,
115
        max_iterations: int = 10,
116
        react_chat_formatter: Optional[ReActChatFormatter] = None,
117
        output_parser: Optional[ReActOutputParser] = None,
118
        callback_manager: Optional[CallbackManager] = None,
119
        verbose: bool = False,
120
        **kwargs: Any,
121
    ) -> "ReActAgentWorker":
122
        """Convenience constructor method from set of of BaseTools (Optional).
123

124
        NOTE: kwargs should have been exhausted by this point. In other words
125
        the various upstream components such as BaseSynthesizer (response synthesizer)
126
        or BaseRetriever should have picked up off their respective kwargs in their
127
        constructions.
128

129
        Returns:
130
            ReActAgent
131
        """
132
        llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
133
        if callback_manager is not None:
134
            llm.callback_manager = callback_manager
135
        return cls(
136
            tools=tools or [],
137
            tool_retriever=tool_retriever,
138
            llm=llm,
139
            max_iterations=max_iterations,
140
            react_chat_formatter=react_chat_formatter,
141
            output_parser=output_parser,
142
            callback_manager=callback_manager,
143
            verbose=verbose,
144
        )
145

146
    def _get_prompts(self) -> PromptDictType:
147
        """Get prompts."""
148
        # TODO: the ReAct formatter does not explicitly specify PromptTemplate
149
        # objects, but wrap it in this to obey the interface
150
        sys_header = self._react_chat_formatter.system_header
151
        return {"system_prompt": PromptTemplate(sys_header)}
152

153
    def _update_prompts(self, prompts: PromptDictType) -> None:
154
        """Update prompts."""
155
        if "system_prompt" in prompts:
156
            sys_prompt = cast(PromptTemplate, prompts["system_prompt"])
157
            self._react_chat_formatter.system_header = sys_prompt.template
158

159
    def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
160
        """Initialize step from task."""
161
        sources: List[ToolOutput] = []
162
        current_reasoning: List[BaseReasoningStep] = []
163
        # temporary memory for new messages
164
        new_memory = ChatMemoryBuffer.from_defaults()
165

166
        # initialize task state
167
        task_state = {
168
            "sources": sources,
169
            "current_reasoning": current_reasoning,
170
            "new_memory": new_memory,
171
        }
172
        task.extra_state.update(task_state)
173

174
        return TaskStep(
175
            task_id=task.task_id,
176
            step_id=str(uuid.uuid4()),
177
            input=task.input,
178
            step_state={"is_first": True},
179
        )
180

181
    def get_tools(self, input: str) -> List[AsyncBaseTool]:
182
        """Get tools."""
183
        return [adapt_to_async_tool(t) for t in self._get_tools(input)]
184

185
    def _extract_reasoning_step(
186
        self, output: ChatResponse, is_streaming: bool = False
187
    ) -> Tuple[str, List[BaseReasoningStep], bool]:
188
        """
189
        Extracts the reasoning step from the given output.
190

191
        This method parses the message content from the output,
192
        extracts the reasoning step, and determines whether the processing is
193
        complete. It also performs validation checks on the output and
194
        handles possible errors.
195
        """
196
        if output.message.content is None:
197
            raise ValueError("Got empty message.")
198
        message_content = output.message.content
199
        current_reasoning = []
200
        try:
201
            reasoning_step = self._output_parser.parse(message_content, is_streaming)
202
        except BaseException as exc:
203
            raise ValueError(f"Could not parse output: {message_content}") from exc
204
        if self._verbose:
205
            print_text(f"{reasoning_step.get_content()}\n", color="pink")
206
        current_reasoning.append(reasoning_step)
207

208
        if reasoning_step.is_done:
209
            return message_content, current_reasoning, True
210

211
        reasoning_step = cast(ActionReasoningStep, reasoning_step)
212
        if not isinstance(reasoning_step, ActionReasoningStep):
213
            raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
214

215
        return message_content, current_reasoning, False
216

217
    def _process_actions(
218
        self,
219
        task: Task,
220
        tools: Sequence[AsyncBaseTool],
221
        output: ChatResponse,
222
        is_streaming: bool = False,
223
    ) -> Tuple[List[BaseReasoningStep], bool]:
224
        tools_dict: Dict[str, AsyncBaseTool] = {
225
            tool.metadata.get_name(): tool for tool in tools
226
        }
227
        _, current_reasoning, is_done = self._extract_reasoning_step(
228
            output, is_streaming
229
        )
230

231
        if is_done:
232
            return current_reasoning, True
233

234
        # call tool with input
235
        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
236
        tool = tools_dict[reasoning_step.action]
237
        with self.callback_manager.event(
238
            CBEventType.FUNCTION_CALL,
239
            payload={
240
                EventPayload.FUNCTION_CALL: reasoning_step.action_input,
241
                EventPayload.TOOL: tool.metadata,
242
            },
243
        ) as event:
244
            tool_output = tool.call(**reasoning_step.action_input)
245
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
246

247
        task.extra_state["sources"].append(tool_output)
248

249
        observation_step = ObservationReasoningStep(observation=str(tool_output))
250
        current_reasoning.append(observation_step)
251
        if self._verbose:
252
            print_text(f"{observation_step.get_content()}\n", color="blue")
253
        return current_reasoning, False
254

255
    async def _aprocess_actions(
256
        self,
257
        task: Task,
258
        tools: Sequence[AsyncBaseTool],
259
        output: ChatResponse,
260
        is_streaming: bool = False,
261
    ) -> Tuple[List[BaseReasoningStep], bool]:
262
        tools_dict = {tool.metadata.name: tool for tool in tools}
263
        _, current_reasoning, is_done = self._extract_reasoning_step(
264
            output, is_streaming
265
        )
266

267
        if is_done:
268
            return current_reasoning, True
269

270
        # call tool with input
271
        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
272
        tool = tools_dict[reasoning_step.action]
273
        with self.callback_manager.event(
274
            CBEventType.FUNCTION_CALL,
275
            payload={
276
                EventPayload.FUNCTION_CALL: reasoning_step.action_input,
277
                EventPayload.TOOL: tool.metadata,
278
            },
279
        ) as event:
280
            tool_output = await tool.acall(**reasoning_step.action_input)
281
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
282

283
        task.extra_state["sources"].append(tool_output)
284

285
        observation_step = ObservationReasoningStep(observation=str(tool_output))
286
        current_reasoning.append(observation_step)
287
        if self._verbose:
288
            print_text(f"{observation_step.get_content()}\n", color="blue")
289
        return current_reasoning, False
290

291
    def _get_response(
292
        self,
293
        current_reasoning: List[BaseReasoningStep],
294
        sources: List[ToolOutput],
295
    ) -> AgentChatResponse:
296
        """Get response from reasoning steps."""
297
        if len(current_reasoning) == 0:
298
            raise ValueError("No reasoning steps were taken.")
299
        elif len(current_reasoning) == self._max_iterations:
300
            raise ValueError("Reached max iterations.")
301

302
        if isinstance(current_reasoning[-1], ResponseReasoningStep):
303
            response_step = cast(ResponseReasoningStep, current_reasoning[-1])
304
            response_str = response_step.response
305
        else:
306
            response_str = current_reasoning[-1].get_content()
307

308
        # TODO: add sources from reasoning steps
309
        return AgentChatResponse(response=response_str, sources=sources)
310

311
    def _get_task_step_response(
312
        self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
313
    ) -> TaskStepOutput:
314
        """Get task step response."""
315
        if is_done:
316
            new_steps = []
317
        else:
318
            new_steps = [
319
                step.get_next_step(
320
                    step_id=str(uuid.uuid4()),
321
                    # NOTE: input is unused
322
                    input=None,
323
                )
324
            ]
325

326
        return TaskStepOutput(
327
            output=agent_response,
328
            task_step=step,
329
            is_last=is_done,
330
            next_steps=new_steps,
331
        )
332

333
    def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
334
        """Infers if a chunk from a live stream is the start of the final
335
        reasoning step. (i.e., and should eventually become
336
        ResponseReasoningStep — not part of this function's logic tho.).
337

338
        Args:
339
            chunk (ChatResponse): the current chunk stream to check
340

341
        Returns:
342
            bool: Boolean on whether the chunk is the start of the final response
343
        """
344
        latest_content = chunk.message.content
345
        if latest_content:
346
            if not latest_content.startswith(
347
                "Thought"
348
            ):  # doesn't follow thought-action format
349
                return True
350
            else:
351
                if "Answer: " in latest_content:
352
                    return True
353
        return False
354

355
    def _add_back_chunk_to_stream(
356
        self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
357
    ) -> Generator[ChatResponse, None, None]:
358
        """Helper method for adding back initial chunk stream of final response
359
        back to the rest of the chat_stream.
360

361
        Args:
362
            chunk (ChatResponse): the chunk to add back to the beginning of the
363
                                    chat_stream.
364

365
        Return:
366
            Generator[ChatResponse, None, None]: the updated chat_stream
367
        """
368
        updated_stream = chain.from_iterable(  # need to add back partial response chunk
369
            [
370
                unit_generator(chunk),
371
                chat_stream,
372
            ]
373
        )
374
        # use cast to avoid mypy issue with chain and Generator
375
        updated_stream_c: Generator[ChatResponse, None, None] = cast(
376
            Generator[ChatResponse, None, None], updated_stream
377
        )
378
        return updated_stream_c
379

380
    async def _async_add_back_chunk_to_stream(
381
        self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
382
    ) -> AsyncGenerator[ChatResponse, None]:
383
        """Helper method for adding back initial chunk stream of final response
384
        back to the rest of the chat_stream.
385

386
        NOTE: this itself is not an async function.
387

388
        Args:
389
            chunk (ChatResponse): the chunk to add back to the beginning of the
390
                                    chat_stream.
391

392
        Return:
393
            AsyncGenerator[ChatResponse, None]: the updated async chat_stream
394
        """
395
        yield chunk
396
        async for item in chat_stream:
397
            yield item
398

399
    def _run_step(
400
        self,
401
        step: TaskStep,
402
        task: Task,
403
    ) -> TaskStepOutput:
404
        """Run step."""
405
        if step.input is not None:
406
            add_user_step_to_reasoning(
407
                step,
408
                task.extra_state["new_memory"],
409
                task.extra_state["current_reasoning"],
410
                verbose=self._verbose,
411
            )
412
        # TODO: see if we want to do step-based inputs
413
        tools = self.get_tools(task.input)
414

415
        input_chat = self._react_chat_formatter.format(
416
            tools,
417
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
418
            current_reasoning=task.extra_state["current_reasoning"],
419
        )
420

421
        # send prompt
422
        chat_response = self._llm.chat(input_chat)
423
        # given react prompt outputs, call tools or return response
424
        reasoning_steps, is_done = self._process_actions(
425
            task, tools, output=chat_response
426
        )
427
        task.extra_state["current_reasoning"].extend(reasoning_steps)
428
        agent_response = self._get_response(
429
            task.extra_state["current_reasoning"], task.extra_state["sources"]
430
        )
431
        if is_done:
432
            task.extra_state["new_memory"].put(
433
                ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
434
            )
435

436
        return self._get_task_step_response(agent_response, step, is_done)
437

438
    async def _arun_step(
439
        self,
440
        step: TaskStep,
441
        task: Task,
442
    ) -> TaskStepOutput:
443
        """Run step."""
444
        if step.input is not None:
445
            add_user_step_to_reasoning(
446
                step,
447
                task.extra_state["new_memory"],
448
                task.extra_state["current_reasoning"],
449
                verbose=self._verbose,
450
            )
451
        # TODO: see if we want to do step-based inputs
452
        tools = self.get_tools(task.input)
453

454
        input_chat = self._react_chat_formatter.format(
455
            tools,
456
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
457
            current_reasoning=task.extra_state["current_reasoning"],
458
        )
459
        # send prompt
460
        chat_response = await self._llm.achat(input_chat)
461
        # given react prompt outputs, call tools or return response
462
        reasoning_steps, is_done = await self._aprocess_actions(
463
            task, tools, output=chat_response
464
        )
465
        task.extra_state["current_reasoning"].extend(reasoning_steps)
466
        agent_response = self._get_response(
467
            task.extra_state["current_reasoning"], task.extra_state["sources"]
468
        )
469
        if is_done:
470
            task.extra_state["new_memory"].put(
471
                ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
472
            )
473

474
        return self._get_task_step_response(agent_response, step, is_done)
475

476
    def _run_step_stream(
477
        self,
478
        step: TaskStep,
479
        task: Task,
480
    ) -> TaskStepOutput:
481
        """Run step."""
482
        if step.input is not None:
483
            add_user_step_to_reasoning(
484
                step,
485
                task.extra_state["new_memory"],
486
                task.extra_state["current_reasoning"],
487
                verbose=self._verbose,
488
            )
489
        # TODO: see if we want to do step-based inputs
490
        tools = self.get_tools(task.input)
491

492
        input_chat = self._react_chat_formatter.format(
493
            tools,
494
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
495
            current_reasoning=task.extra_state["current_reasoning"],
496
        )
497

498
        chat_stream = self._llm.stream_chat(input_chat)
499

500
        # iterate over stream, break out if is final answer after the "Answer: "
501
        full_response = ChatResponse(
502
            message=ChatMessage(content=None, role="assistant")
503
        )
504
        is_done = False
505
        for latest_chunk in chat_stream:
506
            full_response = latest_chunk
507
            is_done = self._infer_stream_chunk_is_final(latest_chunk)
508
            if is_done:
509
                break
510

511
        if not is_done:
512
            # given react prompt outputs, call tools or return response
513
            reasoning_steps, _ = self._process_actions(
514
                task, tools=tools, output=full_response, is_streaming=True
515
            )
516
            task.extra_state["current_reasoning"].extend(reasoning_steps)
517
            # use _get_response to return intermediate response
518
            agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
519
                task.extra_state["current_reasoning"], task.extra_state["sources"]
520
            )
521
        else:
522
            # Get the response in a separate thread so we can yield the response
523
            response_stream = self._add_back_chunk_to_stream(
524
                chunk=latest_chunk, chat_stream=chat_stream
525
            )
526

527
            agent_response = StreamingAgentChatResponse(
528
                chat_stream=response_stream,
529
                sources=task.extra_state["sources"],
530
            )
531
            thread = Thread(
532
                target=agent_response.write_response_to_history,
533
                args=(task.extra_state["new_memory"],),
534
            )
535
            thread.start()
536

537
        return self._get_task_step_response(agent_response, step, is_done)
538

539
    async def _arun_step_stream(
540
        self,
541
        step: TaskStep,
542
        task: Task,
543
    ) -> TaskStepOutput:
544
        """Run step."""
545
        if step.input is not None:
546
            add_user_step_to_reasoning(
547
                step,
548
                task.extra_state["new_memory"],
549
                task.extra_state["current_reasoning"],
550
                verbose=self._verbose,
551
            )
552
        # TODO: see if we want to do step-based inputs
553
        tools = self.get_tools(task.input)
554

555
        input_chat = self._react_chat_formatter.format(
556
            tools,
557
            chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
558
            current_reasoning=task.extra_state["current_reasoning"],
559
        )
560

561
        chat_stream = await self._llm.astream_chat(input_chat)
562

563
        # iterate over stream, break out if is final answer after the "Answer: "
564
        full_response = ChatResponse(
565
            message=ChatMessage(content=None, role="assistant")
566
        )
567
        is_done = False
568
        async for latest_chunk in chat_stream:
569
            full_response = latest_chunk
570
            is_done = self._infer_stream_chunk_is_final(latest_chunk)
571
            if is_done:
572
                break
573

574
        if not is_done:
575
            # given react prompt outputs, call tools or return response
576
            reasoning_steps, _ = self._process_actions(
577
                task, tools=tools, output=full_response, is_streaming=True
578
            )
579
            task.extra_state["current_reasoning"].extend(reasoning_steps)
580
            # use _get_response to return intermediate response
581
            agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
582
                task.extra_state["current_reasoning"], task.extra_state["sources"]
583
            )
584
        else:
585
            # Get the response in a separate thread so we can yield the response
586
            response_stream = self._async_add_back_chunk_to_stream(
587
                chunk=latest_chunk, chat_stream=chat_stream
588
            )
589

590
            agent_response = StreamingAgentChatResponse(
591
                achat_stream=response_stream,
592
                sources=task.extra_state["sources"],
593
            )
594
            # create task to write chat response to history
595
            asyncio.create_task(
596
                agent_response.awrite_response_to_history(
597
                    task.extra_state["new_memory"]
598
                )
599
            )
600
            # wait until response writing is done
601
            await agent_response._is_function_false_event.wait()
602

603
        return self._get_task_step_response(agent_response, step, is_done)
604

605
    @trace_method("run_step")
606
    def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
607
        """Run step."""
608
        return self._run_step(step, task)
609

610
    @trace_method("run_step")
611
    async def arun_step(
612
        self, step: TaskStep, task: Task, **kwargs: Any
613
    ) -> TaskStepOutput:
614
        """Run step (async)."""
615
        return await self._arun_step(step, task)
616

617
    @trace_method("run_step")
618
    def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
619
        """Run step (stream)."""
620
        # TODO: figure out if we need a different type for TaskStepOutput
621
        return self._run_step_stream(step, task)
622

623
    @trace_method("run_step")
624
    async def astream_step(
625
        self, step: TaskStep, task: Task, **kwargs: Any
626
    ) -> TaskStepOutput:
627
        """Run step (async stream)."""
628
        return await self._arun_step_stream(step, task)
629

630
    def finalize_task(self, task: Task, **kwargs: Any) -> None:
631
        """Finalize task, after all the steps are completed."""
632
        # add new messages to memory
633
        task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
634
        # reset new memory
635
        task.extra_state["new_memory"].reset()
636

637
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
638
        """Set callback manager."""
639
        # TODO: make this abstractmethod (right now will break some agent impls)
640
        self.callback_manager = callback_manager
641

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.