llama-index

Форк
0
315 строк · 11.0 Кб
1
"""DashScope llm api."""
2

3
from http import HTTPStatus
4
from typing import Any, Dict, List, Optional, Sequence, Tuple
5

6
from llama_index.legacy.bridge.pydantic import Field
7
from llama_index.legacy.callbacks import CallbackManager
8
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    ChatResponse,
12
    ChatResponseGen,
13
    CompletionResponse,
14
    CompletionResponseGen,
15
    LLMMetadata,
16
    MessageRole,
17
)
18
from llama_index.legacy.llms.base import (
19
    llm_chat_callback,
20
    llm_completion_callback,
21
)
22
from llama_index.legacy.llms.custom import CustomLLM
23
from llama_index.legacy.llms.dashscope_utils import (
24
    chat_message_to_dashscope_messages,
25
    dashscope_response_to_chat_response,
26
    dashscope_response_to_completion_response,
27
)
28

29

30
class DashScopeGenerationModels:
31
    """DashScope Qwen serial models."""
32

33
    QWEN_TURBO = "qwen-turbo"
34
    QWEN_PLUS = "qwen-plus"
35
    QWEN_MAX = "qwen-max"
36
    QWEN_MAX_1201 = "qwen-max-1201"
37
    QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"
38

39

40
DASHSCOPE_MODEL_META = {
41
    DashScopeGenerationModels.QWEN_TURBO: {
42
        "context_window": 1024 * 8,
43
        "num_output": 1024 * 8,
44
        "is_chat_model": True,
45
    },
46
    DashScopeGenerationModels.QWEN_PLUS: {
47
        "context_window": 1024 * 32,
48
        "num_output": 1024 * 32,
49
        "is_chat_model": True,
50
    },
51
    DashScopeGenerationModels.QWEN_MAX: {
52
        "context_window": 1024 * 8,
53
        "num_output": 1024 * 8,
54
        "is_chat_model": True,
55
    },
56
    DashScopeGenerationModels.QWEN_MAX_1201: {
57
        "context_window": 1024 * 8,
58
        "num_output": 1024 * 8,
59
        "is_chat_model": True,
60
    },
61
    DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: {
62
        "context_window": 1024 * 30,
63
        "num_output": 1024 * 30,
64
        "is_chat_model": True,
65
    },
66
}
67

68

69
def call_with_messages(
70
    model: str,
71
    messages: List[Dict],
72
    parameters: Optional[Dict] = None,
73
    api_key: Optional[str] = None,
74
    **kwargs: Any,
75
) -> Dict:
76
    try:
77
        from dashscope import Generation
78
    except ImportError:
79
        raise ValueError(
80
            "DashScope is not installed. Please install it with "
81
            "`pip install dashscope`."
82
        )
83
    return Generation.call(
84
        model=model, messages=messages, api_key=api_key, **parameters
85
    )
86

87

88
class DashScope(CustomLLM):
89
    """DashScope LLM."""
90

91
    model_name: str = Field(
92
        default=DashScopeGenerationModels.QWEN_MAX,
93
        description="The DashScope model to use.",
94
    )
95
    max_tokens: Optional[int] = Field(
96
        description="The maximum number of tokens to generate.",
97
        default=DEFAULT_NUM_OUTPUTS,
98
        gt=0,
99
    )
100
    incremental_output: Optional[bool] = Field(
101
        description="Control stream output, If False, the subsequent \
102
                                                            output will include the content that has been \
103
                                                            output previously.",
104
        default=True,
105
    )
106
    enable_search: Optional[bool] = Field(
107
        description="The model has a built-in Internet search service. \
108
                                                            This parameter controls whether the model refers to \
109
                                                            the Internet search results when generating text.",
110
        default=False,
111
    )
112
    stop: Optional[Any] = Field(
113
        description="str, list of str or token_id, list of token id. It will automatically \
114
                                             stop when the generated content is about to contain the specified string \
115
                                             or token_ids, and the generated content does not contain \
116
                                             the specified content.",
117
        default=None,
118
    )
119
    temperature: Optional[float] = Field(
120
        description="The temperature to use during generation.",
121
        default=DEFAULT_TEMPERATURE,
122
        gte=0.0,
123
        lte=2.0,
124
    )
125
    top_k: Optional[int] = Field(
126
        description="Sample counter when generate.", default=None
127
    )
128
    top_p: Optional[float] = Field(
129
        description="Sample probability threshold when generate."
130
    )
131
    seed: Optional[int] = Field(
132
        description="Random seed when generate.", default=1234, gte=0
133
    )
134
    repetition_penalty: Optional[float] = Field(
135
        description="Penalty for repeated words in generated text; \
136
                                                             1.0 is no penalty, values greater than 1 discourage \
137
                                                             repetition.",
138
        default=None,
139
    )
140
    api_key: str = Field(
141
        default=None, description="The DashScope API key.", exclude=True
142
    )
143

144
    def __init__(
145
        self,
146
        model_name: Optional[str] = DashScopeGenerationModels.QWEN_MAX,
147
        max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS,
148
        incremental_output: Optional[int] = True,
149
        enable_search: Optional[bool] = False,
150
        stop: Optional[Any] = None,
151
        temperature: Optional[float] = DEFAULT_TEMPERATURE,
152
        top_k: Optional[int] = None,
153
        top_p: Optional[float] = None,
154
        seed: Optional[int] = 1234,
155
        api_key: Optional[str] = None,
156
        callback_manager: Optional[CallbackManager] = None,
157
        **kwargs: Any,
158
    ):
