llama-index
128 строк · 4.5 Кб
1from typing import Any, Callable, Dict, Optional, Sequence2
3from llama_index.legacy.bridge.pydantic import Field, PrivateAttr4from llama_index.legacy.callbacks import CallbackManager5from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS6from llama_index.legacy.core.llms.types import (7ChatMessage,8ChatResponse,9ChatResponseGen,10CompletionResponse,11CompletionResponseGen,12LLMMetadata,13)
14from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback15from llama_index.legacy.llms.custom import CustomLLM16from llama_index.legacy.llms.generic_utils import chat_to_completion_decorator17from llama_index.legacy.llms.openai_utils import (18from_openai_message_dict,19to_openai_message_dicts,20)
21from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode22
23
24class LlamaAPI(CustomLLM):25model: str = Field(description="The llama-api model to use.")26temperature: float = Field(description="The temperature to use for sampling.")27max_tokens: int = Field(description="The maximum number of tokens to generate.")28additional_kwargs: Dict[str, Any] = Field(29default_factory=dict, description="Additional kwargs for the llama-api API."30)31
32_client: Any = PrivateAttr()33
34def __init__(35self,36model: str = "llama-13b-chat",37temperature: float = 0.1,38max_tokens: int = DEFAULT_NUM_OUTPUTS,39additional_kwargs: Optional[Dict[str, Any]] = None,40api_key: Optional[str] = None,41callback_manager: Optional[CallbackManager] = None,42system_prompt: Optional[str] = None,43messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,44completion_to_prompt: Optional[Callable[[str], str]] = None,45pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,46output_parser: Optional[BaseOutputParser] = None,47) -> None:48try:49from llamaapi import LlamaAPI as Client50except ImportError as e:51raise ImportError(52"llama_api not installed."53"Please install it with `pip install llamaapi`."54) from e55
56self._client = Client(api_key)57
58super().__init__(59model=model,60temperature=temperature,61max_tokens=max_tokens,62additional_kwargs=additional_kwargs or {},63callback_manager=callback_manager,64system_prompt=system_prompt,65messages_to_prompt=messages_to_prompt,66completion_to_prompt=completion_to_prompt,67pydantic_program_mode=pydantic_program_mode,68output_parser=output_parser,69)70
71@classmethod72def class_name(cls) -> str:73return "llama_api_llm"74
75@property76def _model_kwargs(self) -> Dict[str, Any]:77base_kwargs = {78"model": self.model,79"temperature": self.temperature,80"max_length": self.max_tokens,81}82return {83**base_kwargs,84**self.additional_kwargs,85}86
87@property88def metadata(self) -> LLMMetadata:89return LLMMetadata(90context_window=4096,91num_output=DEFAULT_NUM_OUTPUTS,92is_chat_model=True,93is_function_calling_model=True,94model_name="llama-api",95)96
97@llm_chat_callback()98def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:99message_dicts = to_openai_message_dicts(messages)100json_dict = {101"messages": message_dicts,102**self._model_kwargs,103**kwargs,104}105response = self._client.run(json_dict).json()106message_dict = response["choices"][0]["message"]107message = from_openai_message_dict(message_dict)108
109return ChatResponse(message=message, raw=response)110
111@llm_completion_callback()112def complete(113self, prompt: str, formatted: bool = False, **kwargs: Any114) -> CompletionResponse:115complete_fn = chat_to_completion_decorator(self.chat)116return complete_fn(prompt, **kwargs)117
118@llm_completion_callback()119def stream_complete(120self, prompt: str, formatted: bool = False, **kwargs: Any121) -> CompletionResponseGen:122raise NotImplementedError("stream_complete is not supported for LlamaAPI")123
124@llm_chat_callback()125def stream_chat(126self, messages: Sequence[ChatMessage], **kwargs: Any127) -> ChatResponseGen:128raise NotImplementedError("stream_chat is not supported for LlamaAPI")129