llama-index

Форк
0
195 строк · 6.0 Кб
1
from typing import Any, Callable, Optional, Sequence
2

3
from typing_extensions import override
4

5
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
6
from llama_index.legacy.callbacks import CallbackManager
7
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS
8
from llama_index.legacy.core.llms.types import (
9
    ChatMessage,
10
    CompletionResponse,
11
    CompletionResponseGen,
12
    LLMMetadata,
13
)
14
from llama_index.legacy.llms.base import llm_completion_callback
15
from llama_index.legacy.llms.custom import CustomLLM
16
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
17

18

19
class _BaseGradientLLM(CustomLLM):
20
    _gradient = PrivateAttr()
21
    _model = PrivateAttr()
22

23
    # Config
24
    max_tokens: Optional[int] = Field(
25
        default=DEFAULT_NUM_OUTPUTS,
26
        description="The number of tokens to generate.",
27
        gt=0,
28
        lt=512,
29
    )
30

31
    # Gradient client config
32
    access_token: Optional[str] = Field(
33
        description="The Gradient access token to use.",
34
    )
35
    host: Optional[str] = Field(
36
        description="The url of the Gradient service to access."
37
    )
38
    workspace_id: Optional[str] = Field(
39
        description="The Gradient workspace id to use.",
40
    )
41
    is_chat_model: bool = Field(
42
        default=False, description="Whether the model is a chat model."
43
    )
44

45
    def __init__(
46
        self,
47
        *,
48
        access_token: Optional[str] = None,
49
        host: Optional[str] = None,
50
        max_tokens: Optional[int] = None,
51
        workspace_id: Optional[str] = None,
52
        callback_manager: Optional[CallbackManager] = None,
53
        is_chat_model: bool = False,
54
        system_prompt: Optional[str] = None,
55
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
56
        completion_to_prompt: Optional[Callable[[str], str]] = None,
57
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
58
        output_parser: Optional[BaseOutputParser] = None,
59
        **kwargs: Any,
60
    ) -> None:
61
        super().__init__(
62
            max_tokens=max_tokens,
63
            access_token=access_token,
64
            host=host,
65
            workspace_id=workspace_id,
66
            callback_manager=callback_manager,
67
            is_chat_model=is_chat_model,
68
            system_prompt=system_prompt,
69
            messages_to_prompt=messages_to_prompt,
70
            completion_to_prompt=completion_to_prompt,
71
            pydantic_program_mode=pydantic_program_mode,
72
            output_parser=output_parser,
73
            **kwargs,
74
        )
75
        try:
76
            from gradientai import Gradient
77

78
            self._gradient = Gradient(
79
                access_token=access_token, host=host, workspace_id=workspace_id
80
            )
81
        except ImportError as e:
82
            raise ImportError(
83
                "Could not import Gradient Python package. "
84
                "Please install it with `pip install gradientai`."
85
            ) from e
86

87
    def close(self) -> None:
88
        self._gradient.close()
89

90
    @llm_completion_callback()
91
    @override
92
    def complete(
93
        self, prompt: str, formatted: bool = False, **kwargs: Any
94
    ) -> CompletionResponse:
95
        return CompletionResponse(
96
            text=self._model.complete(
97
                query=prompt,
98
                max_generated_token_count=self.max_tokens,
99
                **kwargs,
100
            ).generated_output
101
        )
102

103
    @llm_completion_callback()
104
    @override
105
    async def acomplete(
106
        self, prompt: str, formatted: bool = False, **kwargs: Any
107
    ) -> CompletionResponse:
108
        grdt_reponse = await self._model.acomplete(
109
            query=prompt,
110
            max_generated_token_count=self.max_tokens,
111
            **kwargs,
112
        )
113

114
        return CompletionResponse(text=grdt_reponse.generated_output)
115

116
    @override
117
    def stream_complete(
118
        self,
119
        prompt: str,
120
        formatted: bool = False,
121
        **kwargs: Any,
122
    ) -> CompletionResponseGen:
123
        raise NotImplementedError
124

125
    @property
126
    @override
127
    def metadata(self) -> LLMMetadata:
128
        return LLMMetadata(
129
            context_window=1024,
130
            num_output=self.max_tokens or 20,
131
            is_chat_model=self.is_chat_model,
132
            is_function_calling_model=False,
133
            model_name=self._model.id,
134
        )
135

136

137
class GradientBaseModelLLM(_BaseGradientLLM):
138
    base_model_slug: str = Field(
139
        description="The slug of the base model to use.",
140
    )
141

142
    def __init__(
143
        self,
144
        *,
145
        access_token: Optional[str] = None,
146
        base_model_slug: str,
147
        host: Optional[str] = None,
148
        max_tokens: Optional[int] = None,
149
        workspace_id: Optional[str] = None,
150
        callback_manager: Optional[CallbackManager] = None,
151
        is_chat_model: bool = False,
152
    ) -> None:
153
        super().__init__(
154
            access_token=access_token,
155
            base_model_slug=base_model_slug,
156
            host=host,
157
            max_tokens=max_tokens,
158
            workspace_id=workspace_id,
159
            callback_manager=callback_manager,
160
            is_chat_model=is_chat_model,
161
        )
162

163
        self._model = self._gradient.get_base_model(
164
            base_model_slug=base_model_slug,
165
        )
166

167

168
class GradientModelAdapterLLM(_BaseGradientLLM):
169
    model_adapter_id: str = Field(
170
        description="The id of the model adapter to use.",
171
    )
172

173
    def __init__(
174
        self,
175
        *,
176
        access_token: Optional[str] = None,
177
        host: Optional[str] = None,
178
        max_tokens: Optional[int] = None,
179
        model_adapter_id: str,
180
        workspace_id: Optional[str] = None,
181
        callback_manager: Optional[CallbackManager] = None,
182
        is_chat_model: bool = False,
183
    ) -> None:
184
        super().__init__(
185
            access_token=access_token,
186
            host=host,
187
            max_tokens=max_tokens,
188
            model_adapter_id=model_adapter_id,
189
            workspace_id=workspace_id,
190
            callback_manager=callback_manager,
191
            is_chat_model=is_chat_model,
192
        )
193
        self._model = self._gradient.get_model_adapter(
194
            model_adapter_id=model_adapter_id
195
        )
196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.