6
from typing import Optional
8
from zhipuai.types.chat.chat_completion import Completion
10
from metagpt.configs.llm_config import LLMConfig, LLMType
11
from metagpt.const import USE_CONFIG_TIMEOUT
12
from metagpt.logs import log_llm_stream
13
from metagpt.provider.base_llm import BaseLLM
14
from metagpt.provider.llm_provider_registry import register_provider
15
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
16
from metagpt.utils.cost_manager import CostManager
19
class ZhiPuEvent(Enum):
22
INTERRUPTED = "interrupted"
26
@register_provider(LLMType.ZHIPUAI)
27
class ZhiPuAILLM(BaseLLM):
29
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
30
From now, support glm-3-turbo、glm-4, and also system_prompt.
33
def __init__(self, config: LLMConfig):
36
self.cost_manager: Optional[CostManager] = None
38
def __init_zhipuai(self):
39
assert self.config.api_key
40
self.api_key = self.config.api_key
41
self.model = self.config.model
42
self.pricing_plan = self.config.pricing_plan or self.model
43
self.llm = ZhiPuModelAPI(api_key=self.api_key)
45
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
46
max_tokens = self.config.max_token if self.config.max_token > 0 else 1024
47
temperature = self.config.temperature if self.config.temperature > 0.0 else 0.3
50
"max_tokens": max_tokens,
53
"temperature": temperature,
57
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
58
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
59
usage = resp.usage.model_dump()
60
self._update_costs(usage)
61
return resp.model_dump()
63
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
64
resp = await self.llm.acreate(**self._const_kwargs(messages))
65
usage = resp.get("usage", {})
66
self._update_costs(usage)
69
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
70
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
72
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
73
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
74
collected_content = []
76
async for chunk in response.stream():
77
finish_reason = chunk.get("choices")[0].get("finish_reason")
78
if finish_reason == "stop":
79
usage = chunk.get("usage", {})
81
content = self.get_choice_delta_text(chunk)
82
collected_content.append(content)
83
log_llm_stream(content)
87
self._update_costs(usage)
88
full_content = "".join(collected_content)