Langchain-Chatchat
100 строк · 3.7 Кб
1from fastchat.conversation import Conversation2from server.model_workers.base import *3from fastchat import conversation as conv4import sys5import json6from server.model_workers import SparkApi7import websockets8from server.utils import iter_over_async, asyncio9from typing import List, Dict10
11
12async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):13wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)14wsUrl = wsParam.create_url()15data = SparkApi.gen_params(appid, domain, question, temperature, max_token)16async with websockets.connect(wsUrl) as ws:17await ws.send(json.dumps(data, ensure_ascii=False))18finish = False19while not finish:20chunk = await ws.recv()21response = json.loads(chunk)22if response.get("header", {}).get("status") == 2:23finish = True24if text := response.get("payload", {}).get("choices", {}).get("text"):25yield text[0]["content"]26
27
28class XingHuoWorker(ApiModelWorker):29def __init__(30self,31*,32model_names: List[str] = ["xinghuo-api"],33controller_addr: str = None,34worker_addr: str = None,35version: str = None,36**kwargs,37):38kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)39kwargs.setdefault("context_len", 8000)40super().__init__(**kwargs)41self.version = version42
43def do_chat(self, params: ApiChatParams) -> Dict:44params.load_config(self.model_names[0])45
46version_mapping = {47"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat", "max_tokens": 4000},48"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat", "max_tokens": 8000},49"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat", "max_tokens": 8000},50"v3.5": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.5/chat", "max_tokens": 16000},51}52
53def get_version_details(version_key):54return version_mapping.get(version_key, {"domain": None, "url": None})55
56details = get_version_details(params.version)57domain = details["domain"]58Spark_url = details["url"]59text = ""60try:61loop = asyncio.get_event_loop()62except:63loop = asyncio.new_event_loop()64params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)65for chunk in iter_over_async(66request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,67params.temperature, params.max_tokens),68loop=loop,69):70if chunk:71text += chunk72yield {"error_code": 0, "text": text}73
74def get_embeddings(self, params):75print("embedding")76print(params)77
78def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:79return conv.Conversation(80name=self.model_names[0],81system_message="你是一个聪明的助手,请根据用户的提示来完成任务",82messages=[],83roles=["user", "assistant"],84sep="\n### ",85stop_str="###",86)87
88
89if __name__ == "__main__":90import uvicorn91from server.utils import MakeFastAPIOffline92from fastchat.serve.model_worker import app93
94worker = XingHuoWorker(95controller_addr="http://127.0.0.1:20001",96worker_addr="http://127.0.0.1:21003",97)98sys.modules["fastchat.serve.model_worker"].worker = worker99MakeFastAPIOffline(app)100uvicorn.run(app, port=21003)101