llama-index

Форк
0
120 строк · 4.5 Кб
1
"""Base tool spec class."""
2

3
import asyncio
4
from inspect import signature
5
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
6

7
from llama_index.legacy.bridge.pydantic import BaseModel
8
from llama_index.legacy.tools.function_tool import FunctionTool
9
from llama_index.legacy.tools.types import ToolMetadata
10
from llama_index.legacy.tools.utils import create_schema_from_function
11

12
AsyncCallable = Callable[..., Awaitable[Any]]
13

14

15
# TODO: deprecate the Tuple (there's no use for it)
16
SPEC_FUNCTION_TYPE = Union[str, Tuple[str, str]]
17

18

19
class BaseToolSpec:
20
    """Base tool spec class."""
21

22
    # list of functions that you'd want to convert to spec
23
    spec_functions: List[SPEC_FUNCTION_TYPE]
24

25
    def get_fn_schema_from_fn_name(
26
        self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
27
    ) -> Optional[Type[BaseModel]]:
28
        """Return map from function name.
29

30
        Return type is Optional, meaning that the schema can be None.
31
        In this case, it's up to the downstream tool implementation to infer the schema.
32

33
        """
34
        spec_functions = spec_functions or self.spec_functions
35
        for fn in spec_functions:
36
            if fn == fn_name:
37
                return create_schema_from_function(fn_name, getattr(self, fn_name))
38

39
        raise ValueError(f"Invalid function name: {fn_name}")
40

41
    def get_metadata_from_fn_name(
42
        self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
43
    ) -> Optional[ToolMetadata]:
44
        """Return map from function name.
45

46
        Return type is Optional, meaning that the schema can be None.
47
        In this case, it's up to the downstream tool implementation to infer the schema.
48

49
        """
50
        try:
51
            func = getattr(self, fn_name)
52
        except AttributeError:
53
            return None
54
        name = fn_name
55
        docstring = func.__doc__ or ""
56
        description = f"{name}{signature(func)}\n{docstring}"
57
        fn_schema = self.get_fn_schema_from_fn_name(
58
            fn_name, spec_functions=spec_functions
59
        )
60
        return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
61

62
    def to_tool_list(
63
        self,
64
        spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None,
65
        func_to_metadata_mapping: Optional[Dict[str, ToolMetadata]] = None,
66
    ) -> List[FunctionTool]:
67
        """Convert tool spec to list of tools."""
68
        spec_functions = spec_functions or self.spec_functions
69
        func_to_metadata_mapping = func_to_metadata_mapping or {}
70
        tool_list = []
71
        for func_spec in spec_functions:
72
            func_sync = None
73
            func_async = None
74
            if isinstance(func_spec, str):
75
                func = getattr(self, func_spec)
76
                if asyncio.iscoroutinefunction(func):
77
                    func_async = func
78
                else:
79
                    func_sync = func
80
                metadata = func_to_metadata_mapping.get(func_spec, None)
81
                if metadata is None:
82
                    metadata = self.get_metadata_from_fn_name(func_spec)
83
            elif isinstance(func_spec, tuple) and len(func_spec) == 2:
84
                func_sync = getattr(self, func_spec[0])
85
                func_async = getattr(self, func_spec[1])
86
                metadata = func_to_metadata_mapping.get(func_spec[0], None)
87
                if metadata is None:
88
                    metadata = func_to_metadata_mapping.get(func_spec[1], None)
89
                    if metadata is None:
90
                        metadata = self.get_metadata_from_fn_name(func_spec[0])
91
            else:
92
                raise ValueError(
93
                    "spec_functions must be of type: List[Union[str, Tuple[str, str]]]"
94
                )
95

96
            if func_sync is None:
97
                if func_async is not None:
98
                    func_sync = patch_sync(func_async)
99
                else:
100
                    raise ValueError(
101
                        f"Could not retrieve a function for spec: {func_spec}"
102
                    )
103

104
            tool = FunctionTool.from_defaults(
105
                fn=func_sync,
106
                async_fn=func_async,
107
                tool_metadata=metadata,
108
            )
109
            tool_list.append(tool)
110
        return tool_list
111

112

113
def patch_sync(func_async: AsyncCallable) -> Callable:
114
    """Patch sync function from async function."""
115

116
    def patched_sync(*args: Any, **kwargs: Any) -> Any:
117
        loop = asyncio.get_event_loop()
118
        return loop.run_until_complete(func_async(*args, **kwargs))
119

120
    return patched_sync
121

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.