llama-index

Форк
0
555 строк · 18.4 Кб
1
"""OpenAI Assistant Agent."""
2

3
import asyncio
4
import json
5
import logging
6
import time
7
from typing import Any, Dict, List, Optional, Tuple, Union, cast
8

9
from llama_index.legacy.agent.openai.utils import get_function_by_name
10
from llama_index.legacy.agent.types import BaseAgent
11
from llama_index.legacy.callbacks import (
12
    CallbackManager,
13
    CBEventType,
14
    EventPayload,
15
    trace_method,
16
)
17
from llama_index.legacy.chat_engine.types import (
18
    AGENT_CHAT_RESPONSE_TYPE,
19
    AgentChatResponse,
20
    ChatResponseMode,
21
    StreamingAgentChatResponse,
22
)
23
from llama_index.legacy.core.llms.types import ChatMessage, MessageRole
24
from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool
25

26
logger = logging.getLogger(__name__)
27
logger.setLevel(logging.WARNING)
28

29

30
def from_openai_thread_message(thread_message: Any) -> ChatMessage:
31
    """From OpenAI thread message."""
32
    from openai.types.beta.threads import MessageContentText, ThreadMessage
33

34
    thread_message = cast(ThreadMessage, thread_message)
35

36
    # we don't have a way of showing images, just do text for now
37
    text_contents = [
38
        t for t in thread_message.content if isinstance(t, MessageContentText)
39
    ]
40
    text_content_str = " ".join([t.text.value for t in text_contents])
41

42
    return ChatMessage(
43
        role=thread_message.role,
44
        content=text_content_str,
45
        additional_kwargs={
46
            "thread_message": thread_message,
47
            "thread_id": thread_message.thread_id,
48
            "assistant_id": thread_message.assistant_id,
49
            "id": thread_message.id,
50
            "metadata": thread_message.metadata,
51
        },
52
    )
53

54

55
def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage]:
56
    """From OpenAI thread messages."""
57
    return [
58
        from_openai_thread_message(thread_message) for thread_message in thread_messages
59
    ]
60

61

62
def call_function(
63
    tools: List[BaseTool], fn_obj: Any, verbose: bool = False
64
) -> Tuple[ChatMessage, ToolOutput]:
65
    """Call a function and return the output as a string."""
66
    from openai.types.beta.threads.required_action_function_tool_call import Function
67

68
    fn_obj = cast(Function, fn_obj)
69
    # TMP: consolidate with other abstractions
70
    name = fn_obj.name
71
    arguments_str = fn_obj.arguments
72
    if verbose:
73
        print("=== Calling Function ===")
74
        print(f"Calling function: {name} with args: {arguments_str}")
75
    tool = get_function_by_name(tools, name)
76
    argument_dict = json.loads(arguments_str)
77
    output = tool(**argument_dict)
78
    if verbose:
79
        print(f"Got output: {output!s}")
80
        print("========================")
81
    return (
82
        ChatMessage(
83
            content=str(output),
84
            role=MessageRole.FUNCTION,
85
            additional_kwargs={
86
                "name": fn_obj.name,
87
            },
88
        ),
89
        output,
90
    )
91

92

93
async def acall_function(
94
    tools: List[BaseTool], fn_obj: Any, verbose: bool = False
95
) -> Tuple[ChatMessage, ToolOutput]:
96
    """Call an async function and return the output as a string."""
97
    from openai.types.beta.threads.required_action_function_tool_call import Function
98

99
    fn_obj = cast(Function, fn_obj)
100
    # TMP: consolidate with other abstractions
101
    name = fn_obj.name
102
    arguments_str = fn_obj.arguments
103
    if verbose:
104
        print("=== Calling Function ===")
105
        print(f"Calling function: {name} with args: {arguments_str}")
106
    tool = get_function_by_name(tools, name)
107
    argument_dict = json.loads(arguments_str)
108
    async_tool = adapt_to_async_tool(tool)
109
    output = await async_tool.acall(**argument_dict)