159
        super().__init__(
160
            model_name=model_name,
161
            max_tokens=max_tokens,
162
            incremental_output=incremental_output,
163
            enable_search=enable_search,
164
            stop=stop,
165
            temperature=temperature,
166
            top_k=top_k,
167
            top_p=top_p,
168
            seed=seed,
169
            api_key=api_key,
170
            callback_manager=callback_manager,
171
            kwargs=kwargs,
172
        )
173

174
    @classmethod
175
    def class_name(cls) -> str:
176
        return "DashScope_LLM"
177

178
    @property
179
    def metadata(self) -> LLMMetadata:
180
        DASHSCOPE_MODEL_META[self.model_name]["num_output"] = (
181
            self.max_tokens or DASHSCOPE_MODEL_META[self.model_name]["num_output"]
182
        )
183
        return LLMMetadata(
184
            model_name=self.model_name, **DASHSCOPE_MODEL_META[self.model_name]
185
        )
186

187
    def _get_default_parameters(self) -> Dict:
188
        params: Dict[Any, Any] = {}
189
        if self.max_tokens is not None:
190
            params["max_tokens"] = self.max_tokens
191
        params["incremental_output"] = self.incremental_output
192
        params["enable_search"] = self.enable_search
193
        if self.stop is not None:
194
            params["stop"] = self.stop
195
        if self.temperature is not None:
196
            params["temperature"] = self.temperature
197

198
        if self.top_k is not None:
199
            params["top_k"] = self.top_k
200

201
        if self.top_p is not None:
202
            params["top_p"] = self.top_p
203
        if self.seed is not None:
204
            params["seed"] = self.seed
205

206
        return params
207

208
    def _get_input_parameters(
209
        self, prompt: str, **kwargs: Any
210
    ) -> Tuple[ChatMessage, Dict]:
211
        parameters = self._get_default_parameters()
212
        parameters.update(kwargs)
213
        parameters["stream"] = False
214
        # we only use message response
215
        parameters["result_format"] = "message"
216
        message = ChatMessage(
217
            role=MessageRole.USER.value,
218
            content=prompt,
219
        )
220
        return message, parameters
221

222
    @llm_completion_callback()
223
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
224
        message, parameters = self._get_input_parameters(prompt=prompt, **kwargs)
225
        parameters.pop("incremental_output", None)
226
        parameters.pop("stream", None)
227
        messages = chat_message_to_dashscope_messages([message])
228
        response = call_with_messages(
229
            model=self.model_name,
230
            messages=messages,
231
            api_key=self.api_key,
232
            parameters=parameters,
233
        )
234
        return dashscope_response_to_completion_response(response)
235

236
    @llm_completion_callback()
237
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
238
        message, parameters = self._get_input_parameters(prompt=prompt, kwargs=kwargs)
239
        parameters["incremental_output"] = True
240
        parameters["stream"] = True
241
        responses = call_with_messages(
242
            model=self.model_name,
243
            messages=chat_message_to_dashscope_messages([message]),
244
            api_key=self.api_key,
245
            parameters=parameters,
246
        )
247

248
        def gen() -> CompletionResponseGen:
249
            content = ""
250
            for response in responses:
251
                if response.status_code == HTTPStatus.OK:
252
                    top_choice = response.output.choices[0]
253
                    incremental_output = top_choice["message"]["content"]
254
                    if not incremental_output:
255
                        incremental_output = ""
256

257
                    content += incremental_output
258
                    yield CompletionResponse(
259
                        text=content, delta=incremental_output, raw=response
260
                    )
261
                else:
262
                    yield CompletionResponse(text="", raw=response)
263
                    return
264

265
        return gen()
266

267
    @llm_chat_callback()
268
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
269
        parameters = self._get_default_parameters()
270
        parameters.update({**kwargs})
271
        parameters.pop("stream", None)
272
        parameters.pop("incremental_output", None)
273
        parameters["result_format"] = "message"  # only use message format.
274
        response = call_with_messages(
275
            model=self.model_name,
276
            messages=chat_message_to_dashscope_messages(messages),
277
            api_key=self.api_key,
278
            parameters=parameters,
279
        )
280
        return dashscope_response_to_chat_response(response)
281

282
    @llm_chat_callback()
283
    def stream_chat(
284
        self, messages: Sequence[ChatMessage], **kwargs: Any
285
    ) -> ChatResponseGen:
286
        parameters = self._get_default_parameters()
287
        parameters.update({**kwargs})
288
        parameters["stream"] = True
289
        parameters["incremental_output"] = True
290
        parameters["result_format"] = "message"  # only use message format.
291
        response = call_with_messages(
292
            model=self.model_name,
293
            messages=chat_message_to_dashscope_messages(messages),
294
            api_key=self.api_key,
295
            parameters=parameters,
296
        )
297

298
        def gen() -> ChatResponseGen:
299
            content = ""
300
            for r in response:
301
                if r.status_code == HTTPStatus.OK:
302
                    top_choice = r.output.choices[0]
303
                    incremental_output = top_choice["message"]["content"]
304
                    role = top_choice["message"]["role"]
305
                    content += incremental_output
306
                    yield ChatResponse(
307
                        message=ChatMessage(role=role, content=content),
308
                        delta=incremental_output,
309
                        raw=r,
310
                    )
311
                else:
312
                    yield ChatResponse(message=ChatMessage(), raw=response)
313
                    return
314

315
        return gen()
316

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.