llama-index

Форк
0
128 строк · 4.5 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseGen,
10
    CompletionResponse,
11
    CompletionResponseGen,
12
    LLMMetadata,
13
)
14
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
15
from llama_index.legacy.llms.custom import CustomLLM
16
from llama_index.legacy.llms.generic_utils import chat_to_completion_decorator
17
from llama_index.legacy.llms.openai_utils import (
18
    from_openai_message_dict,
19
    to_openai_message_dicts,
20
)
21
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
22

23

24
class LlamaAPI(CustomLLM):
25
    model: str = Field(description="The llama-api model to use.")
26
    temperature: float = Field(description="The temperature to use for sampling.")
27
    max_tokens: int = Field(description="The maximum number of tokens to generate.")
28
    additional_kwargs: Dict[str, Any] = Field(
29
        default_factory=dict, description="Additional kwargs for the llama-api API."
30
    )
31

32
    _client: Any = PrivateAttr()
33

34
    def __init__(
35
        self,
36
        model: str = "llama-13b-chat",
37
        temperature: float = 0.1,
38
        max_tokens: int = DEFAULT_NUM_OUTPUTS,
39
        additional_kwargs: Optional[Dict[str, Any]] = None,
40
        api_key: Optional[str] = None,
41
        callback_manager: Optional[CallbackManager] = None,
42
        system_prompt: Optional[str] = None,
43
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
44
        completion_to_prompt: Optional[Callable[[str], str]] = None,
45
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
46
        output_parser: Optional[BaseOutputParser] = None,
47
    ) -> None:
48
        try:
49
            from llamaapi import LlamaAPI as Client
50
        except ImportError as e:
51
            raise ImportError(
52
                "llama_api not installed."
53
                "Please install it with `pip install llamaapi`."
54
            ) from e
55

56
        self._client = Client(api_key)
57

58
        super().__init__(
59
            model=model,
60
            temperature=temperature,
61
            max_tokens=max_tokens,
62
            additional_kwargs=additional_kwargs or {},
63
            callback_manager=callback_manager,
64
            system_prompt=system_prompt,
65
            messages_to_prompt=messages_to_prompt,
66
            completion_to_prompt=completion_to_prompt,
67
            pydantic_program_mode=pydantic_program_mode,
68
            output_parser=output_parser,
69
        )
70

71
    @classmethod
72
    def class_name(cls) -> str:
73
        return "llama_api_llm"
74

75
    @property
76
    def _model_kwargs(self) -> Dict[str, Any]:
77
        base_kwargs = {
78
            "model": self.model,
79
            "temperature": self.temperature,
80
            "max_length": self.max_tokens,
81
        }
82
        return {
83
            **base_kwargs,
84
            **self.additional_kwargs,
85
        }
86

87
    @property
88
    def metadata(self) -> LLMMetadata:
89
        return LLMMetadata(
90
            context_window=4096,
91
            num_output=DEFAULT_NUM_OUTPUTS,
92
            is_chat_model=True,
93
            is_function_calling_model=True,
94
            model_name="llama-api",
95
        )
96

97
    @llm_chat_callback()
98
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
99
        message_dicts = to_openai_message_dicts(messages)
100
        json_dict = {
101
            "messages": message_dicts,
102
            **self._model_kwargs,
103
            **kwargs,
104
        }
105
        response = self._client.run(json_dict).json()
106
        message_dict = response["choices"][0]["message"]
107
        message = from_openai_message_dict(message_dict)
108

109
        return ChatResponse(message=message, raw=response)
110

111
    @llm_completion_callback()
112
    def complete(
113
        self, prompt: str, formatted: bool = False, **kwargs: Any
114
    ) -> CompletionResponse:
115
        complete_fn = chat_to_completion_decorator(self.chat)
116
        return complete_fn(prompt, **kwargs)
117

118
    @llm_completion_callback()
119
    def stream_complete(
120
        self, prompt: str, formatted: bool = False, **kwargs: Any
121
    ) -> CompletionResponseGen:
122
        raise NotImplementedError("stream_complete is not supported for LlamaAPI")
123

124
    @llm_chat_callback()
125
    def stream_chat(
126
        self, messages: Sequence[ChatMessage], **kwargs: Any
127
    ) -> ChatResponseGen:
128
        raise NotImplementedError("stream_chat is not supported for LlamaAPI")
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.