110
    if verbose:
111
        print(f"Got output: {output!s}")
112
        print("========================")
113
    return (
114
        ChatMessage(
115
            content=str(output),
116
            role=MessageRole.FUNCTION,
117
            additional_kwargs={
118
                "name": fn_obj.name,
119
            },
120
        ),
121
        output,
122
    )
123

124

125
def _process_files(client: Any, files: List[str]) -> Dict[str, str]:
126
    """Process files."""
127
    from openai import OpenAI
128

129
    client = cast(OpenAI, client)
130

131
    file_dict = {}
132
    for file in files:
133
        file_obj = client.files.create(file=open(file, "rb"), purpose="assistants")
134
        file_dict[file_obj.id] = file
135
    return file_dict
136

137

138
class OpenAIAssistantAgent(BaseAgent):
139
    """OpenAIAssistant agent.
140

141
    Wrapper around OpenAI assistant API: https://platform.openai.com/docs/assistants/overview
142

143
    """
144

145
    def __init__(
146
        self,
147
        client: Any,
148
        assistant: Any,
149
        tools: Optional[List[BaseTool]],
150
        callback_manager: Optional[CallbackManager] = None,
151
        thread_id: Optional[str] = None,
152
        instructions_prefix: Optional[str] = None,
153
        run_retrieve_sleep_time: float = 0.1,
154
        file_dict: Dict[str, str] = {},
155
        verbose: bool = False,
156
    ) -> None:
157
        """Init params."""
158
        from openai import OpenAI
159
        from openai.types.beta.assistant import Assistant
160

161
        self._client = cast(OpenAI, client)
162
        self._assistant = cast(Assistant, assistant)
163
        self._tools = tools or []
164
        if thread_id is None:
165
            thread = self._client.beta.threads.create()
166
            thread_id = thread.id
167
        self._thread_id = thread_id
168
        self._instructions_prefix = instructions_prefix
169
        self._run_retrieve_sleep_time = run_retrieve_sleep_time
170
        self._verbose = verbose
171
        self.file_dict = file_dict
172

173
        self.callback_manager = callback_manager or CallbackManager([])
174

175
    @classmethod
176
    def from_new(
177
        cls,
178
        name: str,
179
        instructions: str,
180
        tools: Optional[List[BaseTool]] = None,
181
        openai_tools: Optional[List[Dict]] = None,
182
        thread_id: Optional[str] = None,
183
        model: str = "gpt-4-1106-preview",
184
        instructions_prefix: Optional[str] = None,
185
        run_retrieve_sleep_time: float = 0.1,
186
        files: Optional[List[str]] = None,
187
        callback_manager: Optional[CallbackManager] = None,
188
        verbose: bool = False,
189
        file_ids: Optional[List[str]] = None,
190
        api_key: Optional[str] = None,
191
    ) -> "OpenAIAssistantAgent":
192
        """From new assistant.
193

194
        Args:
195
            name: name of assistant
196
            instructions: instructions for assistant
197
            tools: list of tools
198
            openai_tools: list of openai tools
199
            thread_id: thread id
200
            model: model
201
            run_retrieve_sleep_time: run retrieve sleep time
202
            files: files
203
            instructions_prefix: instructions prefix
204
            callback_manager: callback manager
205
            verbose: verbose
206
            file_ids: list of file ids
207
            api_key: OpenAI API key
208

209
        """
210
        from openai import OpenAI
211

212
        # this is the set of openai tools
213
        # not to be confused with the tools we pass in for function calling
214
        openai_tools = openai_tools or []
215
        tools = tools or []
216
        tool_fns = [t.metadata.to_openai_tool() for t in tools]
217
        all_openai_tools = openai_tools + tool_fns
218

219
        # initialize client
220
        client = OpenAI(api_key=api_key)
221

222
        # process files
223
        files = files or []
224
        file_ids = file_ids or []
225

226
        file_dict = _process_files(client, files)
