llama-index

Форк
0
1
"""Agent components."""
2

3
from inspect import signature
4
from typing import Any, Callable, Dict, Optional, Set, Tuple, cast
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.core.query_pipeline.query_component import (
9
    InputKeys,
10
    OutputKeys,
11
    QueryComponent,
12
)
13

14

15
def get_parameters(fn: Callable) -> Tuple[Set[str], Set[str]]:
16
    """Get parameters from function.
17

18
    Returns:
19
        Tuple[Set[str], Set[str]]: required and optional parameters
20

21
    """
22
    # please write function below
23
    params = signature(fn).parameters
24
    required_params = set()
25
    optional_params = set()
26
    for param_name in params:
27
        param_default = params[param_name].default
28
        if param_default is params[param_name].empty:
29
            required_params.add(param_name)
30
        else:
31
            optional_params.add(param_name)
32
    return required_params, optional_params
33

34

35
def default_agent_input_fn(task: Any, state: dict) -> dict:
36
    """Default agent input function."""
37
    from llama_index.legacy.agent.types import Task
38

39
    task = cast(Task, task)
40

41
    return {"input": task.input}
42

43

44
class AgentInputComponent(QueryComponent):
45
    """Takes in agent inputs and transforms it into desired outputs."""
46

47
    fn: Callable = Field(..., description="Function to run.")
48
    async_fn: Optional[Callable] = Field(
49
        None, description="Async function to run. If not provided, will run `fn`."
50
    )
51

52
    _req_params: Set[str] = PrivateAttr()
53
    _opt_params: Set[str] = PrivateAttr()
54

55
    def __init__(
56
        self,
57
        fn: Callable,
58
        async_fn: Optional[Callable] = None,
59
        req_params: Optional[Set[str]] = None,
60
        opt_params: Optional[Set[str]] = None,
61
        **kwargs: Any,
62
    ) -> None:
63
        """Initialize."""
64
        # determine parameters
65
        default_req_params, default_opt_params = get_parameters(fn)
66
        if req_params is None:
67
            req_params = default_req_params
68
        if opt_params is None:
69
            opt_params = default_opt_params
70

71
        self._req_params = req_params
72
        self._opt_params = opt_params
73
        super().__init__(fn=fn, async_fn=async_fn, **kwargs)
74

75
    class Config:
76
        arbitrary_types_allowed = True
77

78
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
79
        """Set callback manager."""
80
        # TODO: implement
81

82
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
83
        """Validate component inputs during run_component."""
84
        from llama_index.legacy.agent.types import Task
85

86
        if "task" not in input:
87
            raise ValueError("Input must have key 'task'")
88
        if not isinstance(input["task"], Task):
89
            raise ValueError("Input must have key 'task' of type Task")
90

91
        if "state" not in input:
92
            raise ValueError("Input must have key 'state'")
93
        if not isinstance(input["state"], dict):
94
            raise ValueError("Input must have key 'state' of type dict")
95

96
        return input
97

98
    def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
99
        """Validate component outputs."""
100
        # NOTE: we override this to do nothing
101
        return output
102

103
    def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
104
        return input
105

106
    def _run_component(self, **kwargs: Any) -> Dict:
107
        """Run component."""
108
        output = self.fn(**kwargs)
109
        if not isinstance(output, dict):
110
            raise ValueError("Output must be a dictionary")
111

112
        return output
113

114
    async def _arun_component(self, **kwargs: Any) -> Any:
115
        """Run component (async)."""
116
        if self.async_fn is None:
117
            return self._run_component(**kwargs)
118
        else:
119
            output = await self.async_fn(**kwargs)
120
            if not isinstance(output, dict):
121
                raise ValueError("Output must be a dictionary")
122
            return output
123

124
    @property
125
    def input_keys(self) -> InputKeys:
126
        """Input keys."""
127
        return InputKeys.from_keys(
128
            required_keys={"task", "state", *self._req_params},
129
            optional_keys=self._opt_params,
130
        )
131

132
    @property
133
    def output_keys(self) -> OutputKeys:
134
        """Output keys."""
135
        # output can be anything, overrode validate function
136
        return OutputKeys.from_keys(set())
137

138

139
class BaseAgentComponent(QueryComponent):
140
    """Agent component.
141

142
    Abstract class used for type checking.
143

144
    """
145

146

147
class AgentFnComponent(BaseAgentComponent):
148
    """Function component for agents.
149

150
    Designed to let users easily modify state.
151

152
    """
153

154
    fn: Callable = Field(..., description="Function to run.")
155
    async_fn: Optional[Callable] = Field(
156
        None, description="Async function to run. If not provided, will run `fn`."
157
    )
158

159
    _req_params: Set[str] = PrivateAttr()
160
    _opt_params: Set[str] = PrivateAttr()
161

