llama-index

Форк
0
631 строка · 20.7 Кб
1
from abc import abstractmethod
2
from collections import deque
3
from typing import Any, Deque, Dict, List, Optional, Union, cast
4

5
from llama_index.legacy.agent.types import (
6
    BaseAgent,
7
    BaseAgentWorker,
8
    Task,
9
    TaskStep,
10
    TaskStepOutput,
11
)
12
from llama_index.legacy.bridge.pydantic import BaseModel, Field
13
from llama_index.legacy.callbacks import (
14
    CallbackManager,
15
    CBEventType,
16
    EventPayload,
17
    trace_method,
18
)
19
from llama_index.legacy.chat_engine.types import (
20
    AGENT_CHAT_RESPONSE_TYPE,
21
    AgentChatResponse,
22
    ChatResponseMode,
23
    StreamingAgentChatResponse,
24
)
25
from llama_index.legacy.llms.base import ChatMessage
26
from llama_index.legacy.llms.llm import LLM
27
from llama_index.legacy.memory import BaseMemory, ChatMemoryBuffer
28
from llama_index.legacy.memory.types import BaseMemory
29
from llama_index.legacy.tools.types import BaseTool
30

31

32
class BaseAgentRunner(BaseAgent):
33
    """Base agent runner."""
34

35
    @abstractmethod
36
    def create_task(self, input: str, **kwargs: Any) -> Task:
37
        """Create task."""
38

39
    @abstractmethod
40
    def delete_task(
41
        self,
42
        task_id: str,
43
    ) -> None:
44
        """Delete task.
45

46
        NOTE: this will not delete any previous executions from memory.
47

48
        """
49

50
    @abstractmethod
51
    def list_tasks(self, **kwargs: Any) -> List[Task]:
52
        """List tasks."""
53

54
    @abstractmethod
55
    def get_task(self, task_id: str, **kwargs: Any) -> Task:
56
        """Get task."""
57

58
    @abstractmethod
59
    def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
60
        """Get upcoming steps."""
61

62
    @abstractmethod
63
    def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
64
        """Get completed steps."""
65

66
    def get_completed_step(
67
        self, task_id: str, step_id: str, **kwargs: Any
68
    ) -> TaskStepOutput:
69
        """Get completed step."""
70
        # call get_completed_steps, and then find the right task
71
        completed_steps = self.get_completed_steps(task_id, **kwargs)
72
        for step_output in completed_steps:
73
            if step_output.task_step.step_id == step_id:
74
                return step_output
75
        raise ValueError(f"Could not find step_id: {step_id}")
76

77
    @abstractmethod
78
    def run_step(
79
        self,
80
        task_id: str,
81
        input: Optional[str] = None,
82
        step: Optional[TaskStep] = None,
83
        **kwargs: Any,
84
    ) -> TaskStepOutput:
85
        """Run step."""
86

87
    @abstractmethod
88
    async def arun_step(
89
        self,
90
        task_id: str,
91
        input: Optional[str] = None,
92
        step: Optional[TaskStep] = None,
93
        **kwargs: Any,
94
    ) -> TaskStepOutput:
95
        """Run step (async)."""
96

97
    @abstractmethod
98
    def stream_step(
99
        self,
100
        task_id: str,
101
        input: Optional[str] = None,
102
        step: Optional[TaskStep] = None,
103
        **kwargs: Any,
104
    ) -> TaskStepOutput:
105
        """Run step (stream)."""
106

107
    @abstractmethod
108
    async def astream_step(
109
        self,
110
        task_id: str,
111
        input: Optional[str] = None,
112
        step: Optional[TaskStep] = None,
113
        **kwargs: Any,
114
    ) -> TaskStepOutput:
115
        """Run step (async stream)."""
116

117
    @abstractmethod
118
    def finalize_response(
119
        self,
120
        task_id: str,
121
        step_output: Optional[TaskStepOutput] = None,
122
    ) -> AGENT_CHAT_RESPONSE_TYPE:
123
        """Finalize response."""
124

125
    @abstractmethod
126
    def undo_step(self, task_id: str) -> None:
127
        """Undo previous step."""
128
        raise NotImplementedError("undo_step not implemented")
129

