llama-index
555 строк · 18.4 Кб
1"""OpenAI Assistant Agent."""
2
3import asyncio4import json5import logging6import time7from typing import Any, Dict, List, Optional, Tuple, Union, cast8
9from llama_index.legacy.agent.openai.utils import get_function_by_name10from llama_index.legacy.agent.types import BaseAgent11from llama_index.legacy.callbacks import (12CallbackManager,13CBEventType,14EventPayload,15trace_method,16)
17from llama_index.legacy.chat_engine.types import (18AGENT_CHAT_RESPONSE_TYPE,19AgentChatResponse,20ChatResponseMode,21StreamingAgentChatResponse,22)
23from llama_index.legacy.core.llms.types import ChatMessage, MessageRole24from llama_index.legacy.tools import BaseTool, ToolOutput, adapt_to_async_tool25
26logger = logging.getLogger(__name__)27logger.setLevel(logging.WARNING)28
29
30def from_openai_thread_message(thread_message: Any) -> ChatMessage:31"""From OpenAI thread message."""32from openai.types.beta.threads import MessageContentText, ThreadMessage33
34thread_message = cast(ThreadMessage, thread_message)35
36# we don't have a way of showing images, just do text for now37text_contents = [38t for t in thread_message.content if isinstance(t, MessageContentText)39]40text_content_str = " ".join([t.text.value for t in text_contents])41
42return ChatMessage(43role=thread_message.role,44content=text_content_str,45additional_kwargs={46"thread_message": thread_message,47"thread_id": thread_message.thread_id,48"assistant_id": thread_message.assistant_id,49"id": thread_message.id,50"metadata": thread_message.metadata,51},52)53
54
55def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage]:56"""From OpenAI thread messages."""57return [58from_openai_thread_message(thread_message) for thread_message in thread_messages59]60
61
62def call_function(63tools: List[BaseTool], fn_obj: Any, verbose: bool = False64) -> Tuple[ChatMessage, ToolOutput]:65"""Call a function and return the output as a string."""66from openai.types.beta.threads.required_action_function_tool_call import Function67
68fn_obj = cast(Function, fn_obj)69# TMP: consolidate with other abstractions70name = fn_obj.name71arguments_str = fn_obj.arguments72if verbose:73print("=== Calling Function ===")74print(f"Calling function: {name} with args: {arguments_str}")75tool = get_function_by_name(tools, name)76argument_dict = json.loads(arguments_str)77output = tool(**argument_dict)78if verbose:79print(f"Got output: {output!s}")80print("========================")81return (82ChatMessage(83content=str(output),84role=MessageRole.FUNCTION,85additional_kwargs={86"name": fn_obj.name,87},88),89output,90)91
92
93async def acall_function(94tools: List[BaseTool], fn_obj: Any, verbose: bool = False95) -> Tuple[ChatMessage, ToolOutput]:96"""Call an async function and return the output as a string."""97from openai.types.beta.threads.required_action_function_tool_call import Function98
99fn_obj = cast(Function, fn_obj)100# TMP: consolidate with other abstractions101name = fn_obj.name102arguments_str = fn_obj.arguments103if verbose:104print("=== Calling Function ===")105print(f"Calling function: {name} with args: {arguments_str}")106tool = get_function_by_name(tools, name)107argument_dict = json.loads(arguments_str)108async_tool = adapt_to_async_tool(tool)109output = await async_tool.acall(**argument_dict)110if verbose:111print(f"Got output: {output!s}")112print("========================")113return (114ChatMessage(115content=str(output),116role=MessageRole.FUNCTION,117additional_kwargs={118"name": fn_obj.name,119},120),121output,122)123
124
125def _process_files(client: Any, files: List[str]) -> Dict[str, str]:126"""Process files."""127from openai import OpenAI128
129client = cast(OpenAI, client)130
131file_dict = {}132for file in files:133file_obj = client.files.create(file=open(file, "rb"), purpose="assistants")134file_dict[file_obj.id] = file135return file_dict136
137
138class OpenAIAssistantAgent(BaseAgent):139"""OpenAIAssistant agent.140
141Wrapper around OpenAI assistant API: https://platform.openai.com/docs/assistants/overview
142
143"""
144
145def __init__(146self,147client: Any,148assistant: Any,149tools: Optional[List[BaseTool]],150callback_manager: Optional[CallbackManager] = None,151thread_id: Optional[str] = None,152instructions_prefix: Optional[str] = None,153run_retrieve_sleep_time: float = 0.1,154file_dict: Dict[str, str] = {},155verbose: bool = False,156) -> None:157"""Init params."""158from openai import OpenAI159from openai.types.beta.assistant import Assistant160
161self._client = cast(OpenAI, client)162self._assistant = cast(Assistant, assistant)163self._tools = tools or []164if thread_id is None:165thread = self._client.beta.threads.create()166thread_id = thread.id167self._thread_id = thread_id168self._instructions_prefix = instructions_prefix169self._run_retrieve_sleep_time = run_retrieve_sleep_time170self._verbose = verbose171self.file_dict = file_dict172
173self.callback_manager = callback_manager or CallbackManager([])174
175@classmethod176def from_new(177cls,178name: str,179instructions: str,180tools: Optional[List[BaseTool]] = None,181openai_tools: Optional[List[Dict]] = None,182thread_id: Optional[str] = None,183model: str = "gpt-4-1106-preview",184instructions_prefix: Optional[str] = None,185run_retrieve_sleep_time: float = 0.1,186files: Optional[List[str]] = None,187callback_manager: Optional[CallbackManager] = None,188verbose: bool = False,189file_ids: Optional[List[str]] = None,190api_key: Optional[str] = None,191) -> "OpenAIAssistantAgent":192"""From new assistant.193
194Args:
195name: name of assistant
196instructions: instructions for assistant
197tools: list of tools
198openai_tools: list of openai tools
199thread_id: thread id
200model: model
201run_retrieve_sleep_time: run retrieve sleep time
202files: files
203instructions_prefix: instructions prefix
204callback_manager: callback manager
205verbose: verbose
206file_ids: list of file ids
207api_key: OpenAI API key
208
209"""
210from openai import OpenAI211
212# this is the set of openai tools213# not to be confused with the tools we pass in for function calling214openai_tools = openai_tools or []215tools = tools or []216tool_fns = [t.metadata.to_openai_tool() for t in tools]217all_openai_tools = openai_tools + tool_fns218
219# initialize client220client = OpenAI(api_key=api_key)221
222# process files223files = files or []224file_ids = file_ids or []225
226file_dict = _process_files(client, files)227all_file_ids = list(file_dict.keys()) + file_ids228
229# TODO: openai's typing is a bit sus230all_openai_tools = cast(List[Any], all_openai_tools)231assistant = client.beta.assistants.create(232name=name,233instructions=instructions,234tools=cast(List[Any], all_openai_tools),235model=model,236file_ids=all_file_ids,237)238return cls(239client,240assistant,241tools,242callback_manager=callback_manager,243thread_id=thread_id,244instructions_prefix=instructions_prefix,245file_dict=file_dict,246run_retrieve_sleep_time=run_retrieve_sleep_time,247verbose=verbose,248)249
250@classmethod251def from_existing(252cls,253assistant_id: str,254tools: Optional[List[BaseTool]] = None,255thread_id: Optional[str] = None,256instructions_prefix: Optional[str] = None,257run_retrieve_sleep_time: float = 0.1,258callback_manager: Optional[CallbackManager] = None,259api_key: Optional[str] = None,260verbose: bool = False,261) -> "OpenAIAssistantAgent":262"""From existing assistant id.263
264Args:
265assistant_id: id of assistant
266tools: list of BaseTools Assistant can use
267thread_id: thread id
268run_retrieve_sleep_time: run retrieve sleep time
269instructions_prefix: instructions prefix
270callback_manager: callback manager
271api_key: OpenAI API key
272verbose: verbose
273
274"""
275from openai import OpenAI276
277# initialize client278client = OpenAI(api_key=api_key)279
280# get assistant281assistant = client.beta.assistants.retrieve(assistant_id)282# assistant.tools is incompatible with BaseTools so have to pass from params283
284return cls(285client,286assistant,287tools=tools,288callback_manager=callback_manager,289thread_id=thread_id,290instructions_prefix=instructions_prefix,291run_retrieve_sleep_time=run_retrieve_sleep_time,292verbose=verbose,293)294
295@property296def assistant(self) -> Any:297"""Get assistant."""298return self._assistant299
300@property301def client(self) -> Any:302"""Get client."""303return self._client304
305@property306def thread_id(self) -> str:307"""Get thread id."""308return self._thread_id309
310@property311def files_dict(self) -> Dict[str, str]:312"""Get files dict."""313return self.file_dict314
315@property316def chat_history(self) -> List[ChatMessage]:317raw_messages = self._client.beta.threads.messages.list(318thread_id=self._thread_id, order="asc"319)320return from_openai_thread_messages(list(raw_messages))321
322def reset(self) -> None:323"""Delete and create a new thread."""324self._client.beta.threads.delete(self._thread_id)325thread = self._client.beta.threads.create()326thread_id = thread.id327self._thread_id = thread_id328
329def get_tools(self, message: str) -> List[BaseTool]:330"""Get tools."""331return self._tools332
333def upload_files(self, files: List[str]) -> Dict[str, Any]:334"""Upload files."""335return _process_files(self._client, files)336
337def add_message(self, message: str, file_ids: Optional[List[str]] = None) -> Any:338"""Add message to assistant."""339file_ids = file_ids or []340return self._client.beta.threads.messages.create(341thread_id=self._thread_id,342role="user",343content=message,344file_ids=file_ids,345)346
347def _run_function_calling(self, run: Any) -> List[ToolOutput]:348"""Run function calling."""349tool_calls = run.required_action.submit_tool_outputs.tool_calls350tool_output_dicts = []351tool_output_objs: List[ToolOutput] = []352for tool_call in tool_calls:353fn_obj = tool_call.function354_, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose)355tool_output_dicts.append(356{"tool_call_id": tool_call.id, "output": str(tool_output)}357)358tool_output_objs.append(tool_output)359
360# submit tool outputs361# TODO: openai's typing is a bit sus362self._client.beta.threads.runs.submit_tool_outputs(363thread_id=self._thread_id,364run_id=run.id,365tool_outputs=cast(List[Any], tool_output_dicts),366)367return tool_output_objs368
369async def _arun_function_calling(self, run: Any) -> List[ToolOutput]:370"""Run function calling."""371tool_calls = run.required_action.submit_tool_outputs.tool_calls372tool_output_dicts = []373tool_output_objs: List[ToolOutput] = []374for tool_call in tool_calls:375fn_obj = tool_call.function376_, tool_output = await acall_function(377self._tools, fn_obj, verbose=self._verbose378)379tool_output_dicts.append(380{"tool_call_id": tool_call.id, "output": str(tool_output)}381)382tool_output_objs.append(tool_output)383
384# submit tool outputs385self._client.beta.threads.runs.submit_tool_outputs(386thread_id=self._thread_id,387run_id=run.id,388tool_outputs=cast(List[Any], tool_output_dicts),389)390return tool_output_objs391
392def run_assistant(393self, instructions_prefix: Optional[str] = None394) -> Tuple[Any, Dict]:395"""Run assistant."""396instructions_prefix = instructions_prefix or self._instructions_prefix397run = self._client.beta.threads.runs.create(398thread_id=self._thread_id,399assistant_id=self._assistant.id,400instructions=instructions_prefix,401)402from openai.types.beta.threads import Run403
404run = cast(Run, run)405
406sources = []407
408while run.status in ["queued", "in_progress", "requires_action"]:409run = self._client.beta.threads.runs.retrieve(410thread_id=self._thread_id, run_id=run.id411)412if run.status == "requires_action":413cur_tool_outputs = self._run_function_calling(run)414sources.extend(cur_tool_outputs)415
416time.sleep(self._run_retrieve_sleep_time)417if run.status == "failed":418raise ValueError(419f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"420)421return run, {"sources": sources}422
423async def arun_assistant(424self, instructions_prefix: Optional[str] = None425) -> Tuple[Any, Dict]:426"""Run assistant."""427instructions_prefix = instructions_prefix or self._instructions_prefix428run = self._client.beta.threads.runs.create(429thread_id=self._thread_id,430assistant_id=self._assistant.id,431instructions=instructions_prefix,432)433from openai.types.beta.threads import Run434
435run = cast(Run, run)436
437sources = []438
439while run.status in ["queued", "in_progress", "requires_action"]:440run = self._client.beta.threads.runs.retrieve(441thread_id=self._thread_id, run_id=run.id442)443if run.status == "requires_action":444cur_tool_outputs = await self._arun_function_calling(run)445sources.extend(cur_tool_outputs)446
447await asyncio.sleep(self._run_retrieve_sleep_time)448if run.status == "failed":449raise ValueError(450f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"451)452return run, {"sources": sources}453
454@property455def latest_message(self) -> ChatMessage:456"""Get latest message."""457raw_messages = self._client.beta.threads.messages.list(458thread_id=self._thread_id, order="desc"459)460messages = from_openai_thread_messages(list(raw_messages))461return messages[0]462
463def _chat(464self,465message: str,466chat_history: Optional[List[ChatMessage]] = None,467function_call: Union[str, dict] = "auto",468mode: ChatResponseMode = ChatResponseMode.WAIT,469) -> AGENT_CHAT_RESPONSE_TYPE:470"""Main chat interface."""471# TODO: since chat interface doesn't expose additional kwargs472# we can't pass in file_ids per message473added_message_obj = self.add_message(message)474run, metadata = self.run_assistant(475instructions_prefix=self._instructions_prefix,476)477latest_message = self.latest_message478# get most recent message content479return AgentChatResponse(480response=str(latest_message.content),481sources=metadata["sources"],482)483
484async def _achat(485self,486message: str,487chat_history: Optional[List[ChatMessage]] = None,488function_call: Union[str, dict] = "auto",489mode: ChatResponseMode = ChatResponseMode.WAIT,490) -> AGENT_CHAT_RESPONSE_TYPE:491"""Asynchronous main chat interface."""492self.add_message(message)493run, metadata = await self.arun_assistant(494instructions_prefix=self._instructions_prefix,495)496latest_message = self.latest_message497# get most recent message content498return AgentChatResponse(499response=str(latest_message.content),500sources=metadata["sources"],501)502
503@trace_method("chat")504def chat(505self,506message: str,507chat_history: Optional[List[ChatMessage]] = None,508function_call: Union[str, dict] = "auto",509) -> AgentChatResponse:510with self.callback_manager.event(511CBEventType.AGENT_STEP,512payload={EventPayload.MESSAGES: [message]},513) as e:514chat_response = self._chat(515message, chat_history, function_call, mode=ChatResponseMode.WAIT516)517assert isinstance(chat_response, AgentChatResponse)518e.on_end(payload={EventPayload.RESPONSE: chat_response})519return chat_response520
521@trace_method("chat")522async def achat(523self,524message: str,525chat_history: Optional[List[ChatMessage]] = None,526function_call: Union[str, dict] = "auto",527) -> AgentChatResponse:528with self.callback_manager.event(529CBEventType.AGENT_STEP,530payload={EventPayload.MESSAGES: [message]},531) as e:532chat_response = await self._achat(533message, chat_history, function_call, mode=ChatResponseMode.WAIT534)535assert isinstance(chat_response, AgentChatResponse)536e.on_end(payload={EventPayload.RESPONSE: chat_response})537return chat_response538
539@trace_method("chat")540def stream_chat(541self,542message: str,543chat_history: Optional[List[ChatMessage]] = None,544function_call: Union[str, dict] = "auto",545) -> StreamingAgentChatResponse:546raise NotImplementedError("stream_chat not implemented")547
548@trace_method("chat")549async def astream_chat(550self,551message: str,552chat_history: Optional[List[ChatMessage]] = None,553function_call: Union[str, dict] = "auto",554) -> StreamingAgentChatResponse:555raise NotImplementedError("astream_chat not implemented")556