162
    def __init__(
163
        self,
164
        fn: Callable,
165
        async_fn: Optional[Callable] = None,
166
        req_params: Optional[Set[str]] = None,
167
        opt_params: Optional[Set[str]] = None,
168
        **kwargs: Any,
169
    ) -> None:
170
        """Initialize."""
171
        # determine parameters
172
        default_req_params, default_opt_params = get_parameters(fn)
173
        # make sure task and step are part of the list, and remove them from the list
174
        if "task" not in default_req_params or "state" not in default_req_params:
175
            raise ValueError(
176
                "AgentFnComponent must have 'task' and 'state' as required parameters"
177
            )
178

179
        default_req_params = default_req_params - {"task", "state"}
180
        default_opt_params = default_opt_params - {"task", "state"}
181

182
        if req_params is None:
183
            req_params = default_req_params
184
        if opt_params is None:
185
            opt_params = default_opt_params
186

187
        self._req_params = req_params
188
        self._opt_params = opt_params
189
        super().__init__(fn=fn, async_fn=async_fn, **kwargs)
190

191
    class Config:
192
        arbitrary_types_allowed = True
193

194
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
195
        """Set callback manager."""
196
        # TODO: implement
197

198
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
199
        """Validate component inputs during run_component."""
200
        from llama_index.legacy.agent.types import Task
201

202
        if "task" not in input:
203
            raise ValueError("Input must have key 'task'")
204
        if not isinstance(input["task"], Task):
205
            raise ValueError("Input must have key 'task' of type Task")
206

207
        if "state" not in input:
208
            raise ValueError("Input must have key 'state'")
209
        if not isinstance(input["state"], dict):
210
            raise ValueError("Input must have key 'state' of type dict")
211

212
        return input
213

214
    def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
215
        """Validate component outputs."""
216
        # NOTE: we override this to do nothing
217
        return output
218

219
    def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
220
        return input
221

222
    def _run_component(self, **kwargs: Any) -> Dict:
223
        """Run component."""
224
        output = self.fn(**kwargs)
225
        # if not isinstance(output, dict):
226
        #     raise ValueError("Output must be a dictionary")
227

228
        return {"output": output}
229

230
    async def _arun_component(self, **kwargs: Any) -> Any:
231
        """Run component (async)."""
232
        if self.async_fn is None:
233
            return self._run_component(**kwargs)
234
        else:
235
            output = await self.async_fn(**kwargs)
236
            # if not isinstance(output, dict):
237
            #     raise ValueError("Output must be a dictionary")
238
            return {"output": output}
239

240
    @property
241
    def input_keys(self) -> InputKeys:
242
        """Input keys."""
243
        return InputKeys.from_keys(
244
            required_keys={"task", "state", *self._req_params},
245
            optional_keys=self._opt_params,
246
        )
247

248
    @property
249
    def output_keys(self) -> OutputKeys:
250
        """Output keys."""
251
        # output can be anything, overrode validate function
252
        return OutputKeys.from_keys({"output"})
253

254

255
class CustomAgentComponent(BaseAgentComponent):
256
    """Custom component for agents.
257

258
    Designed to let users easily modify state.
259

260
    """
261

262
    callback_manager: CallbackManager = Field(
263
        default_factory=CallbackManager, description="Callback manager"
264
    )
265

266
    class Config:
267
        arbitrary_types_allowed = True
268

269
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
270
        """Set callback manager."""
271
        self.callback_manager = callback_manager
272
        # TODO: refactor to put this on base class
273
        for component in self.sub_query_components:
274
            component.set_callback_manager(callback_manager)
275

276
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
277
        """Validate component inputs during run_component."""
278
        # NOTE: user can override this method to validate inputs
279
        # but we do this by default for convenience
280
        return input
281

282
    async def _arun_component(self, **kwargs: Any) -> Any:
283
        """Run component (async)."""
284
        raise NotImplementedError("This component does not support async run.")
285

286
    @property
287
    def _input_keys(self) -> Set[str]:
288
        """Input keys dict."""
289
        raise NotImplementedError("Not implemented yet. Please override this method.")
290

291
    @property
292
    def _optional_input_keys(self) -> Set[str]:
293
        """Optional input keys dict."""
294
        return set()
295

296
    @property
297
    def _output_keys(self) -> Set[str]:
298
        """Output keys dict."""
299
        raise NotImplementedError("Not implemented yet. Please override this method.")
300

301
    @property
302
    def input_keys(self) -> InputKeys:
303
        """Input keys."""
304
        # NOTE: user can override this too, but we have them implement an
305
        # abstract method to make sure they do it
306

307
        input_keys = self._input_keys.union({"task", "state"})
308
        return InputKeys.from_keys(
309
            required_keys=input_keys, optional_keys=self._optional_input_keys
310
        )
311

312
    @property
313
    def output_keys(self) -> OutputKeys:
314
        """Output keys."""
315
        # NOTE: user can override this too, but we have them implement an
316
        # abstract method to make sure they do it
317
        return OutputKeys.from_keys(self._output_keys)
318

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.