llama-index

Форк
0
130 строк · 4.0 Кб
1
# ReAct agent formatter
2

3
import logging
4
from abc import abstractmethod
5
from typing import List, Optional, Sequence
6

7
from llama_index.legacy.agent.react.prompts import (
8
    CONTEXT_REACT_CHAT_SYSTEM_HEADER,
9
    REACT_CHAT_SYSTEM_HEADER,
10
)
11
from llama_index.legacy.agent.react.types import (
12
    BaseReasoningStep,
13
    ObservationReasoningStep,
14
)
15
from llama_index.legacy.bridge.pydantic import BaseModel
16
from llama_index.legacy.core.llms.types import ChatMessage, MessageRole
17
from llama_index.legacy.tools import BaseTool
18

19
logger = logging.getLogger(__name__)
20

21

22
def get_react_tool_descriptions(tools: Sequence[BaseTool]) -> List[str]:
23
    """Tool."""
24
    tool_descs = []
25
    for tool in tools:
26
        tool_desc = (
27
            f"> Tool Name: {tool.metadata.name}\n"
28
            f"Tool Description: {tool.metadata.description}\n"
29
            f"Tool Args: {tool.metadata.fn_schema_str}\n"
30
        )
31
        tool_descs.append(tool_desc)
32
    return tool_descs
33

34

35
# TODO: come up with better name
36
class BaseAgentChatFormatter(BaseModel):
37
    """Base chat formatter."""
38

39
    class Config:
40
        arbitrary_types_allowed = True
41

42
    @abstractmethod
43
    def format(
44
        self,
45
        tools: Sequence[BaseTool],
46
        chat_history: List[ChatMessage],
47
        current_reasoning: Optional[List[BaseReasoningStep]] = None,
48
    ) -> List[ChatMessage]:
49
        """Format chat history into list of ChatMessage."""
50

51

52
class ReActChatFormatter(BaseAgentChatFormatter):
53
    """ReAct chat formatter."""
54

55
    system_header: str = REACT_CHAT_SYSTEM_HEADER  # default
56
    context: str = ""  # not needed w/ default
57

58
    def format(
59
        self,
60
        tools: Sequence[BaseTool],
61
        chat_history: List[ChatMessage],
62
        current_reasoning: Optional[List[BaseReasoningStep]] = None,
63
    ) -> List[ChatMessage]:
64
        """Format chat history into list of ChatMessage."""
65
        current_reasoning = current_reasoning or []
66

67
        format_args = {
68
            "tool_desc": "\n".join(get_react_tool_descriptions(tools)),
69
            "tool_names": ", ".join([tool.metadata.get_name() for tool in tools]),
70
        }
71
        if self.context:
72
            format_args["context"] = self.context
73

74
        fmt_sys_header = self.system_header.format(**format_args)
75

76
        # format reasoning history as alternating user and assistant messages
77
        # where the assistant messages are thoughts and actions and the user
78
        # messages are observations
79
        reasoning_history = []
80
        for reasoning_step in current_reasoning:
81
            if isinstance(reasoning_step, ObservationReasoningStep):
82
                message = ChatMessage(
83
                    role=MessageRole.USER,
84
                    content=reasoning_step.get_content(),
85
                )
86
            else:
87
                message = ChatMessage(
88
                    role=MessageRole.ASSISTANT,
89
                    content=reasoning_step.get_content(),
90
                )
91
            reasoning_history.append(message)
92

93
        return [
94
            ChatMessage(role=MessageRole.SYSTEM, content=fmt_sys_header),
95
            *chat_history,
96
            *reasoning_history,
97
        ]
98

99
    @classmethod
100
    def from_defaults(
101
        cls,
102
        system_header: Optional[str] = None,
103
        context: Optional[str] = None,
104
    ) -> "ReActChatFormatter":
105
        """Create ReActChatFormatter from defaults."""
106
        if not system_header:
107
            system_header = (
108
                REACT_CHAT_SYSTEM_HEADER
109
                if not context
110
                else CONTEXT_REACT_CHAT_SYSTEM_HEADER
111
            )
112

113
        return ReActChatFormatter(
114
            system_header=system_header,
115
            context=context or "",
116
        )
117

118
    @classmethod
119
    def from_context(cls, context: str) -> "ReActChatFormatter":
120
        """Create ReActChatFormatter from context.
121

122
        NOTE: deprecated
123

124
        """
125
        logger.warning(
126
            "ReActChatFormatter.from_context is deprecated, please use `from_defaults` instead."
127
        )
128
        return ReActChatFormatter.from_defaults(
129
            system_header=CONTEXT_REACT_CHAT_SYSTEM_HEADER, context=context
130
        )
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.