llama-index
147 строк · 4.8 Кб
1"""Tool mapping."""
2
3from typing import Any, Dict, Optional, Sequence4
5from llama_index.legacy.objects.base_node_mapping import (6DEFAULT_PERSIST_DIR,7DEFAULT_PERSIST_FNAME,8BaseObjectNodeMapping,9)
10from llama_index.legacy.schema import BaseNode, TextNode11from llama_index.legacy.tools.query_engine import QueryEngineTool12from llama_index.legacy.tools.types import BaseTool13
14
15def convert_tool_to_node(tool: BaseTool) -> TextNode:16"""Function convert Tool to node."""17node_text = (18f"Tool name: {tool.metadata.name}\n"19f"Tool description: {tool.metadata.description}\n"20)21if tool.metadata.fn_schema is not None:22node_text += f"Tool schema: {tool.metadata.fn_schema.schema()}\n"23return TextNode(24text=node_text,25metadata={"name": tool.metadata.name},26excluded_embed_metadata_keys=["name"],27excluded_llm_metadata_keys=["name"],28)29
30
31class BaseToolNodeMapping(BaseObjectNodeMapping[BaseTool]):32"""Base Tool node mapping."""33
34def validate_object(self, obj: BaseTool) -> None:35if not isinstance(obj, BaseTool):36raise ValueError(f"Object must be of type {BaseTool}")37
38@property39def obj_node_mapping(self) -> Dict[int, Any]:40"""The mapping data structure between node and object."""41raise NotImplementedError("Subclasses should implement this!")42
43def persist(44self, persist_dir: str = ..., obj_node_mapping_fname: str = ...45) -> None:46"""Persist objs."""47raise NotImplementedError("Subclasses should implement this!")48
49@classmethod50def from_persist_dir(51cls,52persist_dir: str = DEFAULT_PERSIST_DIR,53obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,54) -> "BaseToolNodeMapping":55raise NotImplementedError(56"This object node mapping does not support persist method."57)58
59
60class SimpleToolNodeMapping(BaseToolNodeMapping):61"""Simple Tool mapping.62
63In this setup, we assume that the tool name is unique, and
64that the list of all tools are stored in memory.
65
66"""
67
68def __init__(self, objs: Optional[Sequence[BaseTool]] = None) -> None:69objs = objs or []70self._tools = {tool.metadata.name: tool for tool in objs}71
72@classmethod73def from_objects(74cls, objs: Sequence[BaseTool], *args: Any, **kwargs: Any75) -> "BaseObjectNodeMapping":76return cls(objs)77
78def _add_object(self, tool: BaseTool) -> None:79self._tools[tool.metadata.name] = tool80
81def to_node(self, tool: BaseTool) -> TextNode:82"""To node."""83return convert_tool_to_node(tool)84
85def _from_node(self, node: BaseNode) -> BaseTool:86"""From node."""87if node.metadata is None:88raise ValueError("Metadata must be set")89return self._tools[node.metadata["name"]]90
91
92class BaseQueryToolNodeMapping(BaseObjectNodeMapping[QueryEngineTool]):93"""Base query tool node mapping."""94
95@classmethod96def from_persist_dir(97cls,98persist_dir: str = DEFAULT_PERSIST_DIR,99obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,100) -> "BaseQueryToolNodeMapping":101raise NotImplementedError(102"This object node mapping does not support persist method."103)104
105@property106def obj_node_mapping(self) -> Dict[int, Any]:107"""The mapping data structure between node and object."""108raise NotImplementedError("Subclasses should implement this!")109
110def persist(111self, persist_dir: str = ..., obj_node_mapping_fname: str = ...112) -> None:113"""Persist objs."""114raise NotImplementedError("Subclasses should implement this!")115
116
117class SimpleQueryToolNodeMapping(BaseQueryToolNodeMapping):118"""Simple query tool mapping."""119
120def __init__(self, objs: Optional[Sequence[QueryEngineTool]] = None) -> None:121objs = objs or []122self._tools = {tool.metadata.name: tool for tool in objs}123
124def validate_object(self, obj: QueryEngineTool) -> None:125if not isinstance(obj, QueryEngineTool):126raise ValueError(f"Object must be of type {QueryEngineTool}")127
128@classmethod129def from_objects(130cls, objs: Sequence[QueryEngineTool], *args: Any, **kwargs: Any131) -> "BaseObjectNodeMapping":132return cls(objs)133
134def _add_object(self, tool: QueryEngineTool) -> None:135if tool.metadata.name is None:136raise ValueError("Tool name must be set")137self._tools[tool.metadata.name] = tool138
139def to_node(self, obj: QueryEngineTool) -> TextNode:140"""To node."""141return convert_tool_to_node(obj)142
143def _from_node(self, node: BaseNode) -> QueryEngineTool:144"""From node."""145if node.metadata is None:146raise ValueError("Metadata must be set")147return self._tools[node.metadata["name"]]148