llama-index

Форк
0
147 строк · 4.8 Кб
1
"""Tool mapping."""
2

3
from typing import Any, Dict, Optional, Sequence
4

5
from llama_index.legacy.objects.base_node_mapping import (
6
    DEFAULT_PERSIST_DIR,
7
    DEFAULT_PERSIST_FNAME,
8
    BaseObjectNodeMapping,
9
)
10
from llama_index.legacy.schema import BaseNode, TextNode
11
from llama_index.legacy.tools.query_engine import QueryEngineTool
12
from llama_index.legacy.tools.types import BaseTool
13

14

15
def convert_tool_to_node(tool: BaseTool) -> TextNode:
16
    """Function convert Tool to node."""
17
    node_text = (
18
        f"Tool name: {tool.metadata.name}\n"
19
        f"Tool description: {tool.metadata.description}\n"
20
    )
21
    if tool.metadata.fn_schema is not None:
22
        node_text += f"Tool schema: {tool.metadata.fn_schema.schema()}\n"
23
    return TextNode(
24
        text=node_text,
25
        metadata={"name": tool.metadata.name},
26
        excluded_embed_metadata_keys=["name"],
27
        excluded_llm_metadata_keys=["name"],
28
    )
29

30

31
class BaseToolNodeMapping(BaseObjectNodeMapping[BaseTool]):
32
    """Base Tool node mapping."""
33

34
    def validate_object(self, obj: BaseTool) -> None:
35
        if not isinstance(obj, BaseTool):
36
            raise ValueError(f"Object must be of type {BaseTool}")
37

38
    @property
39
    def obj_node_mapping(self) -> Dict[int, Any]:
40
        """The mapping data structure between node and object."""
41
        raise NotImplementedError("Subclasses should implement this!")
42

43
    def persist(
44
        self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
45
    ) -> None:
46
        """Persist objs."""
47
        raise NotImplementedError("Subclasses should implement this!")
48

49
    @classmethod
50
    def from_persist_dir(
51
        cls,
52
        persist_dir: str = DEFAULT_PERSIST_DIR,
53
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
54
    ) -> "BaseToolNodeMapping":
55
        raise NotImplementedError(
56
            "This object node mapping does not support persist method."
57
        )
58

59

60
class SimpleToolNodeMapping(BaseToolNodeMapping):
61
    """Simple Tool mapping.
62

63
    In this setup, we assume that the tool name is unique, and
64
    that the list of all tools are stored in memory.
65

66
    """
67

68
    def __init__(self, objs: Optional[Sequence[BaseTool]] = None) -> None:
69
        objs = objs or []
70
        self._tools = {tool.metadata.name: tool for tool in objs}
71

72
    @classmethod
73
    def from_objects(
74
        cls, objs: Sequence[BaseTool], *args: Any, **kwargs: Any
75
    ) -> "BaseObjectNodeMapping":
76
        return cls(objs)
77

78
    def _add_object(self, tool: BaseTool) -> None:
79
        self._tools[tool.metadata.name] = tool
80

81
    def to_node(self, tool: BaseTool) -> TextNode:
82
        """To node."""
83
        return convert_tool_to_node(tool)
84

85
    def _from_node(self, node: BaseNode) -> BaseTool:
86
        """From node."""
87
        if node.metadata is None:
88
            raise ValueError("Metadata must be set")
89
        return self._tools[node.metadata["name"]]
90

91

92
class BaseQueryToolNodeMapping(BaseObjectNodeMapping[QueryEngineTool]):
93
    """Base query tool node mapping."""
94

95
    @classmethod
96
    def from_persist_dir(
97
        cls,
98
        persist_dir: str = DEFAULT_PERSIST_DIR,
99
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
100
    ) -> "BaseQueryToolNodeMapping":
101
        raise NotImplementedError(
102
            "This object node mapping does not support persist method."
103
        )
104

105
    @property
106
    def obj_node_mapping(self) -> Dict[int, Any]:
107
        """The mapping data structure between node and object."""
108
        raise NotImplementedError("Subclasses should implement this!")
109

110
    def persist(
111
        self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
112
    ) -> None:
113
        """Persist objs."""
114
        raise NotImplementedError("Subclasses should implement this!")
115

116

117
class SimpleQueryToolNodeMapping(BaseQueryToolNodeMapping):
118
    """Simple query tool mapping."""
119

120
    def __init__(self, objs: Optional[Sequence[QueryEngineTool]] = None) -> None:
121
        objs = objs or []
122
        self._tools = {tool.metadata.name: tool for tool in objs}
123

124
    def validate_object(self, obj: QueryEngineTool) -> None:
125
        if not isinstance(obj, QueryEngineTool):
126
            raise ValueError(f"Object must be of type {QueryEngineTool}")
127

128
    @classmethod
129
    def from_objects(
130
        cls, objs: Sequence[QueryEngineTool], *args: Any, **kwargs: Any
131
    ) -> "BaseObjectNodeMapping":
132
        return cls(objs)
133

134
    def _add_object(self, tool: QueryEngineTool) -> None:
135
        if tool.metadata.name is None:
136
            raise ValueError("Tool name must be set")
137
        self._tools[tool.metadata.name] = tool
138

139
    def to_node(self, obj: QueryEngineTool) -> TextNode:
140
        """To node."""
141
        return convert_tool_to_node(obj)
142

143
    def _from_node(self, node: BaseNode) -> QueryEngineTool:
144
        """From node."""
145
        if node.metadata is None:
146
            raise ValueError("Metadata must be set")
147
        return self._tools[node.metadata["name"]]
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.