llama-index

Форк
0
261 строка · 8.7 Кб
1
"""Custom agent worker."""
2

3
import uuid
4
from abc import abstractmethod
5
from typing import (
6
    Any,
7
    Callable,
8
    Dict,
9
    List,
10
    Optional,
11
    Sequence,
12
    Tuple,
13
    cast,
14
)
15

16
from llama_index.legacy.agent.types import (
17
    BaseAgentWorker,
18
    Task,
19
    TaskStep,
20
    TaskStepOutput,
21
)
22
from llama_index.legacy.bridge.pydantic import BaseModel, Field, PrivateAttr
23
from llama_index.legacy.callbacks import (
24
    CallbackManager,
25
    trace_method,
26
)
27
from llama_index.legacy.chat_engine.types import (
28
    AGENT_CHAT_RESPONSE_TYPE,
29
    AgentChatResponse,
30
)
31
from llama_index.legacy.llms.llm import LLM
32
from llama_index.legacy.llms.openai import OpenAI
33
from llama_index.legacy.memory.chat_memory_buffer import ChatMemoryBuffer
34
from llama_index.legacy.objects.base import ObjectRetriever
35
from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool
36
from llama_index.legacy.tools.types import AsyncBaseTool
37

38
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
39

40

41
class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker):
42
    """Custom simple agent worker.
43

44
    This is "simple" in the sense that some of the scaffolding is setup already.
45
    Assumptions:
46
    - assumes that the agent has tools, llm, callback manager, and tool retriever
47
    - has a `from_tools` convenience function
48
    - assumes that the agent is sequential, and doesn't take in any additional
49
    intermediate inputs.
50

51
    Args:
52
        tools (Sequence[BaseTool]): Tools to use for reasoning
53
        llm (LLM): LLM to use
54
        callback_manager (CallbackManager): Callback manager
55
        tool_retriever (Optional[ObjectRetriever[BaseTool]]): Tool retriever
56
        verbose (bool): Whether to print out reasoning steps
57

58
    """
59

60
    tools: Sequence[BaseTool] = Field(..., description="Tools to use for reasoning")
61
    llm: LLM = Field(..., description="LLM to use")
62
    callback_manager: CallbackManager = Field(
63
        default_factory=lambda: CallbackManager([]), exclude=True
64
    )
65
    tool_retriever: Optional[ObjectRetriever[BaseTool]] = Field(
66
        default=None, description="Tool retriever"
67
    )
68
    verbose: bool = Field(False, description="Whether to print out reasoning steps")
69

70
    _get_tools: Callable[[str], Sequence[BaseTool]] = PrivateAttr()
71

72
    class Config:
73
        arbitrary_types_allowed = True
74

75
    def __init__(
76
        self,
77
        tools: Sequence[BaseTool],
78
        llm: LLM,
79
        callback_manager: Optional[CallbackManager] = None,
80
        verbose: bool = False,
81
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
82
    ) -> None:
83
        if len(tools) > 0 and tool_retriever is not None:
84
            raise ValueError("Cannot specify both tools and tool_retriever")
85
        elif len(tools) > 0:
86
            self._get_tools = lambda _: tools
87
        elif tool_retriever is not None:
88
            tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
89
            self._get_tools = lambda message: tool_retriever_c.retrieve(message)
90
        else:
91
            self._get_tools = lambda _: []
92

93
        super().__init__(
94
            tools=tools,
95
            llm=llm,
96
            callback_manager=callback_manager,
97
            tool_retriever=tool_retriever,
98
            verbose=verbose,
99
        )
100

101
    @classmethod
102
    def from_tools(
103
        cls,
104
        tools: Optional[Sequence[BaseTool]] = None,
105
        tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
106
        llm: Optional[LLM] = None,
107
        callback_manager: Optional[CallbackManager] = None,
108
        verbose: bool = False,
109
        **kwargs: Any,
110
    ) -> "CustomSimpleAgentWorker":
111
        """Convenience constructor method from set of of BaseTools (Optional)."""
112
        llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
113
        if callback_manager is not None:
114
            llm.callback_manager = callback_manager
115
        return cls(
116
            tools=tools or [],
117
            tool_retriever=tool_retriever,
118
            llm=llm,
119
            callback_manager=callback_manager,
120
            verbose=verbose,
121
        )
122

123
    @abstractmethod
124
    def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
125
        """Initialize state."""
126

127
    def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