130

131
def validate_step_from_args(
132
    task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
133
) -> Optional[TaskStep]:
134
    """Validate step from args."""
135
    if step is not None:
136
        if input is not None:
137
            raise ValueError("Cannot specify both `step` and `input`")
138
        if not isinstance(step, TaskStep):
139
            raise ValueError(f"step must be TaskStep: {step}")
140
        return step
141
    else:
142
        return None
143

144

145
class TaskState(BaseModel):
146
    """Task state."""
147

148
    task: Task = Field(..., description="Task.")
149
    step_queue: Deque[TaskStep] = Field(
150
        default_factory=deque, description="Task step queue."
151
    )
152
    completed_steps: List[TaskStepOutput] = Field(
153
        default_factory=list, description="Completed step outputs."
154
    )
155

156

157
class AgentState(BaseModel):
158
    """Agent state."""
159

160
    task_dict: Dict[str, TaskState] = Field(
161
        default_factory=dict, description="Task dictionary."
162
    )
163

164
    def get_task(self, task_id: str) -> Task:
165
        """Get task state."""
166
        return self.task_dict[task_id].task
167

168
    def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
169
        """Get completed steps."""
170
        return self.task_dict[task_id].completed_steps
171

172
    def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
173
        """Get step queue."""
174
        return self.task_dict[task_id].step_queue
175

176
    def reset(self) -> None:
177
        """Reset."""
178
        self.task_dict = {}
179

180

181
class AgentRunner(BaseAgentRunner):
182
    """Agent runner.
183

184
    Top-level agent orchestrator that can create tasks, run each step in a task,
185
    or run a task e2e. Stores state and keeps track of tasks.
186

187
    Args:
188
        agent_worker (BaseAgentWorker): step executor
189
        chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
190
        state (Optional[AgentState], optional): agent state. Defaults to None.
191
        memory (Optional[BaseMemory], optional): memory. Defaults to None.
192
        llm (Optional[LLM], optional): LLM. Defaults to None.
193
        callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
194
        init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
195

196
    """
197

198
    # # TODO: implement this in Pydantic
199

200
    def __init__(
201
        self,
202
        agent_worker: BaseAgentWorker,
203
        chat_history: Optional[List[ChatMessage]] = None,
204
        state: Optional[AgentState] = None,
205
        memory: Optional[BaseMemory] = None,
206
        llm: Optional[LLM] = None,
207
        callback_manager: Optional[CallbackManager] = None,
208
        init_task_state_kwargs: Optional[dict] = None,
209
        delete_task_on_finish: bool = False,
210
        default_tool_choice: str = "auto",
211
        verbose: bool = False,
212
    ) -> None:
213
        """Initialize."""
214
        self.agent_worker = agent_worker
215
        self.state = state or AgentState()
216
        self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
217

218
        # get and set callback manager
219
        if callback_manager is not None:
220
            self.agent_worker.set_callback_manager(callback_manager)
221
            self.callback_manager = callback_manager
222
        else:
223
            # TODO: This is *temporary*
224
            # Stopgap before having a callback on the BaseAgentWorker interface.
225
            # Doing that requires a bit more refactoring to make sure existing code
226
            # doesn't break.
227
            if hasattr(self.agent_worker, "callback_manager"):
228
                self.callback_manager = (
229
                    self.agent_worker.callback_manager or CallbackManager()
230
                )
231
            else:
232
                self.callback_manager = CallbackManager()
233

234
        self.init_task_state_kwargs = init_task_state_kwargs or {}
235
        self.delete_task_on_finish = delete_task_on_finish
236
        self.default_tool_choice = default_tool_choice
237
        self.verbose = verbose
238

239
    @staticmethod
240
    def from_llm(
241
        tools: Optional[List[BaseTool]] = None,
242
        llm: Optional[LLM] = None,
243
        **kwargs: Any,
244
    ) -> "AgentRunner":
245
        from llama_index.legacy.llms.openai import OpenAI
246
        from llama_index.legacy.llms.openai_utils import is_function_calling_model
247

248
        if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
249
            from llama_index.legacy.agent import OpenAIAgent
250

251
            return OpenAIAgent.from_tools(
252
                tools=tools,
253
                llm=llm,
254
                **kwargs,
255
            )
