llama-index
640 строк · 23.4 Кб
1"""ReAct agent worker."""
2
3import asyncio4import uuid5from itertools import chain6from threading import Thread7from typing import (8Any,9AsyncGenerator,10Dict,11Generator,12List,13Optional,14Sequence,15Tuple,16cast,17)
18
19from llama_index.legacy.agent.react.formatter import ReActChatFormatter20from llama_index.legacy.agent.react.output_parser import ReActOutputParser21from llama_index.legacy.agent.react.types import (22ActionReasoningStep,23BaseReasoningStep,24ObservationReasoningStep,25ResponseReasoningStep,26)
27from llama_index.legacy.agent.types import (28BaseAgentWorker,29Task,30TaskStep,31TaskStepOutput,32)
33from llama_index.legacy.callbacks import (34CallbackManager,35CBEventType,36EventPayload,37trace_method,38)
39from llama_index.legacy.chat_engine.types import (40AGENT_CHAT_RESPONSE_TYPE,41AgentChatResponse,42StreamingAgentChatResponse,43)
44from llama_index.legacy.core.llms.types import MessageRole45from llama_index.legacy.llms.base import ChatMessage, ChatResponse46from llama_index.legacy.llms.llm import LLM47from llama_index.legacy.llms.openai import OpenAI48from llama_index.legacy.memory.chat_memory_buffer import ChatMemoryBuffer49from llama_index.legacy.memory.types import BaseMemory50from llama_index.legacy.objects.base import ObjectRetriever51from llama_index.legacy.prompts.base import PromptTemplate52from llama_index.legacy.prompts.mixin import PromptDictType53from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool54from llama_index.legacy.tools.types import AsyncBaseTool55from llama_index.legacy.utils import print_text, unit_generator56
57DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"58
59
60def add_user_step_to_reasoning(61step: TaskStep,62memory: BaseMemory,63current_reasoning: List[BaseReasoningStep],64verbose: bool = False,65) -> None:66"""Add user step to memory."""67if "is_first" in step.step_state and step.step_state["is_first"]:68# add to new memory69memory.put(ChatMessage(content=step.input, role=MessageRole.USER))70step.step_state["is_first"] = False71else:72reasoning_step = ObservationReasoningStep(observation=step.input)73current_reasoning.append(reasoning_step)74if verbose:75print(f"Added user message to memory: {step.input}")76
77
78class ReActAgentWorker(BaseAgentWorker):79"""OpenAI Agent worker."""80
81def __init__(82self,83tools: Sequence[BaseTool],84llm: LLM,85max_iterations: int = 10,86react_chat_formatter: Optional[ReActChatFormatter] = None,87output_parser: Optional[ReActOutputParser] = None,88callback_manager: Optional[CallbackManager] = None,89verbose: bool = False,90tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,91) -> None:92self._llm = llm93self.callback_manager = callback_manager or llm.callback_manager94self._max_iterations = max_iterations95self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()96self._output_parser = output_parser or ReActOutputParser()97self._verbose = verbose98
99if len(tools) > 0 and tool_retriever is not None:100raise ValueError("Cannot specify both tools and tool_retriever")101elif len(tools) > 0:102self._get_tools = lambda _: tools103elif tool_retriever is not None:104tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)105self._get_tools = lambda message: tool_retriever_c.retrieve(message)106else:107self._get_tools = lambda _: []108
109@classmethod110def from_tools(111cls,112tools: Optional[Sequence[BaseTool]] = None,113tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,114llm: Optional[LLM] = None,115max_iterations: int = 10,116react_chat_formatter: Optional[ReActChatFormatter] = None,117output_parser: Optional[ReActOutputParser] = None,118callback_manager: Optional[CallbackManager] = None,119verbose: bool = False,120**kwargs: Any,121) -> "ReActAgentWorker":122"""Convenience constructor method from set of of BaseTools (Optional).123
124NOTE: kwargs should have been exhausted by this point. In other words
125the various upstream components such as BaseSynthesizer (response synthesizer)
126or BaseRetriever should have picked up off their respective kwargs in their
127constructions.
128
129Returns:
130ReActAgent
131"""
132llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)133if callback_manager is not None:134llm.callback_manager = callback_manager135return cls(136tools=tools or [],137tool_retriever=tool_retriever,138llm=llm,139max_iterations=max_iterations,140react_chat_formatter=react_chat_formatter,141output_parser=output_parser,142callback_manager=callback_manager,143verbose=verbose,144)145
146def _get_prompts(self) -> PromptDictType:147"""Get prompts."""148# TODO: the ReAct formatter does not explicitly specify PromptTemplate149# objects, but wrap it in this to obey the interface150sys_header = self._react_chat_formatter.system_header151return {"system_prompt": PromptTemplate(sys_header)}152
153def _update_prompts(self, prompts: PromptDictType) -> None:154"""Update prompts."""155if "system_prompt" in prompts:156sys_prompt = cast(PromptTemplate, prompts["system_prompt"])157self._react_chat_formatter.system_header = sys_prompt.template158
159def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:160"""Initialize step from task."""161sources: List[ToolOutput] = []162current_reasoning: List[BaseReasoningStep] = []163# temporary memory for new messages164new_memory = ChatMemoryBuffer.from_defaults()165
166# initialize task state167task_state = {168"sources": sources,169"current_reasoning": current_reasoning,170"new_memory": new_memory,171}172task.extra_state.update(task_state)173
174return TaskStep(175task_id=task.task_id,176step_id=str(uuid.uuid4()),177input=task.input,178step_state={"is_first": True},179)180
181def get_tools(self, input: str) -> List[AsyncBaseTool]:182"""Get tools."""183return [adapt_to_async_tool(t) for t in self._get_tools(input)]184
185def _extract_reasoning_step(186self, output: ChatResponse, is_streaming: bool = False187) -> Tuple[str, List[BaseReasoningStep], bool]:188"""189Extracts the reasoning step from the given output.
190
191This method parses the message content from the output,
192extracts the reasoning step, and determines whether the processing is
193complete. It also performs validation checks on the output and
194handles possible errors.
195"""
196if output.message.content is None:197raise ValueError("Got empty message.")198message_content = output.message.content199current_reasoning = []200try:201reasoning_step = self._output_parser.parse(message_content, is_streaming)202except BaseException as exc:203raise ValueError(f"Could not parse output: {message_content}") from exc204if self._verbose:205print_text(f"{reasoning_step.get_content()}\n", color="pink")206current_reasoning.append(reasoning_step)207
208if reasoning_step.is_done:209return message_content, current_reasoning, True210
211reasoning_step = cast(ActionReasoningStep, reasoning_step)212if not isinstance(reasoning_step, ActionReasoningStep):213raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")214
215return message_content, current_reasoning, False216
217def _process_actions(218self,219task: Task,220tools: Sequence[AsyncBaseTool],221output: ChatResponse,222is_streaming: bool = False,223) -> Tuple[List[BaseReasoningStep], bool]:224tools_dict: Dict[str, AsyncBaseTool] = {225tool.metadata.get_name(): tool for tool in tools226}227_, current_reasoning, is_done = self._extract_reasoning_step(228output, is_streaming229)230
231if is_done:232return current_reasoning, True233
234# call tool with input235reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])236tool = tools_dict[reasoning_step.action]237with self.callback_manager.event(238CBEventType.FUNCTION_CALL,239payload={240EventPayload.FUNCTION_CALL: reasoning_step.action_input,241EventPayload.TOOL: tool.metadata,242},243) as event:244tool_output = tool.call(**reasoning_step.action_input)245event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})246
247task.extra_state["sources"].append(tool_output)248
249observation_step = ObservationReasoningStep(observation=str(tool_output))250current_reasoning.append(observation_step)251if self._verbose:252print_text(f"{observation_step.get_content()}\n", color="blue")253return current_reasoning, False254
255async def _aprocess_actions(256self,257task: Task,258tools: Sequence[AsyncBaseTool],259output: ChatResponse,260is_streaming: bool = False,261) -> Tuple[List[BaseReasoningStep], bool]:262tools_dict = {tool.metadata.name: tool for tool in tools}263_, current_reasoning, is_done = self._extract_reasoning_step(264output, is_streaming265)266
267if is_done:268return current_reasoning, True269
270# call tool with input271reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])272tool = tools_dict[reasoning_step.action]273with self.callback_manager.event(274CBEventType.FUNCTION_CALL,275payload={276EventPayload.FUNCTION_CALL: reasoning_step.action_input,277EventPayload.TOOL: tool.metadata,278},279) as event:280tool_output = await tool.acall(**reasoning_step.action_input)281event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})282
283task.extra_state["sources"].append(tool_output)284
285observation_step = ObservationReasoningStep(observation=str(tool_output))286current_reasoning.append(observation_step)287if self._verbose:288print_text(f"{observation_step.get_content()}\n", color="blue")289return current_reasoning, False290
291def _get_response(292self,293current_reasoning: List[BaseReasoningStep],294sources: List[ToolOutput],295) -> AgentChatResponse:296"""Get response from reasoning steps."""297if len(current_reasoning) == 0:298raise ValueError("No reasoning steps were taken.")299elif len(current_reasoning) == self._max_iterations:300raise ValueError("Reached max iterations.")301
302if isinstance(current_reasoning[-1], ResponseReasoningStep):303response_step = cast(ResponseReasoningStep, current_reasoning[-1])304response_str = response_step.response305else:306response_str = current_reasoning[-1].get_content()307
308# TODO: add sources from reasoning steps309return AgentChatResponse(response=response_str, sources=sources)310
311def _get_task_step_response(312self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool313) -> TaskStepOutput:314"""Get task step response."""315if is_done:316new_steps = []317else:318new_steps = [319step.get_next_step(320step_id=str(uuid.uuid4()),321# NOTE: input is unused322input=None,323)324]325
326return TaskStepOutput(327output=agent_response,328task_step=step,329is_last=is_done,330next_steps=new_steps,331)332
333def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:334"""Infers if a chunk from a live stream is the start of the final335reasoning step. (i.e., and should eventually become
336ResponseReasoningStep — not part of this function's logic tho.).
337
338Args:
339chunk (ChatResponse): the current chunk stream to check
340
341Returns:
342bool: Boolean on whether the chunk is the start of the final response
343"""
344latest_content = chunk.message.content345if latest_content:346if not latest_content.startswith(347"Thought"348): # doesn't follow thought-action format349return True350else:351if "Answer: " in latest_content:352return True353return False354
355def _add_back_chunk_to_stream(356self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]357) -> Generator[ChatResponse, None, None]:358"""Helper method for adding back initial chunk stream of final response359back to the rest of the chat_stream.
360
361Args:
362chunk (ChatResponse): the chunk to add back to the beginning of the
363chat_stream.
364
365Return:
366Generator[ChatResponse, None, None]: the updated chat_stream
367"""
368updated_stream = chain.from_iterable( # need to add back partial response chunk369[370unit_generator(chunk),371chat_stream,372]373)374# use cast to avoid mypy issue with chain and Generator375updated_stream_c: Generator[ChatResponse, None, None] = cast(376Generator[ChatResponse, None, None], updated_stream377)378return updated_stream_c379
380async def _async_add_back_chunk_to_stream(381self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]382) -> AsyncGenerator[ChatResponse, None]:383"""Helper method for adding back initial chunk stream of final response384back to the rest of the chat_stream.
385
386NOTE: this itself is not an async function.
387
388Args:
389chunk (ChatResponse): the chunk to add back to the beginning of the
390chat_stream.
391
392Return:
393AsyncGenerator[ChatResponse, None]: the updated async chat_stream
394"""
395yield chunk396async for item in chat_stream:397yield item398
399def _run_step(400self,401step: TaskStep,402task: Task,403) -> TaskStepOutput:404"""Run step."""405if step.input is not None:406add_user_step_to_reasoning(407step,408task.extra_state["new_memory"],409task.extra_state["current_reasoning"],410verbose=self._verbose,411)412# TODO: see if we want to do step-based inputs413tools = self.get_tools(task.input)414
415input_chat = self._react_chat_formatter.format(416tools,417chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),418current_reasoning=task.extra_state["current_reasoning"],419)420
421# send prompt422chat_response = self._llm.chat(input_chat)423# given react prompt outputs, call tools or return response424reasoning_steps, is_done = self._process_actions(425task, tools, output=chat_response426)427task.extra_state["current_reasoning"].extend(reasoning_steps)428agent_response = self._get_response(429task.extra_state["current_reasoning"], task.extra_state["sources"]430)431if is_done:432task.extra_state["new_memory"].put(433ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)434)435
436return self._get_task_step_response(agent_response, step, is_done)437
438async def _arun_step(439self,440step: TaskStep,441task: Task,442) -> TaskStepOutput:443"""Run step."""444if step.input is not None:445add_user_step_to_reasoning(446step,447task.extra_state["new_memory"],448task.extra_state["current_reasoning"],449verbose=self._verbose,450)451# TODO: see if we want to do step-based inputs452tools = self.get_tools(task.input)453
454input_chat = self._react_chat_formatter.format(455tools,456chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),457current_reasoning=task.extra_state["current_reasoning"],458)459# send prompt460chat_response = await self._llm.achat(input_chat)461# given react prompt outputs, call tools or return response462reasoning_steps, is_done = await self._aprocess_actions(463task, tools, output=chat_response464)465task.extra_state["current_reasoning"].extend(reasoning_steps)466agent_response = self._get_response(467task.extra_state["current_reasoning"], task.extra_state["sources"]468)469if is_done:470task.extra_state["new_memory"].put(471ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)472)473
474return self._get_task_step_response(agent_response, step, is_done)475
476def _run_step_stream(477self,478step: TaskStep,479task: Task,480) -> TaskStepOutput:481"""Run step."""482if step.input is not None:483add_user_step_to_reasoning(484step,485task.extra_state["new_memory"],486task.extra_state["current_reasoning"],487verbose=self._verbose,488)489# TODO: see if we want to do step-based inputs490tools = self.get_tools(task.input)491
492input_chat = self._react_chat_formatter.format(493tools,494chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),495current_reasoning=task.extra_state["current_reasoning"],496)497
498chat_stream = self._llm.stream_chat(input_chat)499
500# iterate over stream, break out if is final answer after the "Answer: "501full_response = ChatResponse(502message=ChatMessage(content=None, role="assistant")503)504is_done = False505for latest_chunk in chat_stream:506full_response = latest_chunk507is_done = self._infer_stream_chunk_is_final(latest_chunk)508if is_done:509break510
511if not is_done:512# given react prompt outputs, call tools or return response513reasoning_steps, _ = self._process_actions(514task, tools=tools, output=full_response, is_streaming=True515)516task.extra_state["current_reasoning"].extend(reasoning_steps)517# use _get_response to return intermediate response518agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(519task.extra_state["current_reasoning"], task.extra_state["sources"]520)521else:522# Get the response in a separate thread so we can yield the response523response_stream = self._add_back_chunk_to_stream(524chunk=latest_chunk, chat_stream=chat_stream525)526
527agent_response = StreamingAgentChatResponse(528chat_stream=response_stream,529sources=task.extra_state["sources"],530)531thread = Thread(532target=agent_response.write_response_to_history,533args=(task.extra_state["new_memory"],),534)535thread.start()536
537return self._get_task_step_response(agent_response, step, is_done)538
539async def _arun_step_stream(540self,541step: TaskStep,542task: Task,543) -> TaskStepOutput:544"""Run step."""545if step.input is not None:546add_user_step_to_reasoning(547step,548task.extra_state["new_memory"],549task.extra_state["current_reasoning"],550verbose=self._verbose,551)552# TODO: see if we want to do step-based inputs553tools = self.get_tools(task.input)554
555input_chat = self._react_chat_formatter.format(556tools,557chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),558current_reasoning=task.extra_state["current_reasoning"],559)560
561chat_stream = await self._llm.astream_chat(input_chat)562
563# iterate over stream, break out if is final answer after the "Answer: "564full_response = ChatResponse(565message=ChatMessage(content=None, role="assistant")566)567is_done = False568async for latest_chunk in chat_stream:569full_response = latest_chunk570is_done = self._infer_stream_chunk_is_final(latest_chunk)571if is_done:572break573
574if not is_done:575# given react prompt outputs, call tools or return response576reasoning_steps, _ = self._process_actions(577task, tools=tools, output=full_response, is_streaming=True578)579task.extra_state["current_reasoning"].extend(reasoning_steps)580# use _get_response to return intermediate response581agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(582task.extra_state["current_reasoning"], task.extra_state["sources"]583)584else:585# Get the response in a separate thread so we can yield the response586response_stream = self._async_add_back_chunk_to_stream(587chunk=latest_chunk, chat_stream=chat_stream588)589
590agent_response = StreamingAgentChatResponse(591achat_stream=response_stream,592sources=task.extra_state["sources"],593)594# create task to write chat response to history595asyncio.create_task(596agent_response.awrite_response_to_history(597task.extra_state["new_memory"]598)599)600# wait until response writing is done601await agent_response._is_function_false_event.wait()602
603return self._get_task_step_response(agent_response, step, is_done)604
605@trace_method("run_step")606def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:607"""Run step."""608return self._run_step(step, task)609
610@trace_method("run_step")611async def arun_step(612self, step: TaskStep, task: Task, **kwargs: Any613) -> TaskStepOutput:614"""Run step (async)."""615return await self._arun_step(step, task)616
617@trace_method("run_step")618def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:619"""Run step (stream)."""620# TODO: figure out if we need a different type for TaskStepOutput621return self._run_step_stream(step, task)622
623@trace_method("run_step")624async def astream_step(625self, step: TaskStep, task: Task, **kwargs: Any626) -> TaskStepOutput:627"""Run step (async stream)."""628return await self._arun_step_stream(step, task)629
630def finalize_task(self, task: Task, **kwargs: Any) -> None:631"""Finalize task, after all the steps are completed."""632# add new messages to memory633task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())634# reset new memory635task.extra_state["new_memory"].reset()636
637def set_callback_manager(self, callback_manager: CallbackManager) -> None:638"""Set callback manager."""639# TODO: make this abstractmethod (right now will break some agent impls)640self.callback_manager = callback_manager641