Langchain-Chatchat
161 строка · 5.5 Кб
1from __future__ import annotations
2from uuid import UUID
3from langchain.callbacks import AsyncIteratorCallbackHandler
4import json
5import asyncio
6from typing import Any, Dict, List, Optional
7
8from langchain.schema import AgentFinish, AgentAction
9from langchain.schema.output import LLMResult
10
11
12def dumps(obj: Dict) -> str:
13return json.dumps(obj, ensure_ascii=False)
14
15
16class Status:
17start: int = 1
18running: int = 2
19complete: int = 3
20agent_action: int = 4
21agent_finish: int = 5
22error: int = 6
23tool_finish: int = 7
24
25
26class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
27def __init__(self):
28super().__init__()
29self.queue = asyncio.Queue()
30self.done = asyncio.Event()
31self.cur_tool = {}
32self.out = True
33
34async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
35parent_run_id: UUID | None = None, tags: List[str] | None = None,
36metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
37
38# 对于截断不能自理的大模型,我来帮他截断
39stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
40for stop_word in stop_words:
41index = input_str.find(stop_word)
42if index != -1:
43input_str = input_str[:index]
44break
45
46self.cur_tool = {
47"tool_name": serialized["name"],
48"input_str": input_str,
49"output_str": "",
50"status": Status.agent_action,
51"run_id": run_id.hex,
52"llm_token": "",
53"final_answer": "",
54"error": "",
55}
56# print("\nInput Str:",self.cur_tool["input_str"])
57self.queue.put_nowait(dumps(self.cur_tool))
58
59async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
60tags: List[str] | None = None, **kwargs: Any) -> None:
61self.out = True ## 重置输出
62self.cur_tool.update(
63status=Status.tool_finish,
64output_str=output.replace("Answer:", ""),
65)
66self.queue.put_nowait(dumps(self.cur_tool))
67
68async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
69parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
70self.cur_tool.update(
71status=Status.error,
72error=str(error),
73)
74self.queue.put_nowait(dumps(self.cur_tool))
75
76# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
77# if "Action" in token: ## 减少重复输出
78# before_action = token.split("Action")[0]
79# self.cur_tool.update(
80# status=Status.running,
81# llm_token=before_action + "\n",
82# )
83# self.queue.put_nowait(dumps(self.cur_tool))
84#
85# self.out = False
86#
87# if token and self.out:
88# self.cur_tool.update(
89# status=Status.running,
90# llm_token=token,
91# )
92# self.queue.put_nowait(dumps(self.cur_tool))
93async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
94special_tokens = ["Action", "<|observation|>"]
95for stoken in special_tokens:
96if stoken in token:
97before_action = token.split(stoken)[0]
98self.cur_tool.update(
99status=Status.running,
100llm_token=before_action + "\n",
101)
102self.queue.put_nowait(dumps(self.cur_tool))
103self.out = False
104break
105
106if token and self.out:
107self.cur_tool.update(
108status=Status.running,
109llm_token=token,
110)
111self.queue.put_nowait(dumps(self.cur_tool))
112
113async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
114self.cur_tool.update(
115status=Status.start,
116llm_token="",
117)
118self.queue.put_nowait(dumps(self.cur_tool))
119async def on_chat_model_start(
120self,
121serialized: Dict[str, Any],
122messages: List[List],
123*,
124run_id: UUID,
125parent_run_id: Optional[UUID] = None,
126tags: Optional[List[str]] = None,
127metadata: Optional[Dict[str, Any]] = None,
128**kwargs: Any,
129) -> None:
130self.cur_tool.update(
131status=Status.start,
132llm_token="",
133)
134self.queue.put_nowait(dumps(self.cur_tool))
135
136async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
137self.cur_tool.update(
138status=Status.complete,
139llm_token="\n",
140)
141self.queue.put_nowait(dumps(self.cur_tool))
142
143async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
144self.cur_tool.update(
145status=Status.error,
146error=str(error),
147)
148self.queue.put_nowait(dumps(self.cur_tool))
149
150async def on_agent_finish(
151self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
152tags: Optional[List[str]] = None,
153**kwargs: Any,
154) -> None:
155# 返回最终答案
156self.cur_tool.update(
157status=Status.agent_finish,
158final_answer=finish.return_values["output"],
159)
160self.queue.put_nowait(dumps(self.cur_tool))
161self.cur_tool = {}
162