llama-index

Форк
0
644 строки · 22.5 Кб
1
"""OpenAI agent worker."""
2

3
import asyncio
4
import json
5
import logging
6
import uuid
7
from threading import Thread
8
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
9

10
from llama_index.legacy.agent.openai.utils import resolve_tool_choice
11
from llama_index.legacy.agent.types import (
12
    BaseAgentWorker,
13
    Task,
14
    TaskStep,
15
    TaskStepOutput,
16
)
17
from llama_index.legacy.agent.utils import add_user_step_to_memory
18
from llama_index.legacy.callbacks import (
19
    CallbackManager,
20
    CBEventType,
21
    EventPayload,
22
    trace_method,
23
)
24
from llama_index.legacy.chat_engine.types import (
25
    AGENT_CHAT_RESPONSE_TYPE,
26
    AgentChatResponse,
27
    ChatResponseMode,
28
    StreamingAgentChatResponse,
29
)
30
from llama_index.legacy.core.llms.types import MessageRole
31
from llama_index.legacy.llms.base import ChatMessage, ChatResponse
32
from llama_index.legacy.llms.llm import LLM
33
from llama_index.legacy.llms.openai import OpenAI
34
from llama_index.legacy.llms.openai_utils import OpenAIToolCall
35
from llama_index.legacy.memory import BaseMemory, ChatMemoryBuffer
36
from llama_index.legacy.memory.types import BaseMemory
37
from llama_index.legacy.objects.base import ObjectRetriever
38
from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool
39

40
logger = logging.getLogger(__name__)
41
logger.setLevel(logging.WARNING)
42

43
DEFAULT_MAX_FUNCTION_CALLS = 5
44
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
45

46

47
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
48
    """Get function by name."""
49
    name_to_tool = {tool.metadata.name: tool for tool in tools}
50
    if name not in name_to_tool:
51
        raise ValueError(f"Tool with name {name} not found")
52
    return name_to_tool[name]
53

54

55
def call_tool_with_error_handling(
56
    tool: BaseTool,
57
    input_dict: Dict,
58
    error_message: Optional[str] = None,
59
    raise_error: bool = False,
60
) -> ToolOutput:
61
    """Call tool with error handling.
62

63
    Input is a dictionary with args and kwargs
64

65
    """
66
    try:
67
        return tool(**input_dict)
68
    except Exception as e:
69
        if raise_error:
70
            raise
71
        error_message = error_message or f"Error: {e!s}"
72
        return ToolOutput(
73
            content=error_message,
74
            tool_name=tool.metadata.name,
75
            raw_input={"kwargs": input_dict},
76
            raw_output=e,
77
        )
78

79

80
def call_function(
81
    tools: List[BaseTool],
82
    tool_call: OpenAIToolCall,
83
    verbose: bool = False,
84
) -> Tuple[ChatMessage, ToolOutput]:
85
    """Call a function and return the output as a string."""
86
    # validations to get passed mypy
87
    assert tool_call.id is not None
88
    assert tool_call.function is not None
89
    assert tool_call.function.name is not None
90
    assert tool_call.function.arguments is not None
91

92
    id_ = tool_call.id
93
    function_call = tool_call.function
94
    name = tool_call.function.name
95
    arguments_str = tool_call.function.arguments
96
    if verbose:
97
        print("=== Calling Function ===")
98
        print(f"Calling function: {name} with args: {arguments_str}")
99
    tool = get_function_by_name(tools, name)
100
    argument_dict = json.loads(arguments_str)
101

102
    # Call tool
103
    # Use default error message
104
    output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
105
    if verbose:
106
        print(f"Got output: {output!s}")
107
        print("========================\n")
108
    return (
109
        ChatMessage(
110
            content=str(output),
111
            role=MessageRole.TOOL,
112
            additional_kwargs={
113
                "name": name,
114
                "tool_call_id": id_,
115
            },
116
        ),
117
        output,
118
    )
119

120

121
async def acall_function(
122
    tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
123
) -> Tuple[ChatMessage, ToolOutput]:
124
    """Call a function and return the output as a string."""
125
    # validations to get passed mypy
126
    assert tool_call.id is not None
127
    assert tool_call.function is not None
128
    assert tool_call.function.name is not None
129
    assert tool_call.function.arguments is not None
130

131
    id_ = tool_call.id
132
    function_call = tool_call.function
