llama-index

Форк
0
472 строки · 15.3 Кб
1
"""Agent executor."""
2

3
import asyncio
4
from collections import deque
5
from typing import Any, Deque, Dict, List, Optional, Union, cast
6

7
from llama_index.legacy.agent.runner.base import BaseAgentRunner
8
from llama_index.legacy.agent.types import (
9
    BaseAgentWorker,
10
    Task,
11
    TaskStep,
12
    TaskStepOutput,
13
)
14
from llama_index.legacy.bridge.pydantic import BaseModel, Field
15
from llama_index.legacy.callbacks import (
16
    CallbackManager,
17
    CBEventType,
18
    EventPayload,
19
    trace_method,
20
)
21
from llama_index.legacy.chat_engine.types import (
22
    AGENT_CHAT_RESPONSE_TYPE,
23
    AgentChatResponse,
24
    ChatResponseMode,
25
    StreamingAgentChatResponse,
26
)
27
from llama_index.legacy.llms.base import ChatMessage
28
from llama_index.legacy.llms.llm import LLM
29
from llama_index.legacy.memory import BaseMemory, ChatMemoryBuffer
30
from llama_index.legacy.memory.types import BaseMemory
31

32

33
class DAGTaskState(BaseModel):
34
    """DAG Task state."""
35

36
    task: Task = Field(..., description="Task.")
37
    root_step: TaskStep = Field(..., description="Root step.")
38
    step_queue: Deque[TaskStep] = Field(
39
        default_factory=deque, description="Task step queue."
40
    )
41
    completed_steps: List[TaskStepOutput] = Field(
42
        default_factory=list, description="Completed step outputs."
43
    )
44

45
    @property
46
    def task_id(self) -> str:
47
        """Task id."""
48
        return self.task.task_id
49

50

51
class DAGAgentState(BaseModel):
52
    """Agent state."""
53

54
    task_dict: Dict[str, DAGTaskState] = Field(
55
        default_factory=dict, description="Task dictionary."
56
    )
57

58
    def get_task(self, task_id: str) -> Task:
59
        """Get task state."""
60
        return self.task_dict[task_id].task
61

62
    def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
63
        """Get completed steps."""
64
        return self.task_dict[task_id].completed_steps
65

66
    def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
67
        """Get step queue."""
68
        return self.task_dict[task_id].step_queue
69

70

71
class ParallelAgentRunner(BaseAgentRunner):
72
    """Parallel agent runner.
73

74
    Executes steps in queue in parallel. Requires async support.
75

76
    """
77

78
    def __init__(
79
        self,
80
        agent_worker: BaseAgentWorker,
81
        chat_history: Optional[List[ChatMessage]] = None,
82
        state: Optional[DAGAgentState] = None,
83
        memory: Optional[BaseMemory] = None,
84
        llm: Optional[LLM] = None,
85
        callback_manager: Optional[CallbackManager] = None,
86
        init_task_state_kwargs: Optional[dict] = None,
87
        delete_task_on_finish: bool = False,
88
    ) -> None:
89
        """Initialize."""
90
        self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
91
        self.state = state or DAGAgentState()
92
        self.callback_manager = callback_manager or CallbackManager([])
93
        self.init_task_state_kwargs = init_task_state_kwargs or {}
94
        self.agent_worker = agent_worker
95
        self.delete_task_on_finish = delete_task_on_finish
96

97
    @property
98
    def chat_history(self) -> List[ChatMessage]:
99
        return self.memory.get_all()
100

101
    def reset(self) -> None:
102
        self.memory.reset()
103

104
    def create_task(self, input: str, **kwargs: Any) -> Task:
105
        """Create task."""
106
        task = Task(
107
            input=input,
108
            memory=self.memory,
109
            extra_state=self.init_task_state_kwargs,
110
            **kwargs,
111
        )
112
        # # put input into memory
113
        # self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
114

115
        # add it to state