128
        """Initialize step from task."""
129
        sources: List[ToolOutput] = []
130
        # temporary memory for new messages
131
        new_memory = ChatMemoryBuffer.from_defaults()
132

133
        # initialize initial state
134
        initial_state = {
135
            "sources": sources,
136
            "memory": new_memory,
137
        }
138

139
        step_state = self._initialize_state(task, **kwargs)
140
        # if intersecting keys, error
141
        if set(step_state.keys()).intersection(set(initial_state.keys())):
142
            raise ValueError(
143
                f"Step state keys {step_state.keys()} and initial state keys {initial_state.keys()} intersect."
144
                f"*NOTE*: initial state keys {initial_state.keys()} are reserved."
145
            )
146
        step_state.update(initial_state)
147

148
        return TaskStep(
149
            task_id=task.task_id,
150
            step_id=str(uuid.uuid4()),
151
            input=task.input,
152
            step_state=step_state,
153
        )
154

155
    def get_tools(self, input: str) -> List[AsyncBaseTool]:
156
        """Get tools."""
157
        return [adapt_to_async_tool(t) for t in self._get_tools(input)]
158

159
    def _get_task_step_response(
160
        self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
161
    ) -> TaskStepOutput:
162
        """Get task step response."""
163
        if is_done:
164
            new_steps = []
165
        else:
166
            new_steps = [
167
                step.get_next_step(
168
                    step_id=str(uuid.uuid4()),
169
                    # NOTE: input is unused
170
                    input=None,
171
                )
172
            ]
173

174
        return TaskStepOutput(
175
            output=agent_response,
176
            task_step=step,
177
            is_last=is_done,
178
            next_steps=new_steps,
179
        )
180

181
    @abstractmethod
182
    def _run_step(
183
        self, state: Dict[str, Any], task: Task, input: Optional[str] = None
184
    ) -> Tuple[AgentChatResponse, bool]:
185
        """Run step.
186

187
        Returns:
188
            Tuple of (agent_response, is_done)
189

190
        """
191

192
    async def _arun_step(
193
        self, state: Dict[str, Any], task: Task, input: Optional[str] = None
194
    ) -> Tuple[AgentChatResponse, bool]:
195
        """Run step (async).
196

197
        Can override this method if you want to run the step asynchronously.
198

199
        Returns:
200
            Tuple of (agent_response, is_done)
201

202
        """
203
        raise NotImplementedError(
204
            "This agent does not support async." "Please implement _arun_step."
205
        )
206

207
    @trace_method("run_step")
208
    def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
209
        """Run step."""
210
        agent_response, is_done = self._run_step(
211
            step.step_state, task, input=step.input
212
        )
213
        response = self._get_task_step_response(agent_response, step, is_done)
214
        # sync step state with task state
215
        task.extra_state.update(step.step_state)
216
        return response
217

218
    @trace_method("run_step")
219
    async def arun_step(
220
        self, step: TaskStep, task: Task, **kwargs: Any
221
    ) -> TaskStepOutput:
222
        """Run step (async)."""
223
        agent_response, is_done = await self._arun_step(
224
            step.step_state, task, input=step.input
225
        )
226
        response = self._get_task_step_response(agent_response, step, is_done)
227
        task.extra_state.update(step.step_state)
228
        return response
229

230
    @trace_method("run_step")
231
    def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
232
        """Run step (stream)."""
233
        raise NotImplementedError("This agent does not support streaming.")
234

235
    @trace_method("run_step")
236
    async def astream_step(
237
        self, step: TaskStep, task: Task, **kwargs: Any
238
    ) -> TaskStepOutput:
239
        """Run step (async stream)."""
240
        raise NotImplementedError("This agent does not support streaming.")
241

242
    @abstractmethod
243
    def _finalize_task(self, state: Dict[str, Any], **kwargs: Any) -> None:
244
        """Finalize task, after all the steps are completed.
245

246
        State is all the step states.
247

248
        """
249

250
    def finalize_task(self, task: Task, **kwargs: Any) -> None:
251
        """Finalize task, after all the steps are completed."""
252
        # add new messages to memory
253
        task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
254
        # reset new memory
255
        task.extra_state["memory"].reset()
256
        self._finalize_task(task.extra_state, **kwargs)
257

258
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
259
        """Set callback manager."""
260
        # TODO: make this abstractmethod (right now will break some agent impls)
261
        self.callback_manager = callback_manager
262

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.