133
    name = tool_call.function.name
134
    arguments_str = tool_call.function.arguments
135
    if verbose:
136
        print("=== Calling Function ===")
137
        print(f"Calling function: {name} with args: {arguments_str}")
138
    tool = get_function_by_name(tools, name)
139
    async_tool = adapt_to_async_tool(tool)
140
    argument_dict = json.loads(arguments_str)
141
    output = await async_tool.acall(**argument_dict)
142
    if verbose:
143
        print(f"Got output: {output!s}")
144
        print("========================\n")
145
    return (
146
        ChatMessage(
147
            content=str(output),
148
            role=MessageRole.TOOL,
149
            additional_kwargs={
150
                "name": name,
151
                "tool_call_id": id_,
152
            },
153
        ),
154
        output,
155
    )
156

157

158
class OpenAIAgentWorker(BaseAgentWorker):
159
    """OpenAI Agent agent worker."""
160

161
    def __init__(
162
        self,
163
        tools: List[BaseTool],
164
        llm: OpenAI,
165
        prefix_messages: List[ChatMessage],
166
        verbose: bool = False,
167
        max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
168
        callback_manager: Optional[CallbackManager] = None,
169
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
170
    ):
171
        self._llm = llm
172
        self._verbose = verbose
173
        self._max_function_calls = max_function_calls
174
        self.prefix_messages = prefix_messages
175
        self.callback_manager = callback_manager or self._llm.callback_manager
176

177
        if len(tools) > 0 and tool_retriever is not None:
178
            raise ValueError("Cannot specify both tools and tool_retriever")
179
        elif len(tools) > 0:
180
            self._get_tools = lambda _: tools
181
        elif tool_retriever is not None:
182
            tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
183
            self._get_tools = lambda message: tool_retriever_c.retrieve(message)
184
        else:
185
            # no tools
186
            self._get_tools = lambda _: []
187

188
    @classmethod
189
    def from_tools(
190
        cls,
191
        tools: Optional[List[BaseTool]] = None,
192
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
193
        llm: Optional[LLM] = None,
194
        verbose: bool = False,
195
        max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
196
        callback_manager: Optional[CallbackManager] = None,
197
        system_prompt: Optional[str] = None,
198
        prefix_messages: Optional[List[ChatMessage]] = None,
199
        **kwargs: Any,
200
    ) -> "OpenAIAgentWorker":
201
        """Create an OpenAIAgent from a list of tools.
202

203
        Similar to `from_defaults` in other classes, this method will
204
        infer defaults for a variety of parameters, including the LLM,
205
        if they are not specified.
206

207
        """
208
        tools = tools or []
209

210
        llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
211
        if not isinstance(llm, OpenAI):
212
            raise ValueError("llm must be a OpenAI instance")
213

214
        if callback_manager is not None:
215
            llm.callback_manager = callback_manager
216

217
        if not llm.metadata.is_function_calling_model:
218
            raise ValueError(
219
                f"Model name {llm.model} does not support function calling API. "
220
            )
221

222
        if system_prompt is not None:
223
            if prefix_messages is not None:
224
                raise ValueError(
225
                    "Cannot specify both system_prompt and prefix_messages"
226
                )
227
            prefix_messages = [ChatMessage(content=system_prompt, role="system")]
228

229
        prefix_messages = prefix_messages or []
230

231
        return cls(
232
            tools=tools,
233
            tool_retriever=tool_retriever,
234
            llm=llm,
235
            prefix_messages=prefix_messages,
236
            verbose=verbose,
237
            max_function_calls=max_function_calls,
238
            callback_manager=callback_manager,
239
        )
240

241
    def get_all_messages(self, task: Task) -> List[ChatMessage]:
242
        return (
243
            self.prefix_messages
244
            + task.memory.get()
245
            + task.extra_state["new_memory"].get_all()
246
        )
247

248
    def get_latest_tool_calls(self, task: Task) -> Optional[List[OpenAIToolCall]]:
249
        chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all()
250
        return (
251
            chat_history[-1].additional_kwargs.get("tool_calls", None)
252
            if chat_history
253
            else None
254
        )
255

256
    def _get_llm_chat_kwargs(
257
        self,
258
        task: Task,
259
        openai_tools: List[dict],
260
        tool_choice: Union[str, dict] = "auto",
261
    ) -> Dict[str, Any]:
