Langchain-Chatchat

Форк
0
248 строк · 8.2 Кб
1
from fastchat.conversation import Conversation
2
from configs import LOG_PATH, TEMPERATURE
3
import fastchat.constants
4
fastchat.constants.LOGDIR = LOG_PATH
5
from fastchat.serve.base_model_worker import BaseModelWorker
6
import uuid
7
import json
8
import sys
9
from pydantic import BaseModel, root_validator
10
import fastchat
11
import asyncio
12
from server.utils import get_model_worker_config
13
from typing import Dict, List, Optional
14

15

16
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
17

18

19
class ApiConfigParams(BaseModel):
20
    '''
21
    在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
22
    '''
23
    api_base_url: Optional[str] = None
24
    api_proxy: Optional[str] = None
25
    api_key: Optional[str] = None
26
    secret_key: Optional[str] = None
27
    group_id: Optional[str] = None # for minimax
28
    is_pro: bool = False # for minimax
29

30
    APPID: Optional[str] = None # for xinghuo
31
    APISecret: Optional[str] = None # for xinghuo
32
    is_v2: bool = False # for xinghuo
33

34
    worker_name: Optional[str] = None
35

36
    class Config:
37
        extra = "allow"
38

39
    @root_validator(pre=True)
40
    def validate_config(cls, v: Dict) -> Dict:
41
        if config := get_model_worker_config(v.get("worker_name")):
42
            for n in cls.__fields__:
43
                if n in config:
44
                    v[n] = config[n]
45
        return v
46

47
    def load_config(self, worker_name: str):
48
        self.worker_name = worker_name
49
        if config := get_model_worker_config(worker_name):
50
            for n in self.__fields__:
51
                if n in config:
52
                    setattr(self, n, config[n])
53
        return self
54

55

56
class ApiModelParams(ApiConfigParams):
57
    '''
58
    模型配置参数
59
    '''
60
    version: Optional[str] = None
61
    version_url: Optional[str] = None
62
    api_version: Optional[str] = None # for azure
63
    deployment_name: Optional[str] = None # for azure
64
    resource_name: Optional[str] = None # for azure
65

66
    temperature: float = TEMPERATURE
67
    max_tokens: Optional[int] = None
68
    top_p: Optional[float] = 1.0
69

70

71
class ApiChatParams(ApiModelParams):
72
    '''
73
    chat请求参数
74
    '''
75
    messages: List[Dict[str, str]]
76
    system_message: Optional[str] = None # for minimax
77
    role_meta: Dict = {} # for minimax
78

79

80
class ApiCompletionParams(ApiModelParams):
81
    prompt: str
82

83

84
class ApiEmbeddingsParams(ApiConfigParams):
85
    texts: List[str]
86
    embed_model: Optional[str] = None
87
    to_query: bool = False # for minimax
88

89

90
class ApiModelWorker(BaseModelWorker):
91
    DEFAULT_EMBED_MODEL: str = None # None means not support embedding
92

93
    def __init__(
94
        self,
95
        model_names: List[str],
96
        controller_addr: str = None,
97
        worker_addr: str = None,
98
        context_len: int = 2048,
99
        no_register: bool = False,
100
        **kwargs,
101
    ):
102
        kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
103
        kwargs.setdefault("model_path", "")
104
        kwargs.setdefault("limit_worker_concurrency", 5)
105
        super().__init__(model_names=model_names,
106
                        controller_addr=controller_addr,
107
                        worker_addr=worker_addr,
108
                        **kwargs)
109
        import fastchat.serve.base_model_worker
110
        import sys
111
        self.logger = fastchat.serve.base_model_worker.logger
112
        # 恢复被fastchat覆盖的标准输出
113
        sys.stdout = sys.__stdout__
114
        sys.stderr = sys.__stderr__
115

116
        new_loop = asyncio.new_event_loop()
117
        asyncio.set_event_loop(new_loop)
118

119
        self.context_len = context_len
120
        self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
121
        self.version = None