256
        else:
257
            from llama_index.legacy.agent import ReActAgent
258

259
            return ReActAgent.from_tools(
260
                tools=tools,
261
                llm=llm,
262
                **kwargs,
263
            )
264

265
    @property
266
    def chat_history(self) -> List[ChatMessage]:
267
        return self.memory.get_all()
268

269
    def reset(self) -> None:
270
        self.memory.reset()
271
        self.state.reset()
272

273
    def create_task(self, input: str, **kwargs: Any) -> Task:
274
        """Create task."""
275
        if not self.init_task_state_kwargs:
276
            extra_state = kwargs.pop("extra_state", {})
277
        else:
278
            if "extra_state" in kwargs:
279
                raise ValueError(
280
                    "Cannot specify both `extra_state` and `init_task_state_kwargs`"
281
                )
282
            else:
283
                extra_state = self.init_task_state_kwargs
284

285
        callback_manager = kwargs.pop("callback_manager", self.callback_manager)
286
        task = Task(
287
            input=input,
288
            memory=self.memory,
289
            extra_state=extra_state,
290
            callback_manager=callback_manager,
291
            **kwargs,
292
        )
293
        # # put input into memory
294
        # self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
295

296
        # get initial step from task, and put it in the step queue
297
        initial_step = self.agent_worker.initialize_step(task)
298
        task_state = TaskState(
299
            task=task,
300
            step_queue=deque([initial_step]),
301
        )
302
        # add it to state
303
        self.state.task_dict[task.task_id] = task_state
304

305
        return task
306

307
    def delete_task(
308
        self,
309
        task_id: str,
310
    ) -> None:
311
        """Delete task.
312

313
        NOTE: this will not delete any previous executions from memory.
314

315
        """
316
        self.state.task_dict.pop(task_id)
317

318
    def list_tasks(self, **kwargs: Any) -> List[Task]:
319
        """List tasks."""
320
        return list(self.state.task_dict.values())
321

322
    def get_task(self, task_id: str, **kwargs: Any) -> Task:
323
        """Get task."""
324
        return self.state.get_task(task_id)
325

326
    def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
327
        """Get upcoming steps."""
328
        return list(self.state.get_step_queue(task_id))
329

330
    def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
331
        """Get completed steps."""
332
        return self.state.get_completed_steps(task_id)
333

334
    def _run_step(
335
        self,
336
        task_id: str,
337
        step: Optional[TaskStep] = None,
338
        input: Optional[str] = None,
339
        mode: ChatResponseMode = ChatResponseMode.WAIT,
340
        **kwargs: Any,
341
    ) -> TaskStepOutput:
342
        """Execute step."""
343
        task = self.state.get_task(task_id)
344
        step_queue = self.state.get_step_queue(task_id)
345
        step = step or step_queue.popleft()
346
        if input is not None:
347
            step.input = input
348

349
        if self.verbose:
350
            print(f"> Running step {step.step_id}. Step input: {step.input}")
351

352
        # TODO: figure out if you can dynamically swap in different step executors
353
        # not clear when you would do that by theoretically possible
354

355
        if mode == ChatResponseMode.WAIT:
356
            cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
357
        elif mode == ChatResponseMode.STREAM:
358
            cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
359
        else:
360
            raise ValueError(f"Invalid mode: {mode}")
361
        # append cur_step_output next steps to queue
362
        next_steps = cur_step_output.next_steps
363
        step_queue.extend(next_steps)
364

365
        # add cur_step_output to completed steps
366
        completed_steps = self.state.get_completed_steps(task_id)
367
        completed_steps.append(cur_step_output)
368

369
        return cur_step_output
370

371
    async def _arun_step(
372
        self,
373
        task_id: str,
374
        step: Optional[TaskStep] = None,
375
        input: Optional[str] = None,
376
        mode: ChatResponseMode = ChatResponseMode.WAIT,
377
        **kwargs: Any,
378
    ) -> TaskStepOutput:
379
        """Execute step."""
380
        task = self.state.get_task(task_id)
381
        step_queue = self.state.get_step_queue(task_id)
382
        step = step or step_queue.popleft()