116
        # get initial step from task, and put it in the step queue
117
        initial_step = self.agent_worker.initialize_step(task)
118
        task_state = DAGTaskState(
119
            task=task,
120
            root_step=initial_step,
121
            step_queue=deque([initial_step]),
122
        )
123

124
        self.state.task_dict[task.task_id] = task_state
125

126
        return task
127

128
    def delete_task(
129
        self,
130
        task_id: str,
131
    ) -> None:
132
        """Delete task.
133

134
        NOTE: this will not delete any previous executions from memory.
135

136
        """
137
        self.state.task_dict.pop(task_id)
138

139
    def list_tasks(self, **kwargs: Any) -> List[Task]:
140
        """List tasks."""
141
        task_states = list(self.state.task_dict.values())
142
        return [task_state.task for task_state in task_states]
143

144
    def get_task(self, task_id: str, **kwargs: Any) -> Task:
145
        """Get task."""
146
        return self.state.get_task(task_id)
147

148
    def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
149
        """Get upcoming steps."""
150
        return list(self.state.get_step_queue(task_id))
151

152
    def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
153
        """Get completed steps."""
154
        return self.state.get_completed_steps(task_id)
155

156
    def run_steps_in_queue(
157
        self,
158
        task_id: str,
159
        mode: ChatResponseMode = ChatResponseMode.WAIT,
160
        **kwargs: Any,
161
    ) -> List[TaskStepOutput]:
162
        """Execute steps in queue.
163

164
        Run all steps in queue, clearing it out.
165

166
        Assume that all steps can be run in parallel.
167

168
        """
169
        return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))
170

171
    async def arun_steps_in_queue(
172
        self,
173
        task_id: str,
174
        mode: ChatResponseMode = ChatResponseMode.WAIT,
175
        **kwargs: Any,
176
    ) -> List[TaskStepOutput]:
177
        """Execute all steps in queue.
178

179
        All steps in queue are assumed to be ready.
180

181
        """
182
        # first pop all steps from step_queue
183
        steps: List[TaskStep] = []
184
        while len(self.state.get_step_queue(task_id)) > 0:
185
            steps.append(self.state.get_step_queue(task_id).popleft())
186

187
        # take every item in the queue, and run it
188
        tasks = []
189
        for step in steps:
190
            tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs))
191

192
        return await asyncio.gather(*tasks)
193

194
    def _run_step(
195
        self,
196
        task_id: str,
197
        step: Optional[TaskStep] = None,
198
        mode: ChatResponseMode = ChatResponseMode.WAIT,
199
        **kwargs: Any,
200
    ) -> TaskStepOutput:
201
        """Execute step."""
202
        task = self.state.get_task(task_id)
203
        task_queue = self.state.get_step_queue(task_id)
204
        step = step or task_queue.popleft()
205

206
        if not step.is_ready:
207
            raise ValueError(f"Step {step.step_id} is not ready")
208

209
        if mode == ChatResponseMode.WAIT:
210
            cur_step_output: TaskStepOutput = self.agent_worker.run_step(
211
                step, task, **kwargs
212
            )
213
        elif mode == ChatResponseMode.STREAM:
214
            cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
215
        else:
216
            raise ValueError(f"Invalid mode: {mode}")
217

218
        for next_step in cur_step_output.next_steps:
219
            if next_step.is_ready:
220
                task_queue.append(next_step)
221

222
        # add cur_step_output to completed steps
223
        completed_steps = self.state.get_completed_steps(task_id)
224
        completed_steps.append(cur_step_output)
225

226
        return cur_step_output
227

228
    async def _arun_step(
229
        self,
230
        task_id: str,
231
        step: Optional[TaskStep] = None,
232
        mode: ChatResponseMode = ChatResponseMode.WAIT,
233
        **kwargs: Any,
234
    ) -> TaskStepOutput:
235
        """Execute step."""
236
        task = self.state.get_task(task_id)