122

123
        if not no_register and self.controller_addr:
124
            self.init_heart_beat()
125

126

127
    def count_token(self, params):
128
        prompt = params["prompt"]
129
        return {"count": len(str(prompt)), "error_code": 0}
130

131
    def generate_stream_gate(self, params: Dict):
132
        self.call_ct += 1
133

134
        try:
135
            prompt = params["prompt"]
136
            if self._is_chat(prompt):
137
                messages = self.prompt_to_messages(prompt)
138
                messages = self.validate_messages(messages)
139
            else: # 使用chat模仿续写功能,不支持历史消息
140
                messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
141

142
            p = ApiChatParams(
143
                messages=messages,
144
                temperature=params.get("temperature"),
145
                top_p=params.get("top_p"),
146
                max_tokens=params.get("max_new_tokens"),
147
                version=self.version,
148
            )
149
            for resp in self.do_chat(p):
150
                yield self._jsonify(resp)
151
        except Exception as e:
152
            yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
153

154
    def generate_gate(self, params):
155
        try:
156
            for x in self.generate_stream_gate(params):
157
                ...
158
            return json.loads(x[:-1].decode())
159
        except Exception as e:
160
            return {"error_code": 500, "text": str(e)}
161

162

163
    # 需要用户自定义的方法
164

165
    def do_chat(self, params: ApiChatParams) -> Dict:
166
        '''
167
        执行Chat的方法,默认使用模块里面的chat函数。
168
        要求返回形式:{"error_code": int, "text": str}
169
        '''
170
        return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
171

172
    # def do_completion(self, p: ApiCompletionParams) -> Dict:
173
    #     '''
174
    #     执行Completion的方法,默认使用模块里面的completion函数。
175
    #     要求返回形式:{"error_code": int, "text": str}
176
    #     '''
177
    #     return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
178

179
    def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
180
        '''
181
        执行Embeddings的方法,默认使用模块里面的embed_documents函数。
182
        要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
183
        '''
184
        return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
185

186
    def get_embeddings(self, params):
187
        # fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
188
        # 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
189
        print("get_embedding")
190
        print(params)
191

192
    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
193
        raise NotImplementedError
194

195
    def validate_messages(self, messages: List[Dict]) -> List[Dict]:
196
        '''
197
        有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
198
        之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
199
        '''
200
        return messages
201

202

203
    # help methods
204
    @property
205
    def user_role(self):
206
        return self.conv.roles[0]
207

208
    @property
209
    def ai_role(self):
210
        return self.conv.roles[1]
211

212
    def _jsonify(self, data: Dict) -> str:
213
        '''
214
        将chat函数返回的结果按照fastchat openai-api-server的格式返回
215
        '''
216
        return json.dumps(data, ensure_ascii=False).encode() + b"\0"
217

218
    def _is_chat(self, prompt: str) -> bool:
219
        '''
220
        检查prompt是否由chat messages拼接而来
221
        TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
222
        '''
223
        key = f"{self.conv.sep}{self.user_role}:"
224
        return key in prompt
225

226
    def prompt_to_messages(self, prompt: str) -> List[Dict]:
227
        '''
228
        将prompt字符串拆分成messages.
229
        '''
230
        result = []
231
        user_role = self.user_role
232
        ai_role = self.ai_role
233
        user_start = user_role + ":"
234
        ai_start = ai_role + ":"
235
        for msg in prompt.split(self.conv.sep)[1:-1]:
236
            if msg.startswith(user_start):
237
                if content := msg[len(user_start):].strip():
238
                    result.append({"role": user_role, "content": content})
239
            elif msg.startswith(ai_start):
240
                if content := msg[len(ai_start):].strip():
241
                    result.append({"role": ai_role, "content": content})
242
            else:
243
                raise RuntimeError(f"unknown role in msg: {msg}")
244
        return result
245

246
    @classmethod
247
    def can_embedding(cls):
248
        return cls.DEFAULT_EMBED_MODEL is not None
249

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.