llama-index

Форк
0
217 строк · 7.5 Кб
1
"""Query plan tool."""
2

3
from typing import Any, Dict, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import BaseModel, Field
6
from llama_index.legacy.response_synthesizers import (
7
    BaseSynthesizer,
8
    get_response_synthesizer,
9
)
10
from llama_index.legacy.schema import NodeWithScore, TextNode
11
from llama_index.legacy.tools.types import BaseTool, ToolMetadata, ToolOutput
12
from llama_index.legacy.utils import print_text
13

14
DEFAULT_NAME = "query_plan_tool"
15

16
QUERYNODE_QUERY_STR_DESC = """\
17
Question we are asking. This is the query string that will be executed. \
18
"""
19

20
QUERYNODE_TOOL_NAME_DESC = """\
21
Name of the tool to execute the `query_str`. \
22
Should NOT be specified if there are subquestions to be specified, in which \
23
case child_nodes should be nonempty instead.\
24
"""
25

26
QUERYNODE_DEPENDENCIES_DESC = """\
27
List of sub-questions that need to be answered in order \
28
to answer the question given by `query_str`.\
29
Should be blank if there are no sub-questions to be specified, in which case \
30
`tool_name` is specified.\
31
"""
32

33

34
class QueryNode(BaseModel):
35
    """Query node.
36

37
    A query node represents a query (query_str) that must be answered.
38
    It can either be answered by a tool (tool_name), or by a list of child nodes
39
    (child_nodes).
40
    The tool_name and child_nodes fields are mutually exclusive.
41

42
    """
43

44
    # NOTE: inspired from https://github.com/jxnl/openai_function_call/pull/3/files
45

46
    id: int = Field(..., description="ID of the query node.")
47
    query_str: str = Field(..., description=QUERYNODE_QUERY_STR_DESC)
48
    tool_name: Optional[str] = Field(
49
        default=None, description="Name of the tool to execute the `query_str`."
50
    )
51
    dependencies: List[int] = Field(
52
        default_factory=list, description=QUERYNODE_DEPENDENCIES_DESC
53
    )
54

55

56
class QueryPlan(BaseModel):
57
    """Query plan.
58

59
    Contains a list of QueryNode objects (which is a recursive object).
60
    Out of the list of QueryNode objects, one of them must be the root node.
61
    The root node is the one that isn't a dependency of any other node.
62

63
    """
64

65
    nodes: List[QueryNode] = Field(
66
        ...,
67
        description="The original question we are asking.",
68
    )
69

70

71
DEFAULT_DESCRIPTION_PREFIX = """\
72
This is a query plan tool that takes in a list of tools and executes a \
73
query plan over these tools to answer a query. The query plan is a DAG of query nodes.
74

75
Given a list of tool names and the query plan schema, you \
76
can choose to generate a query plan to answer a question.
77

78
The tool names and descriptions are as follows:
79
"""
80

81

82
class QueryPlanTool(BaseTool):
83
    """Query plan tool.
84

85
    A tool that takes in a list of tools and executes a query plan.
86

87
    """
88

89
    def __init__(
90
        self,
91
        query_engine_tools: List[BaseTool],
92
        response_synthesizer: BaseSynthesizer,
93
        name: str,
94
        description_prefix: str,
95
    ) -> None:
96
        """Initialize."""
97
        self._query_tools_dict = {t.metadata.name: t for t in query_engine_tools}
98
        self._response_synthesizer = response_synthesizer
99
        self._name = name
100
        self._description_prefix = description_prefix
101

102
    @classmethod
103
    def from_defaults(
104
        cls,
105
        query_engine_tools: List[BaseTool],
106
        response_synthesizer: Optional[BaseSynthesizer] = None,
107
        name: Optional[str] = None,
108
        description_prefix: Optional[str] = None,
109
    ) -> "QueryPlanTool":
110
        """Initialize from defaults."""
111
        name = name or DEFAULT_NAME
112
        description_prefix = description_prefix or DEFAULT_DESCRIPTION_PREFIX
113
        response_synthesizer = response_synthesizer or get_response_synthesizer()
