llama-index

Форк
0
107 строк · 3.5 Кб
1
"""Retriever tool."""
2

3
from typing import TYPE_CHECKING, Any, Optional
4

5
from llama_index.legacy.core.base_retriever import BaseRetriever
6

7
if TYPE_CHECKING:
8
    from llama_index.legacy.langchain_helpers.agents.tools import LlamaIndexTool
9
from llama_index.legacy.schema import MetadataMode
10
from llama_index.legacy.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
11

12
DEFAULT_NAME = "retriever_tool"
13
DEFAULT_DESCRIPTION = """Useful for running a natural language query
14
against a knowledge base and retrieving a set of relevant documents.
15
"""
16

17

18
class RetrieverTool(AsyncBaseTool):
19
    """Retriever tool.
20

21
    A tool making use of a retriever.
22

23
    Args:
24
        retriever (BaseRetriever): A retriever.
25
        metadata (ToolMetadata): The associated metadata of the query engine.
26
    """
27

28
    def __init__(
29
        self,
30
        retriever: BaseRetriever,
31
        metadata: ToolMetadata,
32
    ) -> None:
33
        self._retriever = retriever
34
        self._metadata = metadata
35

36
    @classmethod
37
    def from_defaults(
38
        cls,
39
        retriever: BaseRetriever,
40
        name: Optional[str] = None,
41
        description: Optional[str] = None,
42
    ) -> "RetrieverTool":
43
        name = name or DEFAULT_NAME
44
        description = description or DEFAULT_DESCRIPTION
45

46
        metadata = ToolMetadata(name=name, description=description)
47
        return cls(retriever=retriever, metadata=metadata)
48

49
    @property
50
    def retriever(self) -> BaseRetriever:
51
        return self._retriever
52

53
    @property
54
    def metadata(self) -> ToolMetadata:
55
        return self._metadata
56

57
    def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
58
        query_str = ""
59
        if args is not None:
60
            query_str += ", ".join([str(arg) for arg in args]) + "\n"
61
        if kwargs is not None:
62
            query_str += (
63
                ", ".join([f"{k!s} is {v!s}" for k, v in kwargs.items()]) + "\n"
64
            )
65
        if query_str == "":
66
            raise ValueError("Cannot call query engine without inputs")
67

68
        docs = self._retriever.retrieve(query_str)
69
        content = ""
70
        for doc in docs:
71
            node_copy = doc.node.copy()
72
            node_copy.text_template = "{metadata_str}\n{content}"
73
            node_copy.metadata_template = "{key} = {value}"
74
            content += node_copy.get_content(MetadataMode.LLM) + "\n\n"
75
        return ToolOutput(
76
            content=content,
77
            tool_name=self.metadata.name,
78
            raw_input={"input": input},
79
            raw_output=docs,
80
        )
81

82
    async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
83
        query_str = ""
84
        if args is not None:
85
            query_str += ", ".join([str(arg) for arg in args]) + "\n"
86
        if kwargs is not None:
87
            query_str += (
88
                ", ".join([f"{k!s} is {v!s}" for k, v in kwargs.items()]) + "\n"
89
            )
90
        if query_str == "":
91
            raise ValueError("Cannot call query engine without inputs")
92
        docs = await self._retriever.aretrieve(query_str)
93
        content = ""
94
        for doc in docs:
95
            node_copy = doc.node.copy()
96
            node_copy.text_template = "{metadata_str}\n{content}"
97
            node_copy.metadata_template = "{key} = {value}"
98
            content += node_copy.get_content(MetadataMode.LLM) + "\n\n"
99
        return ToolOutput(
100
            content=content,
101
            tool_name=self.metadata.name,
102
            raw_input={"input": input},
103
            raw_output=docs,
104
        )
105

106
    def as_langchain_tool(self) -> "LlamaIndexTool":
107
        raise NotImplementedError("`as_langchain_tool` not implemented here.")
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.