llama-index
631 строка · 20.7 Кб
1from abc import abstractmethod
2from collections import deque
3from typing import Any, Deque, Dict, List, Optional, Union, cast
4
5from llama_index.legacy.agent.types import (
6BaseAgent,
7BaseAgentWorker,
8Task,
9TaskStep,
10TaskStepOutput,
11)
12from llama_index.legacy.bridge.pydantic import BaseModel, Field
13from llama_index.legacy.callbacks import (
14CallbackManager,
15CBEventType,
16EventPayload,
17trace_method,
18)
19from llama_index.legacy.chat_engine.types import (
20AGENT_CHAT_RESPONSE_TYPE,
21AgentChatResponse,
22ChatResponseMode,
23StreamingAgentChatResponse,
24)
25from llama_index.legacy.llms.base import ChatMessage
26from llama_index.legacy.llms.llm import LLM
27from llama_index.legacy.memory import BaseMemory, ChatMemoryBuffer
28from llama_index.legacy.memory.types import BaseMemory
29from llama_index.legacy.tools.types import BaseTool
30
31
32class BaseAgentRunner(BaseAgent):
33"""Base agent runner."""
34
35@abstractmethod
36def create_task(self, input: str, **kwargs: Any) -> Task:
37"""Create task."""
38
39@abstractmethod
40def delete_task(
41self,
42task_id: str,
43) -> None:
44"""Delete task.
45
46NOTE: this will not delete any previous executions from memory.
47
48"""
49
50@abstractmethod
51def list_tasks(self, **kwargs: Any) -> List[Task]:
52"""List tasks."""
53
54@abstractmethod
55def get_task(self, task_id: str, **kwargs: Any) -> Task:
56"""Get task."""
57
58@abstractmethod
59def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
60"""Get upcoming steps."""
61
62@abstractmethod
63def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
64"""Get completed steps."""
65
66def get_completed_step(
67self, task_id: str, step_id: str, **kwargs: Any
68) -> TaskStepOutput:
69"""Get completed step."""
70# call get_completed_steps, and then find the right task
71completed_steps = self.get_completed_steps(task_id, **kwargs)
72for step_output in completed_steps:
73if step_output.task_step.step_id == step_id:
74return step_output
75raise ValueError(f"Could not find step_id: {step_id}")
76
77@abstractmethod
78def run_step(
79self,
80task_id: str,
81input: Optional[str] = None,
82step: Optional[TaskStep] = None,
83**kwargs: Any,
84) -> TaskStepOutput:
85"""Run step."""
86
87@abstractmethod
88async def arun_step(
89self,
90task_id: str,
91input: Optional[str] = None,
92step: Optional[TaskStep] = None,
93**kwargs: Any,
94) -> TaskStepOutput:
95"""Run step (async)."""
96
97@abstractmethod
98def stream_step(
99self,
100task_id: str,
101input: Optional[str] = None,
102step: Optional[TaskStep] = None,
103**kwargs: Any,
104) -> TaskStepOutput:
105"""Run step (stream)."""
106
107@abstractmethod
108async def astream_step(
109self,
110task_id: str,
111input: Optional[str] = None,
112step: Optional[TaskStep] = None,
113**kwargs: Any,
114) -> TaskStepOutput:
115"""Run step (async stream)."""
116
117@abstractmethod
118def finalize_response(
119self,
120task_id: str,
121step_output: Optional[TaskStepOutput] = None,
122) -> AGENT_CHAT_RESPONSE_TYPE:
123"""Finalize response."""
124
125@abstractmethod
126def undo_step(self, task_id: str) -> None:
127"""Undo previous step."""
128raise NotImplementedError("undo_step not implemented")
129
130
131def validate_step_from_args(
132task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
133) -> Optional[TaskStep]:
134"""Validate step from args."""
135if step is not None:
136if input is not None:
137raise ValueError("Cannot specify both `step` and `input`")
138if not isinstance(step, TaskStep):
139raise ValueError(f"step must be TaskStep: {step}")
140return step
141else:
142return None
143
144
145class TaskState(BaseModel):
146"""Task state."""
147
148task: Task = Field(..., description="Task.")
149step_queue: Deque[TaskStep] = Field(
150default_factory=deque, description="Task step queue."
151)
152completed_steps: List[TaskStepOutput] = Field(
153default_factory=list, description="Completed step outputs."
154)
155
156
157class AgentState(BaseModel):
158"""Agent state."""
159
160task_dict: Dict[str, TaskState] = Field(
161default_factory=dict, description="Task dictionary."
162)
163
164def get_task(self, task_id: str) -> Task:
165"""Get task state."""
166return self.task_dict[task_id].task
167
168def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
169"""Get completed steps."""
170return self.task_dict[task_id].completed_steps
171
172def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
173"""Get step queue."""
174return self.task_dict[task_id].step_queue
175
176def reset(self) -> None:
177"""Reset."""
178self.task_dict = {}
179
180
181class AgentRunner(BaseAgentRunner):
182"""Agent runner.
183
184Top-level agent orchestrator that can create tasks, run each step in a task,
185or run a task e2e. Stores state and keeps track of tasks.
186
187Args:
188agent_worker (BaseAgentWorker): step executor
189chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
190state (Optional[AgentState], optional): agent state. Defaults to None.
191memory (Optional[BaseMemory], optional): memory. Defaults to None.
192llm (Optional[LLM], optional): LLM. Defaults to None.
193callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
194init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
195
196"""
197
198# # TODO: implement this in Pydantic
199
200def __init__(
201self,
202agent_worker: BaseAgentWorker,
203chat_history: Optional[List[ChatMessage]] = None,
204state: Optional[AgentState] = None,
205memory: Optional[BaseMemory] = None,
206llm: Optional[LLM] = None,
207callback_manager: Optional[CallbackManager] = None,
208init_task_state_kwargs: Optional[dict] = None,
209delete_task_on_finish: bool = False,
210default_tool_choice: str = "auto",
211verbose: bool = False,
212) -> None:
213"""Initialize."""
214self.agent_worker = agent_worker
215self.state = state or AgentState()
216self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
217
218# get and set callback manager
219if callback_manager is not None:
220self.agent_worker.set_callback_manager(callback_manager)
221self.callback_manager = callback_manager
222else:
223# TODO: This is *temporary*
224# Stopgap before having a callback on the BaseAgentWorker interface.
225# Doing that requires a bit more refactoring to make sure existing code
226# doesn't break.
227if hasattr(self.agent_worker, "callback_manager"):
228self.callback_manager = (
229self.agent_worker.callback_manager or CallbackManager()
230)
231else:
232self.callback_manager = CallbackManager()
233
234self.init_task_state_kwargs = init_task_state_kwargs or {}
235self.delete_task_on_finish = delete_task_on_finish
236self.default_tool_choice = default_tool_choice
237self.verbose = verbose
238
239@staticmethod
240def from_llm(
241tools: Optional[List[BaseTool]] = None,
242llm: Optional[LLM] = None,
243**kwargs: Any,
244) -> "AgentRunner":
245from llama_index.legacy.llms.openai import OpenAI
246from llama_index.legacy.llms.openai_utils import is_function_calling_model
247
248if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
249from llama_index.legacy.agent import OpenAIAgent
250
251return OpenAIAgent.from_tools(
252tools=tools,
253llm=llm,
254**kwargs,
255)
256else:
257from llama_index.legacy.agent import ReActAgent
258
259return ReActAgent.from_tools(
260tools=tools,
261llm=llm,
262**kwargs,
263)
264
265@property
266def chat_history(self) -> List[ChatMessage]:
267return self.memory.get_all()
268
269def reset(self) -> None:
270self.memory.reset()
271self.state.reset()
272
273def create_task(self, input: str, **kwargs: Any) -> Task:
274"""Create task."""
275if not self.init_task_state_kwargs:
276extra_state = kwargs.pop("extra_state", {})
277else:
278if "extra_state" in kwargs:
279raise ValueError(
280"Cannot specify both `extra_state` and `init_task_state_kwargs`"
281)
282else:
283extra_state = self.init_task_state_kwargs
284
285callback_manager = kwargs.pop("callback_manager", self.callback_manager)
286task = Task(
287input=input,
288memory=self.memory,
289extra_state=extra_state,
290callback_manager=callback_manager,
291**kwargs,
292)
293# # put input into memory
294# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
295
296# get initial step from task, and put it in the step queue
297initial_step = self.agent_worker.initialize_step(task)
298task_state = TaskState(
299task=task,
300step_queue=deque([initial_step]),
301)
302# add it to state
303self.state.task_dict[task.task_id] = task_state
304
305return task
306
307def delete_task(
308self,
309task_id: str,
310) -> None:
311"""Delete task.
312
313NOTE: this will not delete any previous executions from memory.
314
315"""
316self.state.task_dict.pop(task_id)
317
318def list_tasks(self, **kwargs: Any) -> List[Task]:
319"""List tasks."""
320return list(self.state.task_dict.values())
321
322def get_task(self, task_id: str, **kwargs: Any) -> Task:
323"""Get task."""
324return self.state.get_task(task_id)
325
326def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
327"""Get upcoming steps."""
328return list(self.state.get_step_queue(task_id))
329
330def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
331"""Get completed steps."""
332return self.state.get_completed_steps(task_id)
333
334def _run_step(
335self,
336task_id: str,
337step: Optional[TaskStep] = None,
338input: Optional[str] = None,
339mode: ChatResponseMode = ChatResponseMode.WAIT,
340**kwargs: Any,
341) -> TaskStepOutput:
342"""Execute step."""
343task = self.state.get_task(task_id)
344step_queue = self.state.get_step_queue(task_id)
345step = step or step_queue.popleft()
346if input is not None:
347step.input = input
348
349if self.verbose:
350print(f"> Running step {step.step_id}. Step input: {step.input}")
351
352# TODO: figure out if you can dynamically swap in different step executors
353# not clear when you would do that by theoretically possible
354
355if mode == ChatResponseMode.WAIT:
356cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
357elif mode == ChatResponseMode.STREAM:
358cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
359else:
360raise ValueError(f"Invalid mode: {mode}")
361# append cur_step_output next steps to queue
362next_steps = cur_step_output.next_steps
363step_queue.extend(next_steps)
364
365# add cur_step_output to completed steps
366completed_steps = self.state.get_completed_steps(task_id)
367completed_steps.append(cur_step_output)
368
369return cur_step_output
370
371async def _arun_step(
372self,
373task_id: str,
374step: Optional[TaskStep] = None,
375input: Optional[str] = None,
376mode: ChatResponseMode = ChatResponseMode.WAIT,
377**kwargs: Any,
378) -> TaskStepOutput:
379"""Execute step."""
380task = self.state.get_task(task_id)
381step_queue = self.state.get_step_queue(task_id)
382step = step or step_queue.popleft()
383if input is not None:
384step.input = input
385
386if self.verbose:
387print(f"> Running step {step.step_id}. Step input: {step.input}")
388
389# TODO: figure out if you can dynamically swap in different step executors
390# not clear when you would do that by theoretically possible
391if mode == ChatResponseMode.WAIT:
392cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
393elif mode == ChatResponseMode.STREAM:
394cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
395else:
396raise ValueError(f"Invalid mode: {mode}")
397# append cur_step_output next steps to queue
398next_steps = cur_step_output.next_steps
399step_queue.extend(next_steps)
400
401# add cur_step_output to completed steps
402completed_steps = self.state.get_completed_steps(task_id)
403completed_steps.append(cur_step_output)
404
405return cur_step_output
406
407def run_step(
408self,
409task_id: str,
410input: Optional[str] = None,
411step: Optional[TaskStep] = None,
412**kwargs: Any,
413) -> TaskStepOutput:
414"""Run step."""
415step = validate_step_from_args(task_id, input, step, **kwargs)
416return self._run_step(
417task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
418)
419
420async def arun_step(
421self,
422task_id: str,
423input: Optional[str] = None,
424step: Optional[TaskStep] = None,
425**kwargs: Any,
426) -> TaskStepOutput:
427"""Run step (async)."""
428step = validate_step_from_args(task_id, input, step, **kwargs)
429return await self._arun_step(
430task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
431)
432
433def stream_step(
434self,
435task_id: str,
436input: Optional[str] = None,
437step: Optional[TaskStep] = None,
438**kwargs: Any,
439) -> TaskStepOutput:
440"""Run step (stream)."""
441step = validate_step_from_args(task_id, input, step, **kwargs)
442return self._run_step(
443task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
444)
445
446async def astream_step(
447self,
448task_id: str,
449input: Optional[str] = None,
450step: Optional[TaskStep] = None,
451**kwargs: Any,
452) -> TaskStepOutput:
453"""Run step (async stream)."""
454step = validate_step_from_args(task_id, input, step, **kwargs)
455return await self._arun_step(
456task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
457)
458
459def finalize_response(
460self,
461task_id: str,
462step_output: Optional[TaskStepOutput] = None,
463) -> AGENT_CHAT_RESPONSE_TYPE:
464"""Finalize response."""
465if step_output is None:
466step_output = self.state.get_completed_steps(task_id)[-1]
467if not step_output.is_last:
468raise ValueError(
469"finalize_response can only be called on the last step output"
470)
471
472if not isinstance(
473step_output.output,
474(AgentChatResponse, StreamingAgentChatResponse),
475):
476raise ValueError(
477"When `is_last` is True, cur_step_output.output must be "
478f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
479)
480
481# finalize task
482self.agent_worker.finalize_task(self.state.get_task(task_id))
483
484if self.delete_task_on_finish:
485self.delete_task(task_id)
486
487return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
488
489def _chat(
490self,
491message: str,
492chat_history: Optional[List[ChatMessage]] = None,
493tool_choice: Union[str, dict] = "auto",
494mode: ChatResponseMode = ChatResponseMode.WAIT,
495) -> AGENT_CHAT_RESPONSE_TYPE:
496"""Chat with step executor."""
497if chat_history is not None:
498self.memory.set(chat_history)
499task = self.create_task(message)
500
501result_output = None
502while True:
503# pass step queue in as argument, assume step executor is stateless
504cur_step_output = self._run_step(
505task.task_id, mode=mode, tool_choice=tool_choice
506)
507
508if cur_step_output.is_last:
509result_output = cur_step_output
510break
511
512# ensure tool_choice does not cause endless loops
513tool_choice = "auto"
514
515return self.finalize_response(task.task_id, result_output)
516
517async def _achat(
518self,
519message: str,
520chat_history: Optional[List[ChatMessage]] = None,
521tool_choice: Union[str, dict] = "auto",
522mode: ChatResponseMode = ChatResponseMode.WAIT,
523) -> AGENT_CHAT_RESPONSE_TYPE:
524"""Chat with step executor."""
525if chat_history is not None:
526self.memory.set(chat_history)
527task = self.create_task(message)
528
529result_output = None
530while True:
531# pass step queue in as argument, assume step executor is stateless
532cur_step_output = await self._arun_step(
533task.task_id, mode=mode, tool_choice=tool_choice
534)
535
536if cur_step_output.is_last:
537result_output = cur_step_output
538break
539
540# ensure tool_choice does not cause endless loops
541tool_choice = "auto"
542
543return self.finalize_response(task.task_id, result_output)
544
545@trace_method("chat")
546def chat(
547self,
548message: str,
549chat_history: Optional[List[ChatMessage]] = None,
550tool_choice: Optional[Union[str, dict]] = None,
551) -> AgentChatResponse:
552# override tool choice is provided as input.
553if tool_choice is None:
554tool_choice = self.default_tool_choice
555with self.callback_manager.event(
556CBEventType.AGENT_STEP,
557payload={EventPayload.MESSAGES: [message]},
558) as e:
559chat_response = self._chat(
560message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
561)
562assert isinstance(chat_response, AgentChatResponse)
563e.on_end(payload={EventPayload.RESPONSE: chat_response})
564return chat_response
565
566@trace_method("chat")
567async def achat(
568self,
569message: str,
570chat_history: Optional[List[ChatMessage]] = None,
571tool_choice: Optional[Union[str, dict]] = None,
572) -> AgentChatResponse:
573# override tool choice is provided as input.
574if tool_choice is None:
575tool_choice = self.default_tool_choice
576with self.callback_manager.event(
577CBEventType.AGENT_STEP,
578payload={EventPayload.MESSAGES: [message]},
579) as e:
580chat_response = await self._achat(
581message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
582)
583assert isinstance(chat_response, AgentChatResponse)
584e.on_end(payload={EventPayload.RESPONSE: chat_response})
585return chat_response
586
587@trace_method("chat")
588def stream_chat(
589self,
590message: str,
591chat_history: Optional[List[ChatMessage]] = None,
592tool_choice: Optional[Union[str, dict]] = None,
593) -> StreamingAgentChatResponse:
594# override tool choice is provided as input.
595if tool_choice is None:
596tool_choice = self.default_tool_choice
597with self.callback_manager.event(
598CBEventType.AGENT_STEP,
599payload={EventPayload.MESSAGES: [message]},
600) as e:
601chat_response = self._chat(
602message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
603)
604assert isinstance(chat_response, StreamingAgentChatResponse)
605e.on_end(payload={EventPayload.RESPONSE: chat_response})
606return chat_response
607
608@trace_method("chat")
609async def astream_chat(
610self,
611message: str,
612chat_history: Optional[List[ChatMessage]] = None,
613tool_choice: Optional[Union[str, dict]] = None,
614) -> StreamingAgentChatResponse:
615# override tool choice is provided as input.
616if tool_choice is None:
617tool_choice = self.default_tool_choice
618with self.callback_manager.event(
619CBEventType.AGENT_STEP,
620payload={EventPayload.MESSAGES: [message]},
621) as e:
622chat_response = await self._achat(
623message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
624)
625assert isinstance(chat_response, StreamingAgentChatResponse)
626e.on_end(payload={EventPayload.RESPONSE: chat_response})
627return chat_response
628
629def undo_step(self, task_id: str) -> None:
630"""Undo previous step."""
631raise NotImplementedError("undo_step not implemented")
632