383
        if input is not None:
384
            step.input = input
385

386
        if self.verbose:
387
            print(f"> Running step {step.step_id}. Step input: {step.input}")
388

389
        # TODO: figure out if you can dynamically swap in different step executors
390
        # not clear when you would do that by theoretically possible
391
        if mode == ChatResponseMode.WAIT:
392
            cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
393
        elif mode == ChatResponseMode.STREAM:
394
            cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
395
        else:
396
            raise ValueError(f"Invalid mode: {mode}")
397
        # append cur_step_output next steps to queue
398
        next_steps = cur_step_output.next_steps
399
        step_queue.extend(next_steps)
400

401
        # add cur_step_output to completed steps
402
        completed_steps = self.state.get_completed_steps(task_id)
403
        completed_steps.append(cur_step_output)
404

405
        return cur_step_output
406

407
    def run_step(
408
        self,
409
        task_id: str,
410
        input: Optional[str] = None,
411
        step: Optional[TaskStep] = None,
412
        **kwargs: Any,
413
    ) -> TaskStepOutput:
414
        """Run step."""
415
        step = validate_step_from_args(task_id, input, step, **kwargs)
416
        return self._run_step(
417
            task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
418
        )
419

420
    async def arun_step(
421
        self,
422
        task_id: str,
423
        input: Optional[str] = None,
424
        step: Optional[TaskStep] = None,
425
        **kwargs: Any,
426
    ) -> TaskStepOutput:
427
        """Run step (async)."""
428
        step = validate_step_from_args(task_id, input, step, **kwargs)
429
        return await self._arun_step(
430
            task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
431
        )
432

433
    def stream_step(
434
        self,
435
        task_id: str,
436
        input: Optional[str] = None,
437
        step: Optional[TaskStep] = None,
438
        **kwargs: Any,
439
    ) -> TaskStepOutput:
440
        """Run step (stream)."""
441
        step = validate_step_from_args(task_id, input, step, **kwargs)
442
        return self._run_step(
443
            task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
444
        )
445

446
    async def astream_step(
447
        self,
448
        task_id: str,
449
        input: Optional[str] = None,
450
        step: Optional[TaskStep] = None,
451
        **kwargs: Any,
452
    ) -> TaskStepOutput:
453
        """Run step (async stream)."""
454
        step = validate_step_from_args(task_id, input, step, **kwargs)
455
        return await self._arun_step(
456
            task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
457
        )
458

459
    def finalize_response(
460
        self,
461
        task_id: str,
462
        step_output: Optional[TaskStepOutput] = None,
463
    ) -> AGENT_CHAT_RESPONSE_TYPE:
464
        """Finalize response."""
465
        if step_output is None:
466
            step_output = self.state.get_completed_steps(task_id)[-1]
467
        if not step_output.is_last:
468
            raise ValueError(
469
                "finalize_response can only be called on the last step output"
470
            )
471

472
        if not isinstance(
473
            step_output.output,
474
            (AgentChatResponse, StreamingAgentChatResponse),
475
        ):
476
            raise ValueError(
477
                "When `is_last` is True, cur_step_output.output must be "
478
                f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
479
            )
480

481
        # finalize task
482
        self.agent_worker.finalize_task(self.state.get_task(task_id))
483

484
        if self.delete_task_on_finish:
485
            self.delete_task(task_id)
486

487
        return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
488

489
    def _chat(
490
        self,
491
        message: str,
492
        chat_history: Optional[List[ChatMessage]] = None,
493
        tool_choice: Union[str, dict] = "auto",
494
        mode: ChatResponseMode = ChatResponseMode.WAIT,
495
    ) -> AGENT_CHAT_RESPONSE_TYPE:
496
        """Chat with step executor."""
497
        if chat_history is not None:
498
            self.memory.set(chat_history)
499
        task = self.create_task(message)
500

501
        result_output = None
502
        while True:
503
            # pass step queue in as argument, assume step executor is stateless
504
            cur_step_output = self._run_step(
505
                task.task_id, mode=mode, tool_choice=tool_choice
506
            )
507

508
            if cur_step_output.is_last:
509
                result_output = cur_step_output
510
                break
511