227
        all_file_ids = list(file_dict.keys()) + file_ids
228

229
        # TODO: openai's typing is a bit sus
230
        all_openai_tools = cast(List[Any], all_openai_tools)
231
        assistant = client.beta.assistants.create(
232
            name=name,
233
            instructions=instructions,
234
            tools=cast(List[Any], all_openai_tools),
235
            model=model,
236
            file_ids=all_file_ids,
237
        )
238
        return cls(
239
            client,
240
            assistant,
241
            tools,
242
            callback_manager=callback_manager,
243
            thread_id=thread_id,
244
            instructions_prefix=instructions_prefix,
245
            file_dict=file_dict,
246
            run_retrieve_sleep_time=run_retrieve_sleep_time,
247
            verbose=verbose,
248
        )
249

250
    @classmethod
251
    def from_existing(
252
        cls,
253
        assistant_id: str,
254
        tools: Optional[List[BaseTool]] = None,
255
        thread_id: Optional[str] = None,
256
        instructions_prefix: Optional[str] = None,
257
        run_retrieve_sleep_time: float = 0.1,
258
        callback_manager: Optional[CallbackManager] = None,
259
        api_key: Optional[str] = None,
260
        verbose: bool = False,
261
    ) -> "OpenAIAssistantAgent":
262
        """From existing assistant id.
263

264
        Args:
265
            assistant_id: id of assistant
266
            tools: list of BaseTools Assistant can use
267
            thread_id: thread id
268
            run_retrieve_sleep_time: run retrieve sleep time
269
            instructions_prefix: instructions prefix
270
            callback_manager: callback manager
271
            api_key: OpenAI API key
272
            verbose: verbose
273

274
        """
275
        from openai import OpenAI
276

277
        # initialize client
278
        client = OpenAI(api_key=api_key)
279

280
        # get assistant
281
        assistant = client.beta.assistants.retrieve(assistant_id)
282
        # assistant.tools is incompatible with BaseTools so have to pass from params
283

284
        return cls(
285
            client,
286
            assistant,
287
            tools=tools,
288
            callback_manager=callback_manager,
289
            thread_id=thread_id,
290
            instructions_prefix=instructions_prefix,
291
            run_retrieve_sleep_time=run_retrieve_sleep_time,
292
            verbose=verbose,
293
        )
294

295
    @property
296
    def assistant(self) -> Any:
297
        """Get assistant."""
298
        return self._assistant
299

300
    @property
301
    def client(self) -> Any:
302
        """Get client."""
303
        return self._client
304

305
    @property
306
    def thread_id(self) -> str:
307
        """Get thread id."""
308
        return self._thread_id
309

310
    @property
311
    def files_dict(self) -> Dict[str, str]:
312
        """Get files dict."""
313
        return self.file_dict
314

315
    @property
316
    def chat_history(self) -> List[ChatMessage]:
317
        raw_messages = self._client.beta.threads.messages.list(
318
            thread_id=self._thread_id, order="asc"
319
        )
320
        return from_openai_thread_messages(list(raw_messages))
321

322
    def reset(self) -> None:
323
        """Delete and create a new thread."""
324
        self._client.beta.threads.delete(self._thread_id)
325
        thread = self._client.beta.threads.create()
326
        thread_id = thread.id
327
        self._thread_id = thread_id
328

329
    def get_tools(self, message: str) -> List[BaseTool]:
330
        """Get tools."""
331
        return self._tools
332

333
    def upload_files(self, files: List[str]) -> Dict[str, Any]:
334
        """Upload files."""
335
        return _process_files(self._client, files)
336

337
    def add_message(self, message: str, file_ids: Optional[List[str]] = None) -> Any:
338
        """Add message to assistant."""
339
        file_ids = file_ids or []
340
        return self._client.beta.threads.messages.create(
341
            thread_id=self._thread_id,
342
            role="user",
343
            content=message,
344
            file_ids=file_ids,
345
        )
346

347
    def _run_function_calling(self, run: Any) -> List[ToolOutput]:
