llama-index
472 строки · 15.3 Кб
1"""Agent executor."""
2
3import asyncio4from collections import deque5from typing import Any, Deque, Dict, List, Optional, Union, cast6
7from llama_index.legacy.agent.runner.base import BaseAgentRunner8from llama_index.legacy.agent.types import (9BaseAgentWorker,10Task,11TaskStep,12TaskStepOutput,13)
14from llama_index.legacy.bridge.pydantic import BaseModel, Field15from llama_index.legacy.callbacks import (16CallbackManager,17CBEventType,18EventPayload,19trace_method,20)
21from llama_index.legacy.chat_engine.types import (22AGENT_CHAT_RESPONSE_TYPE,23AgentChatResponse,24ChatResponseMode,25StreamingAgentChatResponse,26)
27from llama_index.legacy.llms.base import ChatMessage28from llama_index.legacy.llms.llm import LLM29from llama_index.legacy.memory import BaseMemory, ChatMemoryBuffer30from llama_index.legacy.memory.types import BaseMemory31
32
33class DAGTaskState(BaseModel):34"""DAG Task state."""35
36task: Task = Field(..., description="Task.")37root_step: TaskStep = Field(..., description="Root step.")38step_queue: Deque[TaskStep] = Field(39default_factory=deque, description="Task step queue."40)41completed_steps: List[TaskStepOutput] = Field(42default_factory=list, description="Completed step outputs."43)44
45@property46def task_id(self) -> str:47"""Task id."""48return self.task.task_id49
50
51class DAGAgentState(BaseModel):52"""Agent state."""53
54task_dict: Dict[str, DAGTaskState] = Field(55default_factory=dict, description="Task dictionary."56)57
58def get_task(self, task_id: str) -> Task:59"""Get task state."""60return self.task_dict[task_id].task61
62def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:63"""Get completed steps."""64return self.task_dict[task_id].completed_steps65
66def get_step_queue(self, task_id: str) -> Deque[TaskStep]:67"""Get step queue."""68return self.task_dict[task_id].step_queue69
70
71class ParallelAgentRunner(BaseAgentRunner):72"""Parallel agent runner.73
74Executes steps in queue in parallel. Requires async support.
75
76"""
77
78def __init__(79self,80agent_worker: BaseAgentWorker,81chat_history: Optional[List[ChatMessage]] = None,82state: Optional[DAGAgentState] = None,83memory: Optional[BaseMemory] = None,84llm: Optional[LLM] = None,85callback_manager: Optional[CallbackManager] = None,86init_task_state_kwargs: Optional[dict] = None,87delete_task_on_finish: bool = False,88) -> None:89"""Initialize."""90self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)91self.state = state or DAGAgentState()92self.callback_manager = callback_manager or CallbackManager([])93self.init_task_state_kwargs = init_task_state_kwargs or {}94self.agent_worker = agent_worker95self.delete_task_on_finish = delete_task_on_finish96
97@property98def chat_history(self) -> List[ChatMessage]:99return self.memory.get_all()100
101def reset(self) -> None:102self.memory.reset()103
104def create_task(self, input: str, **kwargs: Any) -> Task:105"""Create task."""106task = Task(107input=input,108memory=self.memory,109extra_state=self.init_task_state_kwargs,110**kwargs,111)112# # put input into memory113# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))114
115# add it to state116# get initial step from task, and put it in the step queue117initial_step = self.agent_worker.initialize_step(task)118task_state = DAGTaskState(119task=task,120root_step=initial_step,121step_queue=deque([initial_step]),122)123
124self.state.task_dict[task.task_id] = task_state125
126return task127
128def delete_task(129self,130task_id: str,131) -> None:132"""Delete task.133
134NOTE: this will not delete any previous executions from memory.
135
136"""
137self.state.task_dict.pop(task_id)138
139def list_tasks(self, **kwargs: Any) -> List[Task]:140"""List tasks."""141task_states = list(self.state.task_dict.values())142return [task_state.task for task_state in task_states]143
144def get_task(self, task_id: str, **kwargs: Any) -> Task:145"""Get task."""146return self.state.get_task(task_id)147
148def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:149"""Get upcoming steps."""150return list(self.state.get_step_queue(task_id))151
152def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:153"""Get completed steps."""154return self.state.get_completed_steps(task_id)155
156def run_steps_in_queue(157self,158task_id: str,159mode: ChatResponseMode = ChatResponseMode.WAIT,160**kwargs: Any,161) -> List[TaskStepOutput]:162"""Execute steps in queue.163
164Run all steps in queue, clearing it out.
165
166Assume that all steps can be run in parallel.
167
168"""
169return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))170
171async def arun_steps_in_queue(172self,173task_id: str,174mode: ChatResponseMode = ChatResponseMode.WAIT,175**kwargs: Any,176) -> List[TaskStepOutput]:177"""Execute all steps in queue.178
179All steps in queue are assumed to be ready.
180
181"""
182# first pop all steps from step_queue183steps: List[TaskStep] = []184while len(self.state.get_step_queue(task_id)) > 0:185steps.append(self.state.get_step_queue(task_id).popleft())186
187# take every item in the queue, and run it188tasks = []189for step in steps:190tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs))191
192return await asyncio.gather(*tasks)193
194def _run_step(195self,196task_id: str,197step: Optional[TaskStep] = None,198mode: ChatResponseMode = ChatResponseMode.WAIT,199**kwargs: Any,200) -> TaskStepOutput:201"""Execute step."""202task = self.state.get_task(task_id)203task_queue = self.state.get_step_queue(task_id)204step = step or task_queue.popleft()205
206if not step.is_ready:207raise ValueError(f"Step {step.step_id} is not ready")208
209if mode == ChatResponseMode.WAIT:210cur_step_output: TaskStepOutput = self.agent_worker.run_step(211step, task, **kwargs212)213elif mode == ChatResponseMode.STREAM:214cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)215else:216raise ValueError(f"Invalid mode: {mode}")217
218for next_step in cur_step_output.next_steps:219if next_step.is_ready:220task_queue.append(next_step)221
222# add cur_step_output to completed steps223completed_steps = self.state.get_completed_steps(task_id)224completed_steps.append(cur_step_output)225
226return cur_step_output227
228async def _arun_step(229self,230task_id: str,231step: Optional[TaskStep] = None,232mode: ChatResponseMode = ChatResponseMode.WAIT,233**kwargs: Any,234) -> TaskStepOutput:235"""Execute step."""236task = self.state.get_task(task_id)237task_queue = self.state.get_step_queue(task_id)238step = step or task_queue.popleft()239
240if not step.is_ready:241raise ValueError(f"Step {step.step_id} is not ready")242
243if mode == ChatResponseMode.WAIT:244cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)245elif mode == ChatResponseMode.STREAM:246cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)247else:248raise ValueError(f"Invalid mode: {mode}")249
250for next_step in cur_step_output.next_steps:251if next_step.is_ready:252task_queue.append(next_step)253
254# add cur_step_output to completed steps255completed_steps = self.state.get_completed_steps(task_id)256completed_steps.append(cur_step_output)257
258return cur_step_output259
260def run_step(261self,262task_id: str,263input: Optional[str] = None,264step: Optional[TaskStep] = None,265**kwargs: Any,266) -> TaskStepOutput:267"""Run step."""268return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs)269
270async def arun_step(271self,272task_id: str,273input: Optional[str] = None,274step: Optional[TaskStep] = None,275**kwargs: Any,276) -> TaskStepOutput:277"""Run step (async)."""278return await self._arun_step(279task_id, step, mode=ChatResponseMode.WAIT, **kwargs280)281
282def stream_step(283self,284task_id: str,285input: Optional[str] = None,286step: Optional[TaskStep] = None,287**kwargs: Any,288) -> TaskStepOutput:289"""Run step (stream)."""290return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs)291
292async def astream_step(293self,294task_id: str,295input: Optional[str] = None,296step: Optional[TaskStep] = None,297**kwargs: Any,298) -> TaskStepOutput:299"""Run step (async stream)."""300return await self._arun_step(301task_id, step, mode=ChatResponseMode.STREAM, **kwargs302)303
304def finalize_response(305self,306task_id: str,307step_output: Optional[TaskStepOutput] = None,308) -> AGENT_CHAT_RESPONSE_TYPE:309"""Finalize response."""310if step_output is None:311step_output = self.state.get_completed_steps(task_id)[-1]312if not step_output.is_last:313raise ValueError(314"finalize_response can only be called on the last step output"315)316
317if not isinstance(318step_output.output,319(AgentChatResponse, StreamingAgentChatResponse),320):321raise ValueError(322"When `is_last` is True, cur_step_output.output must be "323f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"324)325
326# finalize task327self.agent_worker.finalize_task(self.state.get_task(task_id))328
329if self.delete_task_on_finish:330self.delete_task(task_id)331
332return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)333
334def _chat(335self,336message: str,337chat_history: Optional[List[ChatMessage]] = None,338tool_choice: Union[str, dict] = "auto",339mode: ChatResponseMode = ChatResponseMode.WAIT,340) -> AGENT_CHAT_RESPONSE_TYPE:341"""Chat with step executor."""342if chat_history is not None:343self.memory.set(chat_history)344task = self.create_task(message)345
346result_output = None347while True:348# pass step queue in as argument, assume step executor is stateless349cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode)350
351# check if a step output is_last352is_last = any(353cur_step_output.is_last for cur_step_output in cur_step_outputs354)355if is_last:356if len(cur_step_outputs) > 1:357raise ValueError(358"More than one step output returned in final step."359)360cur_step_output = cur_step_outputs[0]361result_output = cur_step_output362break363
364return self.finalize_response(task.task_id, result_output)365
366async def _achat(367self,368message: str,369chat_history: Optional[List[ChatMessage]] = None,370tool_choice: Union[str, dict] = "auto",371mode: ChatResponseMode = ChatResponseMode.WAIT,372) -> AGENT_CHAT_RESPONSE_TYPE:373"""Chat with step executor."""374if chat_history is not None:375self.memory.set(chat_history)376task = self.create_task(message)377
378result_output = None379while True:380# pass step queue in as argument, assume step executor is stateless381cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode)382
383# check if a step output is_last384is_last = any(385cur_step_output.is_last for cur_step_output in cur_step_outputs386)387if is_last:388if len(cur_step_outputs) > 1:389raise ValueError(390"More than one step output returned in final step."391)392cur_step_output = cur_step_outputs[0]393result_output = cur_step_output394break395
396return self.finalize_response(task.task_id, result_output)397
398@trace_method("chat")399def chat(400self,401message: str,402chat_history: Optional[List[ChatMessage]] = None,403tool_choice: Union[str, dict] = "auto",404) -> AgentChatResponse:405with self.callback_manager.event(406CBEventType.AGENT_STEP,407payload={EventPayload.MESSAGES: [message]},408) as e:409chat_response = self._chat(410message, chat_history, tool_choice, mode=ChatResponseMode.WAIT411)412assert isinstance(chat_response, AgentChatResponse)413e.on_end(payload={EventPayload.RESPONSE: chat_response})414return chat_response415
416@trace_method("chat")417async def achat(418self,419message: str,420chat_history: Optional[List[ChatMessage]] = None,421tool_choice: Union[str, dict] = "auto",422) -> AgentChatResponse:423with self.callback_manager.event(424CBEventType.AGENT_STEP,425payload={EventPayload.MESSAGES: [message]},426) as e:427chat_response = await self._achat(428message, chat_history, tool_choice, mode=ChatResponseMode.WAIT429)430assert isinstance(chat_response, AgentChatResponse)431e.on_end(payload={EventPayload.RESPONSE: chat_response})432return chat_response433
434@trace_method("chat")435def stream_chat(436self,437message: str,438chat_history: Optional[List[ChatMessage]] = None,439tool_choice: Union[str, dict] = "auto",440) -> StreamingAgentChatResponse:441with self.callback_manager.event(442CBEventType.AGENT_STEP,443payload={EventPayload.MESSAGES: [message]},444) as e:445chat_response = self._chat(446message, chat_history, tool_choice, mode=ChatResponseMode.STREAM447)448assert isinstance(chat_response, StreamingAgentChatResponse)449e.on_end(payload={EventPayload.RESPONSE: chat_response})450return chat_response451
452@trace_method("chat")453async def astream_chat(454self,455message: str,456chat_history: Optional[List[ChatMessage]] = None,457tool_choice: Union[str, dict] = "auto",458) -> StreamingAgentChatResponse:459with self.callback_manager.event(460CBEventType.AGENT_STEP,461payload={EventPayload.MESSAGES: [message]},462) as e:463chat_response = await self._achat(464message, chat_history, tool_choice, mode=ChatResponseMode.STREAM465)466assert isinstance(chat_response, StreamingAgentChatResponse)467e.on_end(payload={EventPayload.RESPONSE: chat_response})468return chat_response469
470def undo_step(self, task_id: str) -> None:471"""Undo previous step."""472raise NotImplementedError("undo_step not implemented")473