262
        llm_chat_kwargs: dict = {"messages": self.get_all_messages(task)}
263
        if openai_tools:
264
            llm_chat_kwargs.update(
265
                tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
266
            )
267
        return llm_chat_kwargs
268

269
    def _process_message(
270
        self, task: Task, chat_response: ChatResponse
271
    ) -> AgentChatResponse:
272
        ai_message = chat_response.message
273
        task.extra_state["new_memory"].put(ai_message)
274
        return AgentChatResponse(
275
            response=str(ai_message.content), sources=task.extra_state["sources"]
276
        )
277

278
    def _get_stream_ai_response(
279
        self, task: Task, **llm_chat_kwargs: Any
280
    ) -> StreamingAgentChatResponse:
281
        chat_stream_response = StreamingAgentChatResponse(
282
            chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
283
            sources=task.extra_state["sources"],
284
        )
285
        # Get the response in a separate thread so we can yield the response
286
        thread = Thread(
287
            target=chat_stream_response.write_response_to_history,
288
            args=(task.extra_state["new_memory"],),
289
        )
290
        thread.start()
291
        # Wait for the event to be set
292
        chat_stream_response._is_function_not_none_thread_event.wait()
293
        # If it is executing an openAI function, wait for the thread to finish
294
        if chat_stream_response._is_function:
295
            thread.join()
296

297
        # if it's false, return the answer (to stream)
298
        return chat_stream_response
299

300
    async def _get_async_stream_ai_response(
301
        self, task: Task, **llm_chat_kwargs: Any
302
    ) -> StreamingAgentChatResponse:
303
        chat_stream_response = StreamingAgentChatResponse(
304
            achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
305
            sources=task.extra_state["sources"],
306
        )
307
        # create task to write chat response to history
308
        asyncio.create_task(
309
            chat_stream_response.awrite_response_to_history(
310
                task.extra_state["new_memory"]
311
            )
312
        )
313
        # wait until openAI functions stop executing
314
        await chat_stream_response._is_function_false_event.wait()
315
        # return response stream
316
        return chat_stream_response
317

318
    def _get_agent_response(
319
        self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
320
    ) -> AGENT_CHAT_RESPONSE_TYPE:
321
        if mode == ChatResponseMode.WAIT:
322
            chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
323
            return self._process_message(task, chat_response)
324
        elif mode == ChatResponseMode.STREAM:
325
            return self._get_stream_ai_response(task, **llm_chat_kwargs)
326
        else:
327
            raise NotImplementedError
328

329
    async def _get_async_agent_response(
330
        self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
331
    ) -> AGENT_CHAT_RESPONSE_TYPE:
332
        if mode == ChatResponseMode.WAIT:
333
            chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
334
            return self._process_message(task, chat_response)
335
        elif mode == ChatResponseMode.STREAM:
336
            return await self._get_async_stream_ai_response(task, **llm_chat_kwargs)
337
        else:
338
            raise NotImplementedError
339

340
    def _call_function(
341
        self,
342
        tools: List[BaseTool],
343
        tool_call: OpenAIToolCall,
344
        memory: BaseMemory,
345
        sources: List[ToolOutput],
346
    ) -> None:
347
        function_call = tool_call.function
348
        # validations to get passed mypy
349
        assert function_call is not None
350
        assert function_call.name is not None
351
        assert function_call.arguments is not None
352

353
        with self.callback_manager.event(
354
            CBEventType.FUNCTION_CALL,
355
            payload={
356
                EventPayload.FUNCTION_CALL: function_call.arguments,
357
                EventPayload.TOOL: get_function_by_name(
358
                    tools, function_call.name
359
                ).metadata,
360
            },
361
        ) as event:
362
            function_message, tool_output = call_function(
363
                tools, tool_call, verbose=self._verbose
364
            )
365
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
366
        sources.append(tool_output)
367
        memory.put(function_message)
368

369
    async def _acall_function(
370
        self,
371
        tools: List[BaseTool],
372
        tool_call: OpenAIToolCall,
373
        memory: BaseMemory,
374
        sources: List[ToolOutput],
375
    ) -> None:
376
        function_call = tool_call.function
377
        # validations to get passed mypy
378
        assert function_call is not None
379
        assert function_call.name is not None
380
        assert function_call.arguments is not None
381