348
        """Run function calling."""
349
        tool_calls = run.required_action.submit_tool_outputs.tool_calls
350
        tool_output_dicts = []
351
        tool_output_objs: List[ToolOutput] = []
352
        for tool_call in tool_calls:
353
            fn_obj = tool_call.function
354
            _, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose)
355
            tool_output_dicts.append(
356
                {"tool_call_id": tool_call.id, "output": str(tool_output)}
357
            )
358
            tool_output_objs.append(tool_output)
359

360
        # submit tool outputs
361
        # TODO: openai's typing is a bit sus
362
        self._client.beta.threads.runs.submit_tool_outputs(
363
            thread_id=self._thread_id,
364
            run_id=run.id,
365
            tool_outputs=cast(List[Any], tool_output_dicts),
366
        )
367
        return tool_output_objs
368

369
    async def _arun_function_calling(self, run: Any) -> List[ToolOutput]:
370
        """Run function calling."""
371
        tool_calls = run.required_action.submit_tool_outputs.tool_calls
372
        tool_output_dicts = []
373
        tool_output_objs: List[ToolOutput] = []
374
        for tool_call in tool_calls:
375
            fn_obj = tool_call.function
376
            _, tool_output = await acall_function(
377
                self._tools, fn_obj, verbose=self._verbose
378
            )
379
            tool_output_dicts.append(
380
                {"tool_call_id": tool_call.id, "output": str(tool_output)}
381
            )
382
            tool_output_objs.append(tool_output)
383

384
        # submit tool outputs
385
        self._client.beta.threads.runs.submit_tool_outputs(
386
            thread_id=self._thread_id,
387
            run_id=run.id,
388
            tool_outputs=cast(List[Any], tool_output_dicts),
389
        )
390
        return tool_output_objs
391

392
    def run_assistant(
393
        self, instructions_prefix: Optional[str] = None
394
    ) -> Tuple[Any, Dict]:
395
        """Run assistant."""
396
        instructions_prefix = instructions_prefix or self._instructions_prefix
397
        run = self._client.beta.threads.runs.create(
398
            thread_id=self._thread_id,
399
            assistant_id=self._assistant.id,
400
            instructions=instructions_prefix,
401
        )
402
        from openai.types.beta.threads import Run
403

404
        run = cast(Run, run)
405

406
        sources = []
407

408
        while run.status in ["queued", "in_progress", "requires_action"]:
409
            run = self._client.beta.threads.runs.retrieve(
410
                thread_id=self._thread_id, run_id=run.id
411
            )
412
            if run.status == "requires_action":
413
                cur_tool_outputs = self._run_function_calling(run)
414
                sources.extend(cur_tool_outputs)
415

416
            time.sleep(self._run_retrieve_sleep_time)
417
        if run.status == "failed":
418
            raise ValueError(
419
                f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
420
            )
421
        return run, {"sources": sources}
422

423
    async def arun_assistant(
424
        self, instructions_prefix: Optional[str] = None
425
    ) -> Tuple[Any, Dict]:
426
        """Run assistant."""
427
        instructions_prefix = instructions_prefix or self._instructions_prefix
428
        run = self._client.beta.threads.runs.create(
429
            thread_id=self._thread_id,
430
            assistant_id=self._assistant.id,
431
            instructions=instructions_prefix,
432
        )
433
        from openai.types.beta.threads import Run
434

435
        run = cast(Run, run)
436

437
        sources = []
438

439
        while run.status in ["queued", "in_progress", "requires_action"]:
440
            run = self._client.beta.threads.runs.retrieve(
441
                thread_id=self._thread_id, run_id=run.id
442
            )
443
            if run.status == "requires_action":
444
                cur_tool_outputs = await self._arun_function_calling(run)
445
                sources.extend(cur_tool_outputs)
446

447
            await asyncio.sleep(self._run_retrieve_sleep_time)
448
        if run.status == "failed":
449
            raise ValueError(
450
                f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
451
            )
