6
import _thread as thread
13
from time import mktime
14
from urllib.parse import urlencode, urlparse
15
from wsgiref.handlers import format_date_time
19
from metagpt.configs.llm_config import LLMConfig, LLMType
20
from metagpt.const import USE_CONFIG_TIMEOUT
21
from metagpt.logs import logger
22
from metagpt.provider.base_llm import BaseLLM
23
from metagpt.provider.llm_provider_registry import register_provider
26
@register_provider(LLMType.SPARK)
27
class SparkLLM(BaseLLM):
28
def __init__(self, config: LLMConfig):
30
logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
32
def get_choice_text(self, rsp: dict) -> str:
33
return rsp["payload"]["choices"]["text"][-1]["content"]
35
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
38
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
41
w = GetMessageFromWeb(messages, self.config)
44
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
47
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
49
w = GetMessageFromWeb(messages, self.config)
53
class GetMessageFromWeb:
57
输入 app_id, api_key, api_secret, spark_url以初始化,
62
def __init__(self, app_id, api_key, api_secret, spark_url, message=None):
64
self.api_key = api_key
65
self.api_secret = api_secret
66
self.host = urlparse(spark_url).netloc
67
self.path = urlparse(spark_url).path
68
self.spark_url = spark_url
69
self.message = message
74
now = datetime.datetime.now()
75
date = format_date_time(mktime(now.timetuple()))
78
signature_origin = "host: " + self.host + "\n"
79
signature_origin += "date: " + date + "\n"
80
signature_origin += "GET " + self.path + " HTTP/1.1"
83
signature_sha = hmac.new(
84
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
87
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
89
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
91
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
94
v = {"authorization": authorization, "date": date, "host": self.host}
96
url = self.spark_url + "?" + urlencode(v)
100
def __init__(self, text, config: LLMConfig):
103
self.spark_appid = config.app_id
104
self.spark_api_secret = config.api_secret
105
self.spark_api_key = config.api_key
106
self.domain = config.domain
107
self.spark_url = config.base_url
109
def on_message(self, ws, message):
110
data = json.loads(message)
111
code = data["header"]["code"]
115
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
118
choices = data["payload"]["choices"]
120
status = choices["status"]
121
content = choices["text"][0]["content"]
127
def on_error(self, ws, error):
129
logger.critical(f"通讯连接出错,【错误提示: {error}】")
132
def on_close(self, ws, one, two):
136
def gen_params(self):
138
"header": {"app_id": self.spark_appid, "uid": "1234"},
142
"domain": self.domain,
150
"payload": {"message": {"text": self.text}},
154
def send(self, ws, *args):
155
data = json.dumps(self.gen_params())
159
def on_open(self, ws):
160
thread.start_new_thread(self.send, (ws,))
164
return self._run(self.text)
166
def _run(self, text_list):
167
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
168
ws_url = ws_param.create_url()
170
websocket.enableTrace(False)
171
ws = websocket.WebSocketApp(
172
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
174
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})