llama-index

Форк
0
188 строк · 6.7 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    CompletionResponse,
10
    CompletionResponseGen,
11
    LLMMetadata,
12
)
13
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
14
from llama_index.legacy.llms.custom import CustomLLM
15
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
16

17
DEFAULT_MONSTER_TEMP = 0.75
18

19

20
class MonsterLLM(CustomLLM):
21
    model: str = Field(description="The MonsterAPI model to use.")
22
    monster_api_key: Optional[str] = Field(description="The MonsterAPI key to use.")
23
    max_new_tokens: int = Field(
24
        default=DEFAULT_NUM_OUTPUTS,
25
        description="The number of tokens to generate.",
26
        gt=0,
27
    )
28
    temperature: float = Field(
29
        default=DEFAULT_MONSTER_TEMP,
30
        description="The temperature to use for sampling.",
31
        gte=0.0,
32
        lte=1.0,
33
    )
34
    context_window: int = Field(
35
        default=DEFAULT_CONTEXT_WINDOW,
36
        description="The number of context tokens available to the LLM.",
37
        gt=0,
38
    )
39

40
    _client: Any = PrivateAttr()
41

42
    def __init__(
43
        self,
44
        model: str,
45
        base_url: str = "https://api.monsterapi.ai/v1",
46
        monster_api_key: Optional[str] = None,
47
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
48
        temperature: float = DEFAULT_MONSTER_TEMP,
49
        context_window: int = DEFAULT_CONTEXT_WINDOW,
50
        callback_manager: Optional[CallbackManager] = None,
51
        system_prompt: Optional[str] = None,
52
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
53
        completion_to_prompt: Optional[Callable[[str], str]] = None,
54
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
55
        output_parser: Optional[BaseOutputParser] = None,
56
    ) -> None:
57
        self._client, available_llms = self.initialize_client(monster_api_key, base_url)
58

59
        # Check if provided model is supported
60
        if model not in available_llms:
61
            error_message = (
62
                f"Model: {model} is not supported. "
63
                f"Supported models are {available_llms}. "
64
                "Please update monsterapiclient to see if any models are added. "
65
                "pip install --upgrade monsterapi"
66
            )
67
            raise RuntimeError(error_message)
68

69
        super().__init__(
70
            model=model,
71
            monster_api_key=monster_api_key,
72
            max_new_tokens=max_new_tokens,
73
            temperature=temperature,
74
            context_window=context_window,
75
            callback_manager=callback_manager,
76
            system_prompt=system_prompt,
77
            messages_to_prompt=messages_to_prompt,
78
            completion_to_prompt=completion_to_prompt,
79
            pydantic_program_mode=pydantic_program_mode,
80
            output_parser=output_parser,
81
        )
82

83
    def initialize_client(
84
        self, monster_api_key: Optional[str], base_url: Optional[str]
85
    ) -> Any:
86
        try:
87
            from monsterapi import client as MonsterClient
88
            from monsterapi.InputDataModels import MODEL_TYPES
89
        except ImportError:
90
            raise ImportError(
91
                "Could not import Monster API client library."
92
                "Please install it with `pip install monsterapi`"
93
            )
94

95
        llm_models_enabled = [i for i, j in MODEL_TYPES.items() if j == "LLM"]
96

97
        return MonsterClient(monster_api_key, base_url), llm_models_enabled
98

99
    @classmethod
100
    def class_name(cls) -> str:
101
        return "MonsterLLM"
102

103
    @property
104
    def metadata(self) -> LLMMetadata:
105
        """Get LLM metadata."""
106
        return LLMMetadata(
107
            context_window=self.context_window,
108
            num_output=self.max_new_tokens,
109
            model_name=self.model,
110
        )
111

112
    def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
113
        return {
114
            "prompt": prompt,
115
            "temperature": self.temperature,
116
            "max_length": self.max_new_tokens,
117
            **kwargs,
118
        }
119

120
    @llm_chat_callback()
121
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
122
        prompt = self.messages_to_prompt(messages)
123
        return self.complete(prompt, formatted=True, **kwargs)
124

125
    @llm_completion_callback()
126
    def complete(
127
        self, prompt: str, formatted: bool = False, timeout: int = 100, **kwargs: Any
128
    ) -> CompletionResponse:
129
        if not formatted:
130
            prompt = self.completion_to_prompt(prompt)
131

132
        stream = kwargs.pop("stream", False)
133

134
        if stream is True:
135
            raise NotImplementedError(
136
                "complete method cannot be used with stream=True, please use stream_complete method"
137
            )
138

139
        # Validate input args against input Pydantic model
140
        input_dict = self._get_input_dict(prompt, **kwargs)
141

142
        result = self._client.generate(
143
            model=self.model, data=input_dict, timeout=timeout
144
        )
145

146
        if isinstance(result, Exception):
147
            raise result
148

149
        if isinstance(result, dict) and "error" in result:
150
            raise RuntimeError(result["error"])
151

152
        if isinstance(result, dict) and "text" in result:
153
            if isinstance(result["text"], list):
154
                return CompletionResponse(text=result["text"][0])
155
            elif isinstance(result["text"], str):
156
                return CompletionResponse(text=result["text"])
157

158
        if isinstance(result, list):
159
            return CompletionResponse(text=result[0]["text"])
160

161
        raise RuntimeError("Unexpected Return please contact monsterapi support!")
162

163
    @llm_completion_callback()
164
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
165
        if "deploy" not in self.model:
166
            raise NotImplementedError(
167
                "stream_complete method can only be used with deploy models for now. Support for other models will be added soon."
168
            )
169

170
        # Validate input args against input Pydantic model
171
        input_dict = self._get_input_dict(prompt, **kwargs)
172
        input_dict["stream"] = True
173

174
        # Starting the stream
175
        result_stream = self._client.generate(model=self.model, data=input_dict)
176

177
        if isinstance(result_stream, Exception):
178
            raise result_stream
179

180
        if isinstance(result_stream, dict) and "error" in result_stream:
181
            raise RuntimeError(result_stream["error"])
182

183
        # Iterating over the generator
184
        try:
185
            for result in result_stream:
186
                yield CompletionResponse(text=result[0])
187
        except StopIteration:
188
            pass
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.