MetaGPT

Форк
0
/
spark_api.py 
175 строк · 6.8 Кб
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
"""
4
@File    : spark_api.py
5
"""
6
import _thread as thread
7
import base64
8
import datetime
9
import hashlib
10
import hmac
11
import json
12
import ssl
13
from time import mktime
14
from urllib.parse import urlencode, urlparse
15
from wsgiref.handlers import format_date_time
16

17
import websocket  # 使用websocket_client
18

19
from metagpt.configs.llm_config import LLMConfig, LLMType
20
from metagpt.const import USE_CONFIG_TIMEOUT
21
from metagpt.logs import logger
22
from metagpt.provider.base_llm import BaseLLM
23
from metagpt.provider.llm_provider_registry import register_provider
24

25

26
@register_provider(LLMType.SPARK)
27
class SparkLLM(BaseLLM):
28
    def __init__(self, config: LLMConfig):
29
        self.config = config
30
        logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
31

32
    def get_choice_text(self, rsp: dict) -> str:
33
        return rsp["payload"]["choices"]["text"][-1]["content"]
34

35
    async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
36
        pass
37

38
    async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
39
        # 不支持
40
        # logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
41
        w = GetMessageFromWeb(messages, self.config)
42
        return w.run()
43

44
    async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
45
        pass
46

47
    async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
48
        # 不支持异步
49
        w = GetMessageFromWeb(messages, self.config)
50
        return w.run()
51

52

53
class GetMessageFromWeb:
54
    class WsParam:
55
        """
56
        该类适合讯飞星火大部分接口的调用。
57
        输入 app_id, api_key, api_secret, spark_url以初始化,
58
        create_url方法返回接口url
59
        """
60

61
        # 初始化
62
        def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
63
            self.app_id = app_id
64
            self.api_key = api_key
65
            self.api_secret = api_secret
66
            self.host = urlparse(spark_url).netloc
67
            self.path = urlparse(spark_url).path
68
            self.spark_url = spark_url
69
            self.message = message
70

71
        # 生成url
72
        def create_url(self):
73
            # 生成RFC1123格式的时间戳
74
            now = datetime.datetime.now()
75
            date = format_date_time(mktime(now.timetuple()))
76

77
            # 拼接字符串
78
            signature_origin = "host: " + self.host + "\n"
79
            signature_origin += "date: " + date + "\n"
80
            signature_origin += "GET " + self.path + " HTTP/1.1"
81

82
            # 进行hmac-sha256进行加密
83
            signature_sha = hmac.new(
84
                self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
85
            ).digest()
86

87
            signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
88

89
            authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
90

91
            authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
92

93
            # 将请求的鉴权参数组合为字典
94
            v = {"authorization": authorization, "date": date, "host": self.host}
95
            # 拼接鉴权参数,生成url
96
            url = self.spark_url + "?" + urlencode(v)
97
            # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
98
            return url
99

100
    def __init__(self, text, config: LLMConfig):
101
        self.text = text
102
        self.ret = ""
103
        self.spark_appid = config.app_id
104
        self.spark_api_secret = config.api_secret
105
        self.spark_api_key = config.api_key
106
        self.domain = config.domain
107
        self.spark_url = config.base_url
108

109
    def on_message(self, ws, message):
110
        data = json.loads(message)
111
        code = data["header"]["code"]
112

113
        if code != 0:
114
            ws.close()  # 请求错误,则关闭socket
115
            logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
116
            return
117
        else:
118
            choices = data["payload"]["choices"]
119
            # seq = choices["seq"]  # 服务端是流式返回,seq为返回的数据序号
120
            status = choices["status"]  # 服务端是流式返回,status用于判断信息是否传送完毕
121
            content = choices["text"][0]["content"]  # 本次接收到的回答文本
122
            self.ret += content
123
            if status == 2:
124
                ws.close()
125

126
    # 收到websocket错误的处理
127
    def on_error(self, ws, error):
128
        # on_message方法处理接收到的信息,出现任何错误,都会调用这个方法
129
        logger.critical(f"通讯连接出错,【错误提示: {error}】")
130

131
    # 收到websocket关闭的处理
132
    def on_close(self, ws, one, two):
133
        pass
134

135
    # 处理请求数据
136
    def gen_params(self):
137
        data = {
138
            "header": {"app_id": self.spark_appid, "uid": "1234"},
139
            "parameter": {
140
                "chat": {
141
                    # domain为必传参数
142
                    "domain": self.domain,
143
                    # 以下为可微调,非必传参数
144
                    # 注意:官方建议,temperature和top_k修改一个即可
145
                    "max_tokens": 2048,  # 默认2048,模型回答的tokens的最大长度,即允许它输出文本的最长字数
146
                    "temperature": 0.5,  # 取值为[0,1],默认为0.5。取值越高随机性越强、发散性越高,即相同的问题得到的不同答案的可能性越高
147
                    "top_k": 4,  # 取值为[1,6],默认为4。从k个候选中随机选择一个(非等概率)
148
                }
149
            },
150
            "payload": {"message": {"text": self.text}},
151
        }
152
        return data
153

154
    def send(self, ws, *args):
155
        data = json.dumps(self.gen_params())
156
        ws.send(data)
157

158
    # 收到websocket连接建立的处理
159
    def on_open(self, ws):
160
        thread.start_new_thread(self.send, (ws,))
161

162
    # 处理收到的 websocket消息,出现任何错误,调用on_error方法
163
    def run(self):
164
        return self._run(self.text)
165

166
    def _run(self, text_list):
167
        ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
168
        ws_url = ws_param.create_url()
169

170
        websocket.enableTrace(False)  # 默认禁用 WebSocket 的跟踪功能
171
        ws = websocket.WebSocketApp(
172
            ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
173
        )
174
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
175
        return self.ret
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.