examples
107 строк · 3.0 Кб
1import os
2import asyncio
3from typing import Any
4
5import uvicorn
6from fastapi import FastAPI, Body
7from fastapi.responses import StreamingResponse
8from queue import Queue
9from pydantic import BaseModel
10
11from langchain.agents import AgentType, initialize_agent
12from langchain.chat_models import ChatOpenAI
13from langchain.memory import ConversationBufferWindowMemory
14from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
15from langchain.callbacks.streaming_stdout_final_only import FinalStreamingStdOutCallbackHandler
16from langchain.schema import LLMResult
17
18app = FastAPI()
19
20# initialize the agent (we need to do this for the callbacks)
21llm = ChatOpenAI(
22openai_api_key=os.getenv("OPENAI_API_KEY"),
23temperature=0.0,
24model_name="gpt-3.5-turbo",
25streaming=True, # ! important
26callbacks=[] # ! important (but we will add them later)
27)
28memory = ConversationBufferWindowMemory(
29memory_key="chat_history",
30k=5,
31return_messages=True,
32output_key="output"
33)
34agent = initialize_agent(
35agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
36tools=[],
37llm=llm,
38verbose=True,
39max_iterations=3,
40early_stopping_method="generate",
41memory=memory,
42return_intermediate_steps=False
43)
44
45class AsyncCallbackHandler(AsyncIteratorCallbackHandler):
46content: str = ""
47final_answer: bool = False
48
49def __init__(self) -> None:
50super().__init__()
51
52async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
53self.content += token
54# if we passed the final answer, we put tokens in queue
55if self.final_answer:
56if '"action_input": "' in self.content:
57if token not in ['"', "}"]:
58self.queue.put_nowait(token)
59elif "Final Answer" in self.content:
60self.final_answer = True
61self.content = ""
62
63async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
64if self.final_answer:
65self.content = ""
66self.final_answer = False
67self.done.set()
68else:
69self.content = ""
70
71async def run_call(query: str, stream_it: AsyncCallbackHandler):
72# assign callback handler
73agent.agent.llm_chain.llm.callbacks = [stream_it]
74# now query
75await agent.acall(inputs={"input": query})
76
77# request input format
78class Query(BaseModel):
79text: str
80
81async def create_gen(query: str, stream_it: AsyncCallbackHandler):
82task = asyncio.create_task(run_call(query, stream_it))
83async for token in stream_it.aiter():
84yield token
85await task
86
87@app.post("/chat")
88async def chat(
89query: Query = Body(...),
90):
91stream_it = AsyncCallbackHandler()
92gen = create_gen(query.text, stream_it)
93return StreamingResponse(gen, media_type="text/event-stream")
94
95@app.get("/health")
96async def health():
97"""Check the api is running"""
98return {"status": "🤙"}
99
100
101if __name__ == "__main__":
102uvicorn.run(
103"app:app",
104host="localhost",
105port=8000,
106reload=True
107)
108