512
            # ensure tool_choice does not cause endless loops
513
            tool_choice = "auto"
514

515
        return self.finalize_response(task.task_id, result_output)
516

517
    async def _achat(
518
        self,
519
        message: str,
520
        chat_history: Optional[List[ChatMessage]] = None,
521
        tool_choice: Union[str, dict] = "auto",
522
        mode: ChatResponseMode = ChatResponseMode.WAIT,
523
    ) -> AGENT_CHAT_RESPONSE_TYPE:
524
        """Chat with step executor."""
525
        if chat_history is not None:
526
            self.memory.set(chat_history)
527
        task = self.create_task(message)
528

529
        result_output = None
530
        while True:
531
            # pass step queue in as argument, assume step executor is stateless
532
            cur_step_output = await self._arun_step(
533
                task.task_id, mode=mode, tool_choice=tool_choice
534
            )
535

536
            if cur_step_output.is_last:
537
                result_output = cur_step_output
538
                break
539

540
            # ensure tool_choice does not cause endless loops
541
            tool_choice = "auto"
542

543
        return self.finalize_response(task.task_id, result_output)
544

545
    @trace_method("chat")
546
    def chat(
547
        self,
548
        message: str,
549
        chat_history: Optional[List[ChatMessage]] = None,
550
        tool_choice: Optional[Union[str, dict]] = None,
551
    ) -> AgentChatResponse:
552
        # override tool choice is provided as input.
553
        if tool_choice is None:
554
            tool_choice = self.default_tool_choice
555
        with self.callback_manager.event(
556
            CBEventType.AGENT_STEP,
557
            payload={EventPayload.MESSAGES: [message]},
558
        ) as e:
559
            chat_response = self._chat(
560
                message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
561
            )
562
            assert isinstance(chat_response, AgentChatResponse)
563
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
564
        return chat_response
565

566
    @trace_method("chat")
567
    async def achat(
568
        self,
569
        message: str,
570
        chat_history: Optional[List[ChatMessage]] = None,
571
        tool_choice: Optional[Union[str, dict]] = None,
572
    ) -> AgentChatResponse:
573
        # override tool choice is provided as input.
574
        if tool_choice is None:
575
            tool_choice = self.default_tool_choice
576
        with self.callback_manager.event(
577
            CBEventType.AGENT_STEP,
578
            payload={EventPayload.MESSAGES: [message]},
579
        ) as e:
580
            chat_response = await self._achat(
581
                message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
582
            )
583
            assert isinstance(chat_response, AgentChatResponse)
584
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
585
        return chat_response
586

587
    @trace_method("chat")
588
    def stream_chat(
589
        self,
590
        message: str,
591
        chat_history: Optional[List[ChatMessage]] = None,
592
        tool_choice: Optional[Union[str, dict]] = None,
593
    ) -> StreamingAgentChatResponse:
594
        # override tool choice is provided as input.
595
        if tool_choice is None:
596
            tool_choice = self.default_tool_choice
597
        with self.callback_manager.event(
598
            CBEventType.AGENT_STEP,
599
            payload={EventPayload.MESSAGES: [message]},
600
        ) as e:
601
            chat_response = self._chat(
602
                message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
603
            )
604
            assert isinstance(chat_response, StreamingAgentChatResponse)
605
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
606
        return chat_response
607

608
    @trace_method("chat")
609
    async def astream_chat(
610
        self,
611
        message: str,
612
        chat_history: Optional[List[ChatMessage]] = None,
613
        tool_choice: Optional[Union[str, dict]] = None,
614
    ) -> StreamingAgentChatResponse:
615
        # override tool choice is provided as input.
616
        if tool_choice is None:
617
            tool_choice = self.default_tool_choice
618
        with self.callback_manager.event(
619
            CBEventType.AGENT_STEP,
620
            payload={EventPayload.MESSAGES: [message]},
621
        ) as e:
622
            chat_response = await self._achat(
623
                message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
624
            )
625
            assert isinstance(chat_response, StreamingAgentChatResponse)
626
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
627
        return chat_response
628

629
    def undo_step(self, task_id: str) -> None:
630
        """Undo previous step."""
631
        raise NotImplementedError("undo_step not implemented")
632

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.