llama-index

Форк
0
108 строк · 3.7 Кб
1
"""Tool runner component."""
2

3
from typing import Any, Dict, Optional, Sequence, cast
4

5
from llama_index.legacy.bridge.pydantic import Field
6
from llama_index.legacy.callbacks import (
7
    CallbackManager,
8
    CBEventType,
9
    EventPayload,
10
)
11
from llama_index.legacy.callbacks.base import CallbackManager
12
from llama_index.legacy.core.query_pipeline.query_component import (
13
    InputKeys,
14
    OutputKeys,
15
    QueryComponent,
16
    validate_and_convert_stringable,
17
)
18
from llama_index.legacy.tools import AsyncBaseTool, adapt_to_async_tool
19

20

21
class ToolRunnerComponent(QueryComponent):
22
    """Tool runner component that takes in a set of tools."""
23

24
    tool_dict: Dict[str, AsyncBaseTool] = Field(
25
        ..., description="Dictionary of tool names to tools."
26
    )
27
    callback_manager: CallbackManager = Field(
28
        default_factory=lambda: CallbackManager([]), exclude=True
29
    )
30

31
    def __init__(
32
        self,
33
        tools: Sequence[AsyncBaseTool],
34
        callback_manager: Optional[CallbackManager] = None,
35
        **kwargs: Any,
36
    ) -> None:
37
        """Initialize."""
38
        # determine parameters
39
        tool_dict = {tool.metadata.name: adapt_to_async_tool(tool) for tool in tools}
40
        callback_manager = callback_manager or CallbackManager([])
41
        super().__init__(
42
            tool_dict=tool_dict, callback_manager=callback_manager, **kwargs
43
        )
44

45
    class Config:
46
        arbitrary_types_allowed = True
47

48
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
49
        """Set callback manager."""
50
        self.callback_manager = callback_manager
51

52
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
53
        """Validate component inputs during run_component."""
54
        if "tool_name" not in input:
55
            raise ValueError("tool_name must be provided in input")
56

57
        input["tool_name"] = validate_and_convert_stringable(input["tool_name"])
58

59
        if "tool_input" not in input:
60
            raise ValueError("tool_input must be provided in input")
61
        # make sure tool_input is a dictionary
62
        if not isinstance(input["tool_input"], dict):
63
            raise ValueError("tool_input must be a dictionary")
64

65
        return input
66

67
    def _run_component(self, **kwargs: Any) -> Dict:
68
        """Run component."""
69
        tool_name = kwargs["tool_name"]
70
        tool_input = kwargs["tool_input"]
71
        tool = cast(AsyncBaseTool, self.tool_dict[tool_name])
72
        with self.callback_manager.event(
73
            CBEventType.FUNCTION_CALL,
74
            payload={
75
                EventPayload.FUNCTION_CALL: tool_input,
76
                EventPayload.TOOL: tool.metadata,
77
            },
78
        ) as event:
79
            tool_output = tool(**tool_input)
80
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
81

82
        return {"output": tool_output}
83

84
    async def _arun_component(self, **kwargs: Any) -> Any:
85
        """Run component (async)."""
86
        tool_name = kwargs["tool_name"]
87
        tool_input = kwargs["tool_input"]
88
        tool = cast(AsyncBaseTool, self.tool_dict[tool_name])
89
        with self.callback_manager.event(
90
            CBEventType.FUNCTION_CALL,
91
            payload={
92
                EventPayload.FUNCTION_CALL: tool_input,
93
                EventPayload.TOOL: tool.metadata,
94
            },
95
        ) as event:
96
            tool_output = await tool.acall(**tool_input)
97
            event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
98
        return {"output": tool_output}
99

100
    @property
101
    def input_keys(self) -> InputKeys:
102
        """Input keys."""
103
        return InputKeys.from_keys({"tool_name", "tool_input"})
104

105
    @property
106
    def output_keys(self) -> OutputKeys:
107
        """Output keys."""
108
        return OutputKeys.from_keys({"output"})
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.