114

115
        return cls(
116
            query_engine_tools=query_engine_tools,
117
            response_synthesizer=response_synthesizer,
118
            name=name,
119
            description_prefix=description_prefix,
120
        )
121

122
    @property
123
    def metadata(self) -> ToolMetadata:
124
        """Metadata."""
125
        tools_description = "\n\n".join(
126
            [
127
                f"Tool Name: {tool.metadata.name}\n"
128
                + f"Tool Description: {tool.metadata.description} "
129
                for tool in self._query_tools_dict.values()
130
            ]
131
        )
132
        # TODO: fill in description with query engine tools.
133
        description = f"""\
134
        {self._description_prefix}\n\n
135
        {tools_description}
136
        """
137
        return ToolMetadata(description, self._name, fn_schema=QueryPlan)
138

139
    def _execute_node(
140
        self, node: QueryNode, nodes_dict: Dict[int, QueryNode]
141
    ) -> ToolOutput:
142
        """Execute node."""
143
        print_text(f"Executing node {node.json()}\n", color="blue")
144
        if len(node.dependencies) > 0:
145
            print_text(
146
                f"Executing {len(node.dependencies)} child nodes\n", color="pink"
147
            )
148
            child_query_nodes: List[QueryNode] = [
149
                nodes_dict[dep] for dep in node.dependencies
150
            ]
151
            # execute the child nodes first
152
            child_responses: List[ToolOutput] = [
153
                self._execute_node(child, nodes_dict) for child in child_query_nodes
154
            ]
155
            # form the child Node/NodeWithScore objects
156
            child_nodes = []
157
            for child_query_node, child_response in zip(
158
                child_query_nodes, child_responses
159
            ):
160
                node_text = (
161
                    f"Query: {child_query_node.query_str}\n"
162
                    f"Response: {child_response!s}\n"
163
                )
164
                child_node = TextNode(text=node_text)
165
                child_nodes.append(child_node)
166
            # use response synthesizer to combine results
167
            child_nodes_with_scores = [
168
                NodeWithScore(node=n, score=1.0) for n in child_nodes
169
            ]
170
            response_obj = self._response_synthesizer.synthesize(
171
                query=node.query_str,
172
                nodes=child_nodes_with_scores,
173
            )
174
            response = ToolOutput(
175
                content=str(response_obj),
176
                tool_name=node.query_str,
177
                raw_input={"query": node.query_str},
178
                raw_output=response_obj,
179
            )
180

181
        else:
182
            # this is a leaf request, execute the query string using the specified tool
183
            tool = self._query_tools_dict[node.tool_name]
184
            print_text(f"Selected Tool: {tool.metadata}\n", color="pink")
185
            response = tool(node.query_str)
186
        print_text(
187
            "Executed query, got response.\n"
188
            f"Query: {node.query_str}\n"
189
            f"Response: {response!s}\n",
190
            color="blue",
191
        )
192
        return response
193

194
    def _find_root_nodes(self, nodes_dict: Dict[int, QueryNode]) -> List[QueryNode]:
195
        """Find root node."""
196
        # the root node is the one that isn't a dependency of any other node
197
        node_counts = {node_id: 0 for node_id in nodes_dict}
198
        for node in nodes_dict.values():
199
            for dep in node.dependencies:
200
                node_counts[dep] += 1
201
        root_node_ids = [
202
            node_id for node_id, count in node_counts.items() if count == 0
203
        ]
204
        return [nodes_dict[node_id] for node_id in root_node_ids]
205

206
    def __call__(self, *args: Any, **kwargs: Any) -> ToolOutput:
207
        """Call."""
208
        # the kwargs represented as a JSON object
209
        # should be a QueryPlan object
210
        query_plan = QueryPlan(**kwargs)
211

212
        nodes_dict = {node.id: node for node in query_plan.nodes}
213
        root_nodes = self._find_root_nodes(nodes_dict)
214
        if len(root_nodes) > 1:
215
            raise ValueError("Query plan should have exactly one root node.")
216

217
        return self._execute_node(root_nodes[0], nodes_dict)
218

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.