452
        return run, {"sources": sources}
453

454
    @property
455
    def latest_message(self) -> ChatMessage:
456
        """Get latest message."""
457
        raw_messages = self._client.beta.threads.messages.list(
458
            thread_id=self._thread_id, order="desc"
459
        )
460
        messages = from_openai_thread_messages(list(raw_messages))
461
        return messages[0]
462

463
    def _chat(
464
        self,
465
        message: str,
466
        chat_history: Optional[List[ChatMessage]] = None,
467
        function_call: Union[str, dict] = "auto",
468
        mode: ChatResponseMode = ChatResponseMode.WAIT,
469
    ) -> AGENT_CHAT_RESPONSE_TYPE:
470
        """Main chat interface."""
471
        # TODO: since chat interface doesn't expose additional kwargs
472
        # we can't pass in file_ids per message
473
        added_message_obj = self.add_message(message)
474
        run, metadata = self.run_assistant(
475
            instructions_prefix=self._instructions_prefix,
476
        )
477
        latest_message = self.latest_message
478
        # get most recent message content
479
        return AgentChatResponse(
480
            response=str(latest_message.content),
481
            sources=metadata["sources"],
482
        )
483

484
    async def _achat(
485
        self,
486
        message: str,
487
        chat_history: Optional[List[ChatMessage]] = None,
488
        function_call: Union[str, dict] = "auto",
489
        mode: ChatResponseMode = ChatResponseMode.WAIT,
490
    ) -> AGENT_CHAT_RESPONSE_TYPE:
491
        """Asynchronous main chat interface."""
492
        self.add_message(message)
493
        run, metadata = await self.arun_assistant(
494
            instructions_prefix=self._instructions_prefix,
495
        )
496
        latest_message = self.latest_message
497
        # get most recent message content
498
        return AgentChatResponse(
499
            response=str(latest_message.content),
500
            sources=metadata["sources"],
501
        )
502

503
    @trace_method("chat")
504
    def chat(
505
        self,
506
        message: str,
507
        chat_history: Optional[List[ChatMessage]] = None,
508
        function_call: Union[str, dict] = "auto",
509
    ) -> AgentChatResponse:
510
        with self.callback_manager.event(
511
            CBEventType.AGENT_STEP,
512
            payload={EventPayload.MESSAGES: [message]},
513
        ) as e:
514
            chat_response = self._chat(
515
                message, chat_history, function_call, mode=ChatResponseMode.WAIT
516
            )
517
            assert isinstance(chat_response, AgentChatResponse)
518
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
519
        return chat_response
520

521
    @trace_method("chat")
522
    async def achat(
523
        self,
524
        message: str,
525
        chat_history: Optional[List[ChatMessage]] = None,
526
        function_call: Union[str, dict] = "auto",
527
    ) -> AgentChatResponse:
528
        with self.callback_manager.event(
529
            CBEventType.AGENT_STEP,
530
            payload={EventPayload.MESSAGES: [message]},
531
        ) as e:
532
            chat_response = await self._achat(
533
                message, chat_history, function_call, mode=ChatResponseMode.WAIT
534
            )
535
            assert isinstance(chat_response, AgentChatResponse)
536
            e.on_end(payload={EventPayload.RESPONSE: chat_response})
537
        return chat_response
538

539
    @trace_method("chat")
540
    def stream_chat(
541
        self,
542
        message: str,
543
        chat_history: Optional[List[ChatMessage]] = None,
544
        function_call: Union[str, dict] = "auto",
545
    ) -> StreamingAgentChatResponse:
546
        raise NotImplementedError("stream_chat not implemented")
547

548
    @trace_method("chat")
549
    async def astream_chat(
550
        self,
551
        message: str,
552
        chat_history: Optional[List[ChatMessage]] = None,
553
        function_call: Union[str, dict] = "auto",
554
    ) -> StreamingAgentChatResponse:
555
        raise NotImplementedError("astream_chat not implemented")
556

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.