Langchain-Chatchat
248 строк · 8.2 Кб
1from fastchat.conversation import Conversation
2from configs import LOG_PATH, TEMPERATURE
3import fastchat.constants
4fastchat.constants.LOGDIR = LOG_PATH
5from fastchat.serve.base_model_worker import BaseModelWorker
6import uuid
7import json
8import sys
9from pydantic import BaseModel, root_validator
10import fastchat
11import asyncio
12from server.utils import get_model_worker_config
13from typing import Dict, List, Optional
14
15
16__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
17
18
19class ApiConfigParams(BaseModel):
20'''
21在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
22'''
23api_base_url: Optional[str] = None
24api_proxy: Optional[str] = None
25api_key: Optional[str] = None
26secret_key: Optional[str] = None
27group_id: Optional[str] = None # for minimax
28is_pro: bool = False # for minimax
29
30APPID: Optional[str] = None # for xinghuo
31APISecret: Optional[str] = None # for xinghuo
32is_v2: bool = False # for xinghuo
33
34worker_name: Optional[str] = None
35
36class Config:
37extra = "allow"
38
39@root_validator(pre=True)
40def validate_config(cls, v: Dict) -> Dict:
41if config := get_model_worker_config(v.get("worker_name")):
42for n in cls.__fields__:
43if n in config:
44v[n] = config[n]
45return v
46
47def load_config(self, worker_name: str):
48self.worker_name = worker_name
49if config := get_model_worker_config(worker_name):
50for n in self.__fields__:
51if n in config:
52setattr(self, n, config[n])
53return self
54
55
56class ApiModelParams(ApiConfigParams):
57'''
58模型配置参数
59'''
60version: Optional[str] = None
61version_url: Optional[str] = None
62api_version: Optional[str] = None # for azure
63deployment_name: Optional[str] = None # for azure
64resource_name: Optional[str] = None # for azure
65
66temperature: float = TEMPERATURE
67max_tokens: Optional[int] = None
68top_p: Optional[float] = 1.0
69
70
71class ApiChatParams(ApiModelParams):
72'''
73chat请求参数
74'''
75messages: List[Dict[str, str]]
76system_message: Optional[str] = None # for minimax
77role_meta: Dict = {} # for minimax
78
79
80class ApiCompletionParams(ApiModelParams):
81prompt: str
82
83
84class ApiEmbeddingsParams(ApiConfigParams):
85texts: List[str]
86embed_model: Optional[str] = None
87to_query: bool = False # for minimax
88
89
90class ApiModelWorker(BaseModelWorker):
91DEFAULT_EMBED_MODEL: str = None # None means not support embedding
92
93def __init__(
94self,
95model_names: List[str],
96controller_addr: str = None,
97worker_addr: str = None,
98context_len: int = 2048,
99no_register: bool = False,
100**kwargs,
101):
102kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
103kwargs.setdefault("model_path", "")
104kwargs.setdefault("limit_worker_concurrency", 5)
105super().__init__(model_names=model_names,
106controller_addr=controller_addr,
107worker_addr=worker_addr,
108**kwargs)
109import fastchat.serve.base_model_worker
110import sys
111self.logger = fastchat.serve.base_model_worker.logger
112# 恢复被fastchat覆盖的标准输出
113sys.stdout = sys.__stdout__
114sys.stderr = sys.__stderr__
115
116new_loop = asyncio.new_event_loop()
117asyncio.set_event_loop(new_loop)
118
119self.context_len = context_len
120self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
121self.version = None
122
123if not no_register and self.controller_addr:
124self.init_heart_beat()
125
126
127def count_token(self, params):
128prompt = params["prompt"]
129return {"count": len(str(prompt)), "error_code": 0}
130
131def generate_stream_gate(self, params: Dict):
132self.call_ct += 1
133
134try:
135prompt = params["prompt"]
136if self._is_chat(prompt):
137messages = self.prompt_to_messages(prompt)
138messages = self.validate_messages(messages)
139else: # 使用chat模仿续写功能,不支持历史消息
140messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
141
142p = ApiChatParams(
143messages=messages,
144temperature=params.get("temperature"),
145top_p=params.get("top_p"),
146max_tokens=params.get("max_new_tokens"),
147version=self.version,
148)
149for resp in self.do_chat(p):
150yield self._jsonify(resp)
151except Exception as e:
152yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
153
154def generate_gate(self, params):
155try:
156for x in self.generate_stream_gate(params):
157...
158return json.loads(x[:-1].decode())
159except Exception as e:
160return {"error_code": 500, "text": str(e)}
161
162
163# 需要用户自定义的方法
164
165def do_chat(self, params: ApiChatParams) -> Dict:
166'''
167执行Chat的方法,默认使用模块里面的chat函数。
168要求返回形式:{"error_code": int, "text": str}
169'''
170return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
171
172# def do_completion(self, p: ApiCompletionParams) -> Dict:
173# '''
174# 执行Completion的方法,默认使用模块里面的completion函数。
175# 要求返回形式:{"error_code": int, "text": str}
176# '''
177# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
178
179def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
180'''
181执行Embeddings的方法,默认使用模块里面的embed_documents函数。
182要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
183'''
184return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
185
186def get_embeddings(self, params):
187# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
188# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
189print("get_embedding")
190print(params)
191
192def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
193raise NotImplementedError
194
195def validate_messages(self, messages: List[Dict]) -> List[Dict]:
196'''
197有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
198之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
199'''
200return messages
201
202
203# help methods
204@property
205def user_role(self):
206return self.conv.roles[0]
207
208@property
209def ai_role(self):
210return self.conv.roles[1]
211
212def _jsonify(self, data: Dict) -> str:
213'''
214将chat函数返回的结果按照fastchat openai-api-server的格式返回
215'''
216return json.dumps(data, ensure_ascii=False).encode() + b"\0"
217
218def _is_chat(self, prompt: str) -> bool:
219'''
220检查prompt是否由chat messages拼接而来
221TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
222'''
223key = f"{self.conv.sep}{self.user_role}:"
224return key in prompt
225
226def prompt_to_messages(self, prompt: str) -> List[Dict]:
227'''
228将prompt字符串拆分成messages.
229'''
230result = []
231user_role = self.user_role
232ai_role = self.ai_role
233user_start = user_role + ":"
234ai_start = ai_role + ":"
235for msg in prompt.split(self.conv.sep)[1:-1]:
236if msg.startswith(user_start):
237if content := msg[len(user_start):].strip():
238result.append({"role": user_role, "content": content})
239elif msg.startswith(ai_start):
240if content := msg[len(ai_start):].strip():
241result.append({"role": ai_role, "content": content})
242else:
243raise RuntimeError(f"unknown role in msg: {msg}")
244return result
245
246@classmethod
247def can_embedding(cls):
248return cls.DEFAULT_EMBED_MODEL is not None
249