237
        task_queue = self.state.get_step_queue(task_id)
238
        step = step or task_queue.popleft()
239

240
        if not step.is_ready:
241
            raise ValueError(f"Step {step.step_id} is not ready")
242

243
        if mode == ChatResponseMode.WAIT:
244
            cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
245
        elif mode == ChatResponseMode.STREAM:
246
            cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
247
        else:
248
            raise ValueError(f"Invalid mode: {mode}")
249

250
        for next_step in cur_step_output.next_steps:
251
            if next_step.is_ready:
252
                task_queue.append(next_step)
253

254
        # add cur_step_output to completed steps
255
        completed_steps = self.state.get_completed_steps(task_id)
256
        completed_steps.append(cur_step_output)
257

258
        return cur_step_output
259

260
    def run_step(
261
        self,
262
        task_id: str,
263
        input: Optional[str] = None,
264
        step: Optional[TaskStep] = None,
265
        **kwargs: Any,
266
    ) -> TaskStepOutput:
267
        """Run step."""
268
        return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs)
269

270
    async def arun_step(
271
        self,
272
        task_id: str,
273
        input: Optional[str] = None,
274
        step: Optional[TaskStep] = None,
275
        **kwargs: Any,
276
    ) -> TaskStepOutput:
277
        """Run step (async)."""
278
        return await self._arun_step(
279
            task_id, step, mode=ChatResponseMode.WAIT, **kwargs
280
        )
281

282
    def stream_step(
283
        self,
284
        task_id: str,
285
        input: Optional[str] = None,
286
        step: Optional[TaskStep] = None,
287
        **kwargs: Any,
288
    ) -> TaskStepOutput:
289
        """Run step (stream)."""
290
        return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs)
291

292
    async def astream_step(
293
        self,
294
        task_id: str,
295
        input: Optional[str] = None,
296
        step: Optional[TaskStep] = None,
297
        **kwargs: Any,
298
    ) -> TaskStepOutput:
299
        """Run step (async stream)."""
300
        return await self._arun_step(
301
            task_id, step, mode=ChatResponseMode.STREAM, **kwargs
302
        )
303

304
    def finalize_response(
305
        self,
306
        task_id: str,
307
        step_output: Optional[TaskStepOutput] = None,
308
    ) -> AGENT_CHAT_RESPONSE_TYPE:
309
        """Finalize response."""
310
        if step_output is None:
311
            step_output = self.state.get_completed_steps(task_id)[-1]
312
        if not step_output.is_last:
313
            raise ValueError(
314
                "finalize_response can only be called on the last step output"
315
            )
316

317
        if not isinstance(
318
            step_output.output,
319
            (AgentChatResponse, StreamingAgentChatResponse),
320
        ):
321
            raise ValueError(
322
                "When `is_last` is True, cur_step_output.output must be "
323
                f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
324
            )
325

326
        # finalize task
327
        self.agent_worker.finalize_task(self.state.get_task(task_id))
328

329
        if self.delete_task_on_finish:
330
            self.delete_task(task_id)
331

332
        return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
333

334
    def _chat(
335
        self,
336
        message: str,
337
        chat_history: Optional[List[ChatMessage]] = None,
338
        tool_choice: Union[str, dict] = "auto",
339
        mode: ChatResponseMode = ChatResponseMode.WAIT,
340
    ) -> AGENT_CHAT_RESPONSE_TYPE:
341
        """Chat with step executor."""
342
        if chat_history is not None:
343
            self.memory.set(chat_history)
344
        task = self.create_task(message)
345

346
        result_output = None
347
        while True:
348
            # pass step queue in as argument, assume step executor is stateless
349
            cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode)
350

351
            # check if a step output is_last
352
            is_last = any(
353
                cur_step_output.is_last for cur_step_output in cur_step_outputs
354
            )
355
            if is_last:
356
                if len(cur_step_outputs) > 1:
