8
from qianfan import ChatCompletion
9
from qianfan.resources.typing import JsonBody
11
from metagpt.configs.llm_config import LLMConfig, LLMType
12
from metagpt.const import USE_CONFIG_TIMEOUT
13
from metagpt.logs import log_llm_stream
14
from metagpt.provider.base_llm import BaseLLM
15
from metagpt.provider.llm_provider_registry import register_provider
16
from metagpt.utils.cost_manager import CostManager
17
from metagpt.utils.token_counter import (
18
QIANFAN_ENDPOINT_TOKEN_COSTS,
19
QIANFAN_MODEL_TOKEN_COSTS,
23
@register_provider(LLMType.QIANFAN)
24
class QianFanLLM(BaseLLM):
27
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
28
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
29
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
30
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
33
def __init__(self, config: LLMConfig):
35
self.use_system_prompt = False
37
self.cost_manager = CostManager(token_costs=self.token_costs)
39
def __init_qianfan(self):
40
if self.config.access_key and self.config.secret_key:
43
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
44
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
45
elif self.config.api_key and self.config.secret_key:
48
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
49
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
51
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
53
support_system_pairs = [
54
("ERNIE-Bot-4", "completions_pro"),
55
("ERNIE-Bot-8k", "ernie_bot_8k"),
56
("ERNIE-Bot", "completions"),
57
("ERNIE-Bot-turbo", "eb-instant"),
58
("ERNIE-Speed", "ernie_speed"),
59
("EB-turbo-AppBuilder", "ai_apaas"),
61
if self.config.model in [pair[0] for pair in support_system_pairs]:
63
self.use_system_prompt = True
64
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
65
self.use_system_prompt = True
67
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
68
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
70
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
71
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
74
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
75
self.aclient: ChatCompletion = qianfan.ChatCompletion()
77
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
82
if self.config.temperature > 0:
84
kwargs["temperature"] = self.config.temperature
85
if self.config.endpoint:
86
kwargs["endpoint"] = self.config.endpoint
87
elif self.config.model:
88
kwargs["model"] = self.config.model
90
if self.use_system_prompt:
92
if messages[0]["role"] == "system":
93
kwargs["messages"] = messages[1:]
94
kwargs["system"] = messages[0]["content"]
97
def _update_costs(self, usage: dict):
98
"""update each request's token cost"""
99
model_or_endpoint = self.config.model or self.config.endpoint
100
local_calc_usage = model_or_endpoint in self.token_costs
101
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
103
def get_choice_text(self, resp: JsonBody) -> str:
104
return resp.get("result", "")
106
def completion(self, messages: list[dict]) -> JsonBody:
107
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
108
self._update_costs(resp.body.get("usage", {}))
111
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
112
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
113
self._update_costs(resp.body.get("usage", {}))
116
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
117
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
119
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
120
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
121
collected_content = []
123
async for chunk in resp:
124
content = chunk.body.get("result", "")
125
usage = chunk.body.get("usage", {})
126
log_llm_stream(content)
127
collected_content.append(content)
130
self._update_costs(usage)
131
full_content = "".join(collected_content)