llama-index

Форк
0
199 строк · 6.6 Кб
1
"""Agent worker that takes in a query pipeline."""
2

3
import uuid
4
from typing import (
5
    Any,
6
    List,
7
    Optional,
8
    cast,
9
)
10

11
from llama_index.legacy.agent.types import (
12
    BaseAgentWorker,
13
    Task,
14
    TaskStep,
15
    TaskStepOutput,
16
)
17
from llama_index.legacy.bridge.pydantic import BaseModel, Field
18
from llama_index.legacy.callbacks import (
19
    CallbackManager,
20
    trace_method,
21
)
22
from llama_index.legacy.chat_engine.types import (
23
    AGENT_CHAT_RESPONSE_TYPE,
24
)
25
from llama_index.legacy.core.query_pipeline.query_component import QueryComponent
26
from llama_index.legacy.memory.chat_memory_buffer import ChatMemoryBuffer
27
from llama_index.legacy.query_pipeline.components.agent import (
28
    AgentFnComponent,
29
    AgentInputComponent,
30
    BaseAgentComponent,
31
)
32
from llama_index.legacy.query_pipeline.query import QueryPipeline
33
from llama_index.legacy.tools import ToolOutput
34

35
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
36

37

38
def _get_agent_components(query_component: QueryComponent) -> List[BaseAgentComponent]:
39
    """Get agent components."""
40
    agent_components: List[BaseAgentComponent] = []
41
    for c in query_component.sub_query_components:
42
        if isinstance(c, BaseAgentComponent):
43
            agent_components.append(cast(BaseAgentComponent, c))
44

45
        if len(c.sub_query_components) > 0:
46
            agent_components.extend(_get_agent_components(c))
47

48
    return agent_components
49

50

51
class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
52
    """Query Pipeline agent worker.
53

54
    Barebones agent worker that takes in a query pipeline.
55

56
    Assumes that the first component in the query pipeline is an
57
    `AgentInputComponent` and last is `AgentFnComponent`.
58

59
    Args:
60
        pipeline (QueryPipeline): Query pipeline
61

62
    """
63

64
    pipeline: QueryPipeline = Field(..., description="Query pipeline")
65
    callback_manager: CallbackManager = Field(..., exclude=True)
66

67
    class Config:
68
        arbitrary_types_allowed = True
69

70
    def __init__(
71
        self,
72
        pipeline: QueryPipeline,
73
        callback_manager: Optional[CallbackManager] = None,
74
    ) -> None:
75
        """Initialize."""
76
        if callback_manager is not None:
77
            # set query pipeline callback
78
            pipeline.set_callback_manager(callback_manager)
79
        else:
80
            callback_manager = pipeline.callback_manager
81
        super().__init__(
82
            pipeline=pipeline,
83
            callback_manager=callback_manager,
84
        )
85
        # validate query pipeline
86
        self.agent_input_component
87
        self.agent_components
88

89
    @property
90
    def agent_input_component(self) -> AgentInputComponent:
91
        """Get agent input component."""
92
        root_key = self.pipeline.get_root_keys()[0]
93
        if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
94
            raise ValueError(
95
                "Query pipeline first component must be AgentInputComponent, got "
96
                f"{self.pipeline.module_dict[root_key]}"
97
            )
98

99
        return cast(AgentInputComponent, self.pipeline.module_dict[root_key])
100

101
    @property
102
    def agent_components(self) -> List[AgentFnComponent]:
103
        """Get agent output component."""
104
        return _get_agent_components(self.pipeline)
105

106
    def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
107
        """Initialize step from task."""
108
        sources: List[ToolOutput] = []
109
        # temporary memory for new messages
110
        new_memory = ChatMemoryBuffer.from_defaults()
111

112
        # initialize initial state
113
        initial_state = {
114
            "sources": sources,
115
            "memory": new_memory,
116
        }
117

118
        return TaskStep(
119
            task_id=task.task_id,
120
            step_id=str(uuid.uuid4()),
121
            input=task.input,
122
            step_state=initial_state,
123
        )
124

125
    def _get_task_step_response(
126
        self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
127
    ) -> TaskStepOutput:
128
        """Get task step response."""
129
        if is_done:
130
            new_steps = []
131
        else:
132
            new_steps = [
133
                step.get_next_step(
134
                    step_id=str(uuid.uuid4()),
135
                    # NOTE: input is unused
136
                    input=None,
137
                )
138
            ]
139

140
        return TaskStepOutput(
141
            output=agent_response,
142
            task_step=step,
143
            is_last=is_done,
144
            next_steps=new_steps,
145
        )
146

147
    @trace_method("run_step")
148
    def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
149
        """Run step."""
150
        # partial agent output component with task and step
151
        for agent_fn_component in self.agent_components:
152
            agent_fn_component.partial(task=task, state=step.step_state)
153

154
        agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
155
        response = self._get_task_step_response(agent_response, step, is_done)
156
        # sync step state with task state
157
        task.extra_state.update(step.step_state)
158
        return response
159

160
    @trace_method("run_step")
161
    async def arun_step(
162
        self, step: TaskStep, task: Task, **kwargs: Any
163
    ) -> TaskStepOutput:
164
        """Run step (async)."""
165
        # partial agent output component with task and step
166
        for agent_fn_component in self.agent_components:
167
            agent_fn_component.partial(task=task, state=step.step_state)
168

169
        agent_response, is_done = await self.pipeline.arun(
170
            state=step.step_state, task=task
171
        )
172
        response = self._get_task_step_response(agent_response, step, is_done)
173
        task.extra_state.update(step.step_state)
174
        return response
175

176
    @trace_method("run_step")
177
    def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
178
        """Run step (stream)."""
179
        raise NotImplementedError("This agent does not support streaming.")
180

181
    @trace_method("run_step")
182
    async def astream_step(
183
        self, step: TaskStep, task: Task, **kwargs: Any
184
    ) -> TaskStepOutput:
185
        """Run step (async stream)."""
186
        raise NotImplementedError("This agent does not support streaming.")
187

188
    def finalize_task(self, task: Task, **kwargs: Any) -> None:
189
        """Finalize task, after all the steps are completed."""
190
        # add new messages to memory
191
        task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
192
        # reset new memory
193
        task.extra_state["memory"].reset()
194

195
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
196
        """Set callback manager."""
197
        # TODO: make this abstractmethod (right now will break some agent impls)
198
        self.callback_manager = callback_manager
199
        self.pipeline.set_callback_manager(callback_manager)
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.