382
        with self.callback_manager.event(
383
            CBEventType.FUNCTION_CALL,
384
            payload={
385
                EventPayload.FUNCTION_CALL: function_call.arguments,
386
                EventPayload.TOOL: get_function_by_name(
387
                    tools, function_call.name
388
                ).metadata,
389
            },
390
        ) as event:
391
            function_message, tool_output = await acall_function(
392
                tools, tool_call, verbose=self._verbose
393
            )
394
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
395
        sources.append(tool_output)
396
        memory.put(function_message)
397

398
    def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
399
        """Initialize step from task."""
400
        sources: List[ToolOutput] = []
401
        # temporary memory for new messages
402
        new_memory = ChatMemoryBuffer.from_defaults()
403
        # initialize task state
404
        task_state = {
405
            "sources": sources,
406
            "n_function_calls": 0,
407
            "new_memory": new_memory,
408
        }
409
        task.extra_state.update(task_state)
410

411
        return TaskStep(
412
            task_id=task.task_id,
413
            step_id=str(uuid.uuid4()),
414
            input=task.input,
415
        )
416

417
    def _should_continue(
418
        self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
419
    ) -> bool:
420
        if n_function_calls > self._max_function_calls:
421
            return False
422
        if not tool_calls:
423
            return False
424
        return True
425

426
    def get_tools(self, input: str) -> List[BaseTool]:
427
        """Get tools."""
428
        return self._get_tools(input)
429

430
    def _run_step(
431
        self,
432
        step: TaskStep,
433
        task: Task,
434
        mode: ChatResponseMode = ChatResponseMode.WAIT,
435
        tool_choice: Union[str, dict] = "auto",
436
    ) -> TaskStepOutput:
437
        """Run step."""
438
        if step.input is not None:
439
            add_user_step_to_memory(
440
                step, task.extra_state["new_memory"], verbose=self._verbose
441
            )
442
        # TODO: see if we want to do step-based inputs
443
        tools = self.get_tools(task.input)
444
        openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
445

446
        llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
447

448
        agent_chat_response = self._get_agent_response(
449
            task, mode=mode, **llm_chat_kwargs
450
        )
451

452
        # TODO: implement _should_continue
453
        latest_tool_calls = self.get_latest_tool_calls(task) or []
454
        if not self._should_continue(
455
            latest_tool_calls, task.extra_state["n_function_calls"]
456
        ):
457
            is_done = True
458
            new_steps = []
459
            # TODO: return response
460
        else:
461
            is_done = False
462
            for tool_call in latest_tool_calls:
463
                # Some validation
464
                if not isinstance(tool_call, get_args(OpenAIToolCall)):
465
                    raise ValueError("Invalid tool_call object")
466

467
                if tool_call.type != "function":
468
                    raise ValueError("Invalid tool type. Unsupported by OpenAI")
469
                # TODO: maybe execute this with multi-threading
470
                self._call_function(
471
                    tools,
472
                    tool_call,
473
                    task.extra_state["new_memory"],
474
                    task.extra_state["sources"],
475
                )
476
                # change function call to the default value, if a custom function was given
477
                # as an argument (none and auto are predefined by OpenAI)
478
                if tool_choice not in ("auto", "none"):
479
                    tool_choice = "auto"
480
                task.extra_state["n_function_calls"] += 1
481
            new_steps = [
482
                step.get_next_step(
483
                    step_id=str(uuid.uuid4()),
484
                    # NOTE: input is unused
485
                    input=None,
486
                )
487
            ]
488

489
        # attach next step to task
490

491
        return TaskStepOutput(
492
            output=agent_chat_response,
493
            task_step=step,
494
            is_last=is_done,
495
            next_steps=new_steps,
496
        )
497

498
    async def _arun_step(
499
        self,
500
        step: TaskStep,
501
        task: Task,
502
        mode: ChatResponseMode = ChatResponseMode.WAIT,
503
        tool_choice: Union[str, dict] = "auto",
504
    ) -> TaskStepOutput:
505
        """Run step."""
506
        if step.input is not None:
507
            add_user_step_to_memory(
508
                step, task.extra_state["new_memory"], verbose=self._verbose
509
            )
510

511
        # TODO: see if we want to do step-based inputs
512
        tools = self.get_tools(task.input)
513
        openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
514

515
        llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
