llama-index
315 строк · 11.0 Кб
1"""DashScope llm api."""
2
3from http import HTTPStatus4from typing import Any, Dict, List, Optional, Sequence, Tuple5
6from llama_index.legacy.bridge.pydantic import Field7from llama_index.legacy.callbacks import CallbackManager8from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE9from llama_index.legacy.core.llms.types import (10ChatMessage,11ChatResponse,12ChatResponseGen,13CompletionResponse,14CompletionResponseGen,15LLMMetadata,16MessageRole,17)
18from llama_index.legacy.llms.base import (19llm_chat_callback,20llm_completion_callback,21)
22from llama_index.legacy.llms.custom import CustomLLM23from llama_index.legacy.llms.dashscope_utils import (24chat_message_to_dashscope_messages,25dashscope_response_to_chat_response,26dashscope_response_to_completion_response,27)
28
29
30class DashScopeGenerationModels:31"""DashScope Qwen serial models."""32
33QWEN_TURBO = "qwen-turbo"34QWEN_PLUS = "qwen-plus"35QWEN_MAX = "qwen-max"36QWEN_MAX_1201 = "qwen-max-1201"37QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"38
39
40DASHSCOPE_MODEL_META = {41DashScopeGenerationModels.QWEN_TURBO: {42"context_window": 1024 * 8,43"num_output": 1024 * 8,44"is_chat_model": True,45},46DashScopeGenerationModels.QWEN_PLUS: {47"context_window": 1024 * 32,48"num_output": 1024 * 32,49"is_chat_model": True,50},51DashScopeGenerationModels.QWEN_MAX: {52"context_window": 1024 * 8,53"num_output": 1024 * 8,54"is_chat_model": True,55},56DashScopeGenerationModels.QWEN_MAX_1201: {57"context_window": 1024 * 8,58"num_output": 1024 * 8,59"is_chat_model": True,60},61DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: {62"context_window": 1024 * 30,63"num_output": 1024 * 30,64"is_chat_model": True,65},66}
67
68
69def call_with_messages(70model: str,71messages: List[Dict],72parameters: Optional[Dict] = None,73api_key: Optional[str] = None,74**kwargs: Any,75) -> Dict:76try:77from dashscope import Generation78except ImportError:79raise ValueError(80"DashScope is not installed. Please install it with "81"`pip install dashscope`."82)83return Generation.call(84model=model, messages=messages, api_key=api_key, **parameters85)86
87
88class DashScope(CustomLLM):89"""DashScope LLM."""90
91model_name: str = Field(92default=DashScopeGenerationModels.QWEN_MAX,93description="The DashScope model to use.",94)95max_tokens: Optional[int] = Field(96description="The maximum number of tokens to generate.",97default=DEFAULT_NUM_OUTPUTS,98gt=0,99)100incremental_output: Optional[bool] = Field(101description="Control stream output, If False, the subsequent \102output will include the content that has been \
103output previously.",104default=True,105)106enable_search: Optional[bool] = Field(107description="The model has a built-in Internet search service. \108This parameter controls whether the model refers to \
109the Internet search results when generating text.",110default=False,111)112stop: Optional[Any] = Field(113description="str, list of str or token_id, list of token id. It will automatically \114stop when the generated content is about to contain the specified string \115or token_ids, and the generated content does not contain \116the specified content.",117default=None,118)119temperature: Optional[float] = Field(120description="The temperature to use during generation.",121default=DEFAULT_TEMPERATURE,122gte=0.0,123lte=2.0,124)125top_k: Optional[int] = Field(126description="Sample counter when generate.", default=None127)128top_p: Optional[float] = Field(129description="Sample probability threshold when generate."130)131seed: Optional[int] = Field(132description="Random seed when generate.", default=1234, gte=0133)134repetition_penalty: Optional[float] = Field(135description="Penalty for repeated words in generated text; \1361.0 is no penalty, values greater than 1 discourage \137repetition.",138default=None,139)140api_key: str = Field(141default=None, description="The DashScope API key.", exclude=True142)143
144def __init__(145self,146model_name: Optional[str] = DashScopeGenerationModels.QWEN_MAX,147max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS,148incremental_output: Optional[int] = True,149enable_search: Optional[bool] = False,150stop: Optional[Any] = None,151temperature: Optional[float] = DEFAULT_TEMPERATURE,152top_k: Optional[int] = None,153top_p: Optional[float] = None,154seed: Optional[int] = 1234,155api_key: Optional[str] = None,156callback_manager: Optional[CallbackManager] = None,157**kwargs: Any,158):159super().__init__(160model_name=model_name,161max_tokens=max_tokens,162incremental_output=incremental_output,163enable_search=enable_search,164stop=stop,165temperature=temperature,166top_k=top_k,167top_p=top_p,168seed=seed,169api_key=api_key,170callback_manager=callback_manager,171kwargs=kwargs,172)173
174@classmethod175def class_name(cls) -> str:176return "DashScope_LLM"177
178@property179def metadata(self) -> LLMMetadata:180DASHSCOPE_MODEL_META[self.model_name]["num_output"] = (181self.max_tokens or DASHSCOPE_MODEL_META[self.model_name]["num_output"]182)183return LLMMetadata(184model_name=self.model_name, **DASHSCOPE_MODEL_META[self.model_name]185)186
187def _get_default_parameters(self) -> Dict:188params: Dict[Any, Any] = {}189if self.max_tokens is not None:190params["max_tokens"] = self.max_tokens191params["incremental_output"] = self.incremental_output192params["enable_search"] = self.enable_search193if self.stop is not None:194params["stop"] = self.stop195if self.temperature is not None:196params["temperature"] = self.temperature197
198if self.top_k is not None:199params["top_k"] = self.top_k200
201if self.top_p is not None:202params["top_p"] = self.top_p203if self.seed is not None:204params["seed"] = self.seed205
206return params207
208def _get_input_parameters(209self, prompt: str, **kwargs: Any210) -> Tuple[ChatMessage, Dict]:211parameters = self._get_default_parameters()212parameters.update(kwargs)213parameters["stream"] = False214# we only use message response215parameters["result_format"] = "message"216message = ChatMessage(217role=MessageRole.USER.value,218content=prompt,219)220return message, parameters221
222@llm_completion_callback()223def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:224message, parameters = self._get_input_parameters(prompt=prompt, **kwargs)225parameters.pop("incremental_output", None)226parameters.pop("stream", None)227messages = chat_message_to_dashscope_messages([message])228response = call_with_messages(229model=self.model_name,230messages=messages,231api_key=self.api_key,232parameters=parameters,233)234return dashscope_response_to_completion_response(response)235
236@llm_completion_callback()237def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:238message, parameters = self._get_input_parameters(prompt=prompt, kwargs=kwargs)239parameters["incremental_output"] = True240parameters["stream"] = True241responses = call_with_messages(242model=self.model_name,243messages=chat_message_to_dashscope_messages([message]),244api_key=self.api_key,245parameters=parameters,246)247
248def gen() -> CompletionResponseGen:249content = ""250for response in responses:251if response.status_code == HTTPStatus.OK:252top_choice = response.output.choices[0]253incremental_output = top_choice["message"]["content"]254if not incremental_output:255incremental_output = ""256
257content += incremental_output258yield CompletionResponse(259text=content, delta=incremental_output, raw=response260)261else:262yield CompletionResponse(text="", raw=response)263return264
265return gen()266
267@llm_chat_callback()268def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:269parameters = self._get_default_parameters()270parameters.update({**kwargs})271parameters.pop("stream", None)272parameters.pop("incremental_output", None)273parameters["result_format"] = "message" # only use message format.274response = call_with_messages(275model=self.model_name,276messages=chat_message_to_dashscope_messages(messages),277api_key=self.api_key,278parameters=parameters,279)280return dashscope_response_to_chat_response(response)281
282@llm_chat_callback()283def stream_chat(284self, messages: Sequence[ChatMessage], **kwargs: Any285) -> ChatResponseGen:286parameters = self._get_default_parameters()287parameters.update({**kwargs})288parameters["stream"] = True289parameters["incremental_output"] = True290parameters["result_format"] = "message" # only use message format.291response = call_with_messages(292model=self.model_name,293messages=chat_message_to_dashscope_messages(messages),294api_key=self.api_key,295parameters=parameters,296)297
298def gen() -> ChatResponseGen:299content = ""300for r in response:301if r.status_code == HTTPStatus.OK:302top_choice = r.output.choices[0]303incremental_output = top_choice["message"]["content"]304role = top_choice["message"]["role"]305content += incremental_output306yield ChatResponse(307message=ChatMessage(role=role, content=content),308delta=incremental_output,309raw=r,310)311else:312yield ChatResponse(message=ChatMessage(), raw=response)313return314
315return gen()316