Langchain-Chatchat
128 строк · 4.5 Кб
1import json
2import sys
3
4from fastchat.conversation import Conversation
5from configs import TEMPERATURE
6from http import HTTPStatus
7from typing import List, Literal, Dict
8
9from fastchat import conversation as conv
10from server.model_workers.base import *
11from server.model_workers.base import ApiEmbeddingsParams
12from configs import logger, log_verbose
13
14
15class QwenWorker(ApiModelWorker):
16DEFAULT_EMBED_MODEL = "text-embedding-v1"
17
18def __init__(
19self,
20*,
21version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
22model_names: List[str] = ["qwen-api"],
23controller_addr: str = None,
24worker_addr: str = None,
25**kwargs,
26):
27kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
28kwargs.setdefault("context_len", 16384)
29super().__init__(**kwargs)
30self.version = version
31
32def do_chat(self, params: ApiChatParams) -> Dict:
33import dashscope
34params.load_config(self.model_names[0])
35if log_verbose:
36logger.info(f'{self.__class__.__name__}:params: {params}')
37
38gen = dashscope.Generation()
39responses = gen.call(
40model=params.version,
41temperature=params.temperature,
42api_key=params.api_key,
43messages=params.messages,
44result_format='message', # set the result is message format.
45stream=True,
46)
47
48for resp in responses:
49if resp["status_code"] == 200:
50if choices := resp["output"]["choices"]:
51yield {
52"error_code": 0,
53"text": choices[0]["message"]["content"],
54}
55else:
56data = {
57"error_code": resp["status_code"],
58"text": resp["message"],
59"error": {
60"message": resp["message"],
61"type": "invalid_request_error",
62"param": None,
63"code": None,
64}
65}
66self.logger.error(f"请求千问 API 时发生错误:{data}")
67yield data
68
69def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
70import dashscope
71params.load_config(self.model_names[0])
72if log_verbose:
73logger.info(f'{self.__class__.__name__}:params: {params}')
74result = []
75i = 0
76while i < len(params.texts):
77texts = params.texts[i:i+25]
78resp = dashscope.TextEmbedding.call(
79model=params.embed_model or self.DEFAULT_EMBED_MODEL,
80input=texts, # 最大25行
81api_key=params.api_key,
82)
83if resp["status_code"] != 200:
84data = {
85"code": resp["status_code"],
86"msg": resp.message,
87"error": {
88"message": resp["message"],
89"type": "invalid_request_error",
90"param": None,
91"code": None,
92}
93}
94self.logger.error(f"请求千问 API 时发生错误:{data}")
95return data
96else:
97embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
98result += embeddings
99i += 25
100return {"code": 200, "data": result}
101
102def get_embeddings(self, params):
103print("embedding")
104print(params)
105
106def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
107return conv.Conversation(
108name=self.model_names[0],
109system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
110messages=[],
111roles=["user", "assistant", "system"],
112sep="\n### ",
113stop_str="###",
114)
115
116
117if __name__ == "__main__":
118import uvicorn
119from server.utils import MakeFastAPIOffline
120from fastchat.serve.model_worker import app
121
122worker = QwenWorker(
123controller_addr="http://127.0.0.1:20001",
124worker_addr="http://127.0.0.1:20007",
125)
126sys.modules["fastchat.serve.model_worker"].worker = worker
127MakeFastAPIOffline(app)
128uvicorn.run(app, port=20007)
129