516
        agent_chat_response = await self._get_async_agent_response(
517
            task, mode=mode, **llm_chat_kwargs
518
        )
519

520
        # TODO: implement _should_continue
521
        latest_tool_calls = self.get_latest_tool_calls(task) or []
522
        if not self._should_continue(
523
            latest_tool_calls, task.extra_state["n_function_calls"]
524
        ):
525
            is_done = True
526

527
        else:
528
            is_done = False
529
            for tool_call in latest_tool_calls:
530
                # Some validation
531
                if not isinstance(tool_call, get_args(OpenAIToolCall)):
532
                    raise ValueError("Invalid tool_call object")
533

534
                if tool_call.type != "function":
535
                    raise ValueError("Invalid tool type. Unsupported by OpenAI")
536
                # TODO: maybe execute this with multi-threading
537
                await self._acall_function(
538
                    tools,
539
                    tool_call,
540
                    task.extra_state["new_memory"],
541
                    task.extra_state["sources"],
542
                )
543
                # change function call to the default value, if a custom function was given
544
                # as an argument (none and auto are predefined by OpenAI)
545
                if tool_choice not in ("auto", "none"):
546
                    tool_choice = "auto"
547
                task.extra_state["n_function_calls"] += 1
548

549
        # generate next step, append to task queue
550
        new_steps = (
551
            [
552
                step.get_next_step(
553
                    step_id=str(uuid.uuid4()),
554
                    # NOTE: input is unused
555
                    input=None,
556
                )
557
            ]
558
            if not is_done
559
            else []
560
        )
561

562
        return TaskStepOutput(
563
            output=agent_chat_response,
564
            task_step=step,
565
            is_last=is_done,
566
            next_steps=new_steps,
567
        )
568

569
    @trace_method("run_step")
570
    def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
571
        """Run step."""
572
        tool_choice = kwargs.get("tool_choice", "auto")
573
        return self._run_step(
574
            step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
575
        )
576

577
    @trace_method("run_step")
578
    async def arun_step(
579
        self, step: TaskStep, task: Task, **kwargs: Any
580
    ) -> TaskStepOutput:
581
        """Run step (async)."""
582
        tool_choice = kwargs.get("tool_choice", "auto")
583
        return await self._arun_step(
584
            step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
585
        )
586

587
    @trace_method("run_step")
588
    def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
589
        """Run step (stream)."""
590
        # TODO: figure out if we need a different type for TaskStepOutput
591
        tool_choice = kwargs.get("tool_choice", "auto")
592
        return self._run_step(
593
            step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
594
        )
595

596
    @trace_method("run_step")
597
    async def astream_step(
598
        self, step: TaskStep, task: Task, **kwargs: Any
599
    ) -> TaskStepOutput:
600
        """Run step (async stream)."""
601
        tool_choice = kwargs.get("tool_choice", "auto")
602
        return await self._arun_step(
603
            step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
604
        )
605

606
    def finalize_task(self, task: Task, **kwargs: Any) -> None:
607
        """Finalize task, after all the steps are completed."""
608
        # add new messages to memory
609
        task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
610
        # reset new memory
611
        task.extra_state["new_memory"].reset()
612

613
    def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
614
        """Undo step from task.
615

616
        If this cannot be implemented, return None.
617

618
        """
619
        raise NotImplementedError("Undo is not yet implemented")
620
        # if len(task.completed_steps) == 0:
621
        #     return None
622

623
        # # pop last step output
624
        # last_step_output = task.completed_steps.pop()
625
        # # add step to the front of the queue
626
        # task.step_queue.appendleft(last_step_output.task_step)
627

628
        # # undo any `step_state` variables that have changed
629
        # last_step_output.step_state["n_function_calls"] -= 1
630

631
        # # TODO: we don't have memory pop capabilities yet
632
        # # # now pop the memory until we get to the state
633
        # # last_step_response = cast(AgentChatResponse, last_step_output.output)
634
        # # while last_step_response != task.memory.:
635
        # #     last_message = last_step_output.task_step.memory.pop()
636
        # #     if last_message == cast(AgentChatResponse, last_step_output.output).response:
637
        # #         break
638

639
        # # while cast(AgentChatResponse, last_step_output.output).response !=
640

641
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
642
        """Set callback manager."""
643
        # TODO: make this abstractmethod (right now will break some agent impls)
644
        self.callback_manager = callback_manager
645

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.