Langchain-Chatchat

Форк
0
111 строк · 3.9 Кб
1
from fastchat.conversation import Conversation
2
from server.model_workers.base import *
3
from fastchat import conversation as conv
4
import sys
5
from typing import List, Literal, Dict
6
from configs import logger, log_verbose
7

8

9
class FangZhouWorker(ApiModelWorker):
10
    """
11
    火山方舟
12
    """
13

14
    def __init__(
15
            self,
16
            *,
17
            model_names: List[str] = ["fangzhou-api"],
18
            controller_addr: str = None,
19
            worker_addr: str = None,
20
            version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
21
            **kwargs,
22
    ):
23
        kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
24
        kwargs.setdefault("context_len", 16384)
25
        super().__init__(**kwargs)
26
        self.version = version
27
    def do_chat(self, params: ApiChatParams) -> Dict:
28
        from volcengine.maas.v2 import MaasService
29

30
        params.load_config(self.model_names[0])
31
        maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
32
        maas.set_ak(params.api_key)
33
        maas.set_sk(params.secret_key)
34

35
        # document: "https://www.volcengine.com/docs/82379/1099475"
36
        req = {
37
            "parameters": {
38
                # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
39
                "max_new_tokens": params.max_tokens,
40
                "temperature": params.temperature,
41
            },
42
            "messages": params.messages,
43
        }
44

45
        text = ""
46
        if log_verbose:
47
            self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
48
        for resp in maas.stream_chat(params.version, unicode_escape_data(req)):
49
            error = resp.error
50
            if error and error.code_n > 0:
51
                data = {
52
                    "error_code": error.code_n,
53
                    "text": error.message,
54
                    "error": {
55
                        "message": error.message,
56
                        "type": "invalid_request_error",
57
                        "param": None,
58
                        "code": None,
59
                    }
60
                }
61
                self.logger.error(f"请求方舟 API 时发生错误:{data}")
62
                yield data
63
            elif chunk := resp.choices and resp.choices[0].message.content:
64
                text += chunk
65
                yield {"error_code": 0, "text": text}
66
            else:
67
                data = {
68
                    "error_code": 500,
69
                    "text": f"请求方舟 API 时发生未知的错误: {resp}"
70
                }
71
                self.logger.error(data)
72
                yield data
73
                break
74

75
    def get_embeddings(self, params):
76
        print("embedding")
77
        print(params)
78

79
    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
80
        return conv.Conversation(
81
            name=self.model_names[0],
82
            system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
83
            messages=[],
84
            roles=["user", "assistant", "system"],
85
            sep="\n### ",
86
            stop_str="###",
87
        )
88

89

90
def unicode_escape_data(data):
91
    if isinstance(data, str):
92
        return data.encode('unicode_escape').decode('ascii')
93
    elif isinstance(data, dict):
94
        return {key: unicode_escape_data(value) for key, value in data.items()}
95
    elif isinstance(data, list):
96
        return [unicode_escape_data(item) for item in data]
97
    else:
98
        return data
99

100
if __name__ == "__main__":
101
    import uvicorn
102
    from server.utils import MakeFastAPIOffline
103
    from fastchat.serve.model_worker import app
104

105
    worker = FangZhouWorker(
106
        controller_addr="http://127.0.0.1:20001",
107
        worker_addr="http://127.0.0.1:21005",
108
    )
109
    sys.modules["fastchat.serve.model_worker"].worker = worker
110
    MakeFastAPIOffline(app)
111
    uvicorn.run(app, port=21005)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.