Langchain-Chatchat

Форк
0
161 строка · 5.5 Кб
1
from __future__ import annotations
2
from uuid import UUID
3
from langchain.callbacks import AsyncIteratorCallbackHandler
4
import json
5
import asyncio
6
from typing import Any, Dict, List, Optional
7

8
from langchain.schema import AgentFinish, AgentAction
9
from langchain.schema.output import LLMResult
10

11

12
def dumps(obj: Dict) -> str:
13
    return json.dumps(obj, ensure_ascii=False)
14

15

16
class Status:
17
    start: int = 1
18
    running: int = 2
19
    complete: int = 3
20
    agent_action: int = 4
21
    agent_finish: int = 5
22
    error: int = 6
23
    tool_finish: int = 7
24

25

26
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
27
    def __init__(self):
28
        super().__init__()
29
        self.queue = asyncio.Queue()
30
        self.done = asyncio.Event()
31
        self.cur_tool = {}
32
        self.out = True
33

34
    async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
35
                            parent_run_id: UUID | None = None, tags: List[str] | None = None,
36
                            metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
37

38
        # 对于截断不能自理的大模型,我来帮他截断
39
        stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
40
        for stop_word in stop_words:
41
            index = input_str.find(stop_word)
42
            if index != -1:
43
                input_str = input_str[:index]
44
                break
45

46
        self.cur_tool = {
47
            "tool_name": serialized["name"],
48
            "input_str": input_str,
49
            "output_str": "",
50
            "status": Status.agent_action,
51
            "run_id": run_id.hex,
52
            "llm_token": "",
53
            "final_answer": "",
54
            "error": "",
55
        }
56
        # print("\nInput Str:",self.cur_tool["input_str"])
57
        self.queue.put_nowait(dumps(self.cur_tool))
58

59
    async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
60
                          tags: List[str] | None = None, **kwargs: Any) -> None:
61
        self.out = True ## 重置输出
62
        self.cur_tool.update(
63
            status=Status.tool_finish,
64
            output_str=output.replace("Answer:", ""),
65
        )
66
        self.queue.put_nowait(dumps(self.cur_tool))
67

68
    async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
69
                            parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
70
        self.cur_tool.update(
71
            status=Status.error,
72
            error=str(error),
73
        )
74
        self.queue.put_nowait(dumps(self.cur_tool))
75

76
    # async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
77
    #     if "Action" in token: ## 减少重复输出
78
    #         before_action = token.split("Action")[0]
79
    #         self.cur_tool.update(
80
    #             status=Status.running,
81
    #             llm_token=before_action + "\n",
82
    #         )
83
    #         self.queue.put_nowait(dumps(self.cur_tool))
84
    #
85
    #         self.out = False
86
    #
87
    #     if token and self.out:
88
    #         self.cur_tool.update(
89
    #                 status=Status.running,
90
    #                 llm_token=token,
91
    #         )
92
    #         self.queue.put_nowait(dumps(self.cur_tool))
93
    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
94
        special_tokens = ["Action", "<|observation|>"]
95
        for stoken in special_tokens:
96
            if stoken in token:
97
                before_action = token.split(stoken)[0]
98
                self.cur_tool.update(
99
                    status=Status.running,
100
                    llm_token=before_action + "\n",
101
                )
102
                self.queue.put_nowait(dumps(self.cur_tool))
103
                self.out = False
104
                break
105

106
        if token and self.out:
107
            self.cur_tool.update(
108
                status=Status.running,
109
                llm_token=token,
110
            )
111
            self.queue.put_nowait(dumps(self.cur_tool))
112

113
    async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
114
        self.cur_tool.update(
115
            status=Status.start,
116
            llm_token="",
117
        )
118
        self.queue.put_nowait(dumps(self.cur_tool))
119
    async def on_chat_model_start(
120
        self,
121
        serialized: Dict[str, Any],
122
        messages: List[List],
123
        *,
124
        run_id: UUID,
125
        parent_run_id: Optional[UUID] = None,
126
        tags: Optional[List[str]] = None,
127
        metadata: Optional[Dict[str, Any]] = None,
128
        **kwargs: Any,
129
    ) -> None:
130
        self.cur_tool.update(
131
            status=Status.start,
132
            llm_token="",
133
        )
134
        self.queue.put_nowait(dumps(self.cur_tool))
135

136
    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
137
        self.cur_tool.update(
138
            status=Status.complete,
139
            llm_token="\n",
140
        )
141
        self.queue.put_nowait(dumps(self.cur_tool))
142

143
    async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
144
        self.cur_tool.update(
145
            status=Status.error,
146
            error=str(error),
147
        )
148
        self.queue.put_nowait(dumps(self.cur_tool))
149

150
    async def on_agent_finish(
151
            self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
152
            tags: Optional[List[str]] = None,
153
            **kwargs: Any,
154
    ) -> None:
155
        # 返回最终答案
156
        self.cur_tool.update(
157
            status=Status.agent_finish,
158
            final_answer=finish.return_values["output"],
159
        )
160
        self.queue.put_nowait(dumps(self.cur_tool))
161
        self.cur_tool = {}
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.