357
                    raise ValueError(
358
                        "More than one step output returned in final step."
359
                    )
360
                cur_step_output = cur_step_outputs[0]
361
                result_output = cur_step_output
362
                break
363

364
        return self.finalize_response(task.task_id, result_output)
365

366
    async def _achat(
367
        self,
368
        message: str,
369
        chat_history: Optional[List[ChatMessage]] = None,
370
        tool_choice: Union[str, dict] = "auto",
371
        mode: ChatResponseMode = ChatResponseMode.WAIT,
372
    ) -> AGENT_CHAT_RESPONSE_TYPE:
373
        """Chat with step executor."""
374
        if chat_history is not None:
375
            self.memory.set(chat_history)
376
        task = self.create_task(message)
377

378
        result_output = None
379
        while True:
380
            # pass step queue in as argument, assume step executor is stateless
381
            cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode)
382

383
            # check if a step output is_last
384
            is_last = any(
385
                cur_step_output.is_last for cur_step_output in cur_step_outputs
386
            )
387
            if is_last:
388
                if len(cur_step_outputs) > 1:
389
                    raise ValueError(
390
                        "More than one step output returned in final step."
391
                    )
392
                cur_step_output = cur_step_outputs[0]
393
                result_output = cur_step_output
394
                break
395

396
        return self.finalize_response(task.task_id, result_output)
397

398
    @trace_method("chat")
399
    def chat(
400
        self,
401
        message: str,
402
        chat_history: Optional[List[ChatMessage]] = None,
403
        tool_choice: Union[str, dict] = "auto",
404
    ) -> AgentChatResponse:
405
        with self.callback_manager.event(
406
            CBEventType.AGENT_STEP,
407
            payload={EventPayload.MESSAGES: [message]},
408
        ) as e:
409
            chat_response = self._chat(
410
                message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
411
            )
412
            assert isinstance(chat_response, AgentChatResponse)
413
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
414
        return chat_response
415

416
    @trace_method("chat")
417
    async def achat(
418
        self,
419
        message: str,
420
        chat_history: Optional[List[ChatMessage]] = None,
421
        tool_choice: Union[str, dict] = "auto",
422
    ) -> AgentChatResponse:
423
        with self.callback_manager.event(
424
            CBEventType.AGENT_STEP,
425
            payload={EventPayload.MESSAGES: [message]},
426
        ) as e:
427
            chat_response = await self._achat(
428
                message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
429
            )
430
            assert isinstance(chat_response, AgentChatResponse)
431
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
432
        return chat_response
433

434
    @trace_method("chat")
435
    def stream_chat(
436
        self,
437
        message: str,
438
        chat_history: Optional[List[ChatMessage]] = None,
439
        tool_choice: Union[str, dict] = "auto",
440
    ) -> StreamingAgentChatResponse:
441
        with self.callback_manager.event(
442
            CBEventType.AGENT_STEP,
443
            payload={EventPayload.MESSAGES: [message]},
444
        ) as e:
445
            chat_response = self._chat(
446
                message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
447
            )
448
            assert isinstance(chat_response, StreamingAgentChatResponse)
449
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
450
        return chat_response
451

452
    @trace_method("chat")
453
    async def astream_chat(
454
        self,
455
        message: str,
456
        chat_history: Optional[List[ChatMessage]] = None,
457
        tool_choice: Union[str, dict] = "auto",
458
    ) -> StreamingAgentChatResponse:
459
        with self.callback_manager.event(
460
            CBEventType.AGENT_STEP,
461
            payload={EventPayload.MESSAGES: [message]},
462
        ) as e:
463
            chat_response = await self._achat(
464
                message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
465
            )
466
            assert isinstance(chat_response, StreamingAgentChatResponse)
467
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
468
        return chat_response
469

470
    def undo_step(self, task_id: str) -> None:
471
        """Undo previous step."""
472
        raise NotImplementedError("undo_step not implemented")
473

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.