MetaGPT

Форк
0
/
qianfan_api.py 
132 строки · 6.1 Кб
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
# @Desc   : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
4
import copy
5
import os
6

7
import qianfan
8
from qianfan import ChatCompletion
9
from qianfan.resources.typing import JsonBody
10

11
from metagpt.configs.llm_config import LLMConfig, LLMType
12
from metagpt.const import USE_CONFIG_TIMEOUT
13
from metagpt.logs import log_llm_stream
14
from metagpt.provider.base_llm import BaseLLM
15
from metagpt.provider.llm_provider_registry import register_provider
16
from metagpt.utils.cost_manager import CostManager
17
from metagpt.utils.token_counter import (
18
    QIANFAN_ENDPOINT_TOKEN_COSTS,
19
    QIANFAN_MODEL_TOKEN_COSTS,
20
)
21

22

23
@register_provider(LLMType.QIANFAN)
24
class QianFanLLM(BaseLLM):
25
    """
26
    Refs
27
        Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
28
        Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
29
        Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
30
                https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
31
    """
32

33
    def __init__(self, config: LLMConfig):
34
        self.config = config
35
        self.use_system_prompt = False  # only some ERNIE-x related models support system_prompt
36
        self.__init_qianfan()
37
        self.cost_manager = CostManager(token_costs=self.token_costs)
38

39
    def __init_qianfan(self):
40
        if self.config.access_key and self.config.secret_key:
41
            # for system level auth, use access_key and secret_key, recommended by official
42
            # set environment variable due to official recommendation
43
            os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
44
            os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
45
        elif self.config.api_key and self.config.secret_key:
46
            # for application level auth, use api_key and secret_key
47
            # set environment variable due to official recommendation
48
            os.environ.setdefault("QIANFAN_AK", self.config.api_key)
49
            os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
50
        else:
51
            raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
52

53
        support_system_pairs = [
54
            ("ERNIE-Bot-4", "completions_pro"),  # (model, corresponding-endpoint)
55
            ("ERNIE-Bot-8k", "ernie_bot_8k"),
56
            ("ERNIE-Bot", "completions"),
57
            ("ERNIE-Bot-turbo", "eb-instant"),
58
            ("ERNIE-Speed", "ernie_speed"),
59
            ("EB-turbo-AppBuilder", "ai_apaas"),
60
        ]
61
        if self.config.model in [pair[0] for pair in support_system_pairs]:
62
            # only some ERNIE models support
63
            self.use_system_prompt = True
64
        if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
65
            self.use_system_prompt = True
66

67
        assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
68
        assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
69

70
        self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
71
        self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
72

73
        # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
74
        self.calc_usage = self.config.calc_usage and self.config.endpoint is None
75
        self.aclient: ChatCompletion = qianfan.ChatCompletion()
76

77
    def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
78
        kwargs = {
79
            "messages": messages,
80
            "stream": stream,
81
        }
82
        if self.config.temperature > 0:
83
            # different model has default temperature. only set when it's specified.
84
            kwargs["temperature"] = self.config.temperature
85
        if self.config.endpoint:
86
            kwargs["endpoint"] = self.config.endpoint
87
        elif self.config.model:
88
            kwargs["model"] = self.config.model
89

90
        if self.use_system_prompt:
91
            # if the model support system prompt, extract and pass it
92
            if messages[0]["role"] == "system":
93
                kwargs["messages"] = messages[1:]
94
                kwargs["system"] = messages[0]["content"]  # set system prompt here
95
        return kwargs
96

97
    def _update_costs(self, usage: dict):
98
        """update each request's token cost"""
99
        model_or_endpoint = self.config.model or self.config.endpoint
100
        local_calc_usage = model_or_endpoint in self.token_costs
101
        super()._update_costs(usage, model_or_endpoint, local_calc_usage)
102

103
    def get_choice_text(self, resp: JsonBody) -> str:
104
        return resp.get("result", "")
105

106
    def completion(self, messages: list[dict]) -> JsonBody:
107
        resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
108
        self._update_costs(resp.body.get("usage", {}))
109
        return resp.body
110

111
    async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
112
        resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
113
        self._update_costs(resp.body.get("usage", {}))
114
        return resp.body
115

116
    async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
117
        return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
118

119
    async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
120
        resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
121
        collected_content = []
122
        usage = {}
123
        async for chunk in resp:
124
            content = chunk.body.get("result", "")
125
            usage = chunk.body.get("usage", {})
126
            log_llm_stream(content)
127
            collected_content.append(content)
128
        log_llm_stream("\n")
129

130
        self._update_costs(usage)
131
        full_content = "".join(collected_content)
132
        return full_content
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.