llama-index

Форк
0
349 строк · 12.3 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.llms.types import (
6
    ChatMessage,
7
    ChatResponse,
8
    ChatResponseAsyncGen,
9
    ChatResponseGen,
10
    CompletionResponse,
11
    CompletionResponseAsyncGen,
12
    CompletionResponseGen,
13
    LLMMetadata,
14
    MessageRole,
15
)
16
from llama_index.legacy.llms.base import (
17
    llm_chat_callback,
18
    llm_completion_callback,
19
)
20
from llama_index.legacy.llms.llm import LLM
21
from llama_index.legacy.llms.vertex_gemini_utils import is_gemini_model
22
from llama_index.legacy.llms.vertex_utils import (
23
    CHAT_MODELS,
24
    CODE_CHAT_MODELS,
25
    CODE_MODELS,
26
    TEXT_MODELS,
27
    _parse_chat_history,
28
    _parse_examples,
29
    _parse_message,
30
    acompletion_with_retry,
31
    completion_with_retry,
32
    init_vertexai,
33
)
34
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
35

36

37
class Vertex(LLM):
38
    model: str = Field(description="The vertex model to use.")
39
    temperature: float = Field(description="The temperature to use for sampling.")
40
    max_tokens: int = Field(description="The maximum number of tokens to generate.")
41
    examples: Optional[Sequence[ChatMessage]] = Field(
42
        description="Example messages for the chat model."
43
    )
44
    max_retries: int = Field(default=10, description="The maximum number of retries.")
45

46
    additional_kwargs: Dict[str, Any] = Field(
47
        default_factory=dict, description="Additional kwargs for the Vertex."
48
    )
49
    iscode: bool = Field(
50
        default=False, description="Flag to determine if current model is a Code Model"
51
    )
52
    _is_gemini: bool = PrivateAttr()
53
    _is_chat_model: bool = PrivateAttr()
54
    _client: Any = PrivateAttr()
55
    _chat_client: Any = PrivateAttr()
56

57
    def __init__(
58
        self,
59
        model: str = "text-bison",
60
        project: Optional[str] = None,
61
        location: Optional[str] = None,
62
        credentials: Optional[Any] = None,
63
        examples: Optional[Sequence[ChatMessage]] = None,
64
        temperature: float = 0.1,
65
        max_tokens: int = 512,
66
        max_retries: int = 10,
67
        iscode: bool = False,
68
        additional_kwargs: Optional[Dict[str, Any]] = None,
69
        callback_manager: Optional[CallbackManager] = None,
70
        system_prompt: Optional[str] = None,
71
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
72
        completion_to_prompt: Optional[Callable[[str], str]] = None,
73
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
74
        output_parser: Optional[BaseOutputParser] = None,
75
    ) -> None:
76
        init_vertexai(project=project, location=location, credentials=credentials)
77

78
        additional_kwargs = additional_kwargs or {}
79
        callback_manager = callback_manager or CallbackManager([])
80

81
        self._is_gemini = False
82
        self._is_chat_model = False
83
        if model in CHAT_MODELS:
84
            from vertexai.language_models import ChatModel
85

86
            self._chat_client = ChatModel.from_pretrained(model)
87
            self._is_chat_model = True
88
        elif model in CODE_CHAT_MODELS:
89
            from vertexai.language_models import CodeChatModel
90

91
            self._chat_client = CodeChatModel.from_pretrained(model)
92
            iscode = True
93
            self._is_chat_model = True
94
        elif model in CODE_MODELS:
95
            from vertexai.language_models import CodeGenerationModel
96

97
            self._client = CodeGenerationModel.from_pretrained(model)
98
            iscode = True
99
        elif model in TEXT_MODELS:
100
            from vertexai.language_models import TextGenerationModel
101

102
            self._client = TextGenerationModel.from_pretrained(model)
103
        elif is_gemini_model(model):
104
            from llama_index.legacy.llms.vertex_gemini_utils import create_gemini_client
105

106
            self._client = create_gemini_client(model)
107
            self._chat_client = self._client
108
            self._is_gemini = True
109
            self._is_chat_model = True
110
        else:
111
            raise (ValueError(f"Model {model} not found, please verify the model name"))
112

113
        super().__init__(
114
            temperature=temperature,
115
            max_tokens=max_tokens,
116
            additional_kwargs=additional_kwargs,
117
            max_retries=max_retries,
118
            model=model,
119
            examples=examples,
120
            iscode=iscode,
121
            callback_manager=callback_manager,
122
            system_prompt=system_prompt,
123
            messages_to_prompt=messages_to_prompt,
124
            completion_to_prompt=completion_to_prompt,
125
            pydantic_program_mode=pydantic_program_mode,
126
            output_parser=output_parser,
127
        )
128

129
    @classmethod
130
    def class_name(cls) -> str:
131
        return "Vertex"
132

133
    @property
134
    def metadata(self) -> LLMMetadata:
135
        return LLMMetadata(
136
            is_chat_model=self._is_chat_model,
137
            model_name=self.model,
138
        )
139

140
    @property
141
    def _model_kwargs(self) -> Dict[str, Any]:
142
        base_kwargs = {
143
            "temperature": self.temperature,
144
            "max_output_tokens": self.max_tokens,
145
        }
146
        return {
147
            **base_kwargs,
148
            **self.additional_kwargs,
149
        }
150

151
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
152
        return {
153
            **self._model_kwargs,
154
            **kwargs,
155
        }
156

157
    @llm_chat_callback()
158
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
159
        question = _parse_message(messages[-1], self._is_gemini)
160
        chat_history = _parse_chat_history(messages[:-1], self._is_gemini)
161
        chat_params = {**chat_history}
162

163
        kwargs = kwargs if kwargs else {}
164

165
        params = {**self._model_kwargs, **kwargs}
166

167
        if self.iscode and "candidate_count" in params:
168
            raise (ValueError("candidate_count is not supported by the codey model's"))
169
        if self.examples and "examples" not in params:
170
            chat_params["examples"] = _parse_examples(self.examples)
171
        elif "examples" in params:
172
            raise (
173
                ValueError(
174
                    "examples are not supported in chat generation pass them as a constructor parameter"
175
                )
176
            )
177

178
        generation = completion_with_retry(
179
            client=self._chat_client,
180
            prompt=question,
181
            chat=True,
182
            stream=False,
183
            is_gemini=self._is_gemini,
184
            params=chat_params,
185
            max_retries=self.max_retries,
186
            **params,
187
        )
188

189
        return ChatResponse(
190
            message=ChatMessage(role=MessageRole.ASSISTANT, content=generation.text),
191
            raw=generation.__dict__,
192
        )
193

194
    @llm_completion_callback()
195
    def complete(
196
        self, prompt: str, formatted: bool = False, **kwargs: Any
197
    ) -> CompletionResponse:
198
        kwargs = kwargs if kwargs else {}
199
        params = {**self._model_kwargs, **kwargs}
200
        if self.iscode and "candidate_count" in params:
201
            raise (ValueError("candidate_count is not supported by the codey model's"))
202

203
        completion = completion_with_retry(
204
            self._client,
205
            prompt,
206
            max_retries=self.max_retries,
207
            is_gemini=self._is_gemini,
208
            **params,
209
        )
210
        return CompletionResponse(text=completion.text, raw=completion.__dict__)
211

212
    @llm_chat_callback()
213
    def stream_chat(
214
        self, messages: Sequence[ChatMessage], **kwargs: Any
215
    ) -> ChatResponseGen:
216
        question = _parse_message(messages[-1], self._is_gemini)
217
        chat_history = _parse_chat_history(messages[:-1], self._is_gemini)
218
        chat_params = {**chat_history}
219
        kwargs = kwargs if kwargs else {}
220
        params = {**self._model_kwargs, **kwargs}
221
        if self.iscode and "candidate_count" in params:
222
            raise (ValueError("candidate_count is not supported by the codey model's"))
223
        if self.examples and "examples" not in params:
224
            chat_params["examples"] = _parse_examples(self.examples)
225
        elif "examples" in params:
226
            raise (
227
                ValueError(
228
                    "examples are not supported in chat generation pass them as a constructor parameter"
229
                )
230
            )
231

232
        response = completion_with_retry(
233
            client=self._chat_client,
234
            prompt=question,
235
            chat=True,
236
            stream=True,
237
            is_gemini=self._is_gemini,
238
            params=chat_params,
239
            max_retries=self.max_retries,
240
            **params,
241
        )
242

243
        def gen() -> ChatResponseGen:
244
            content = ""
245
            role = MessageRole.ASSISTANT
246
            for r in response:
247
                content_delta = r.text
248
                content += content_delta
249
                yield ChatResponse(
250
                    message=ChatMessage(role=role, content=content),
251
                    delta=content_delta,
252
                    raw=r.__dict__,
253
                )
254

255
        return gen()
256

257
    @llm_completion_callback()
258
    def stream_complete(
259
        self, prompt: str, formatted: bool = False, **kwargs: Any
260
    ) -> CompletionResponseGen:
261
        kwargs = kwargs if kwargs else {}
262
        params = {**self._model_kwargs, **kwargs}
263
        if "candidate_count" in params:
264
            raise (ValueError("candidate_count is not supported by the streaming"))
265

266
        completion = completion_with_retry(
267
            client=self._client,
268
            prompt=prompt,
269
            stream=True,
270
            is_gemini=self._is_gemini,
271
            max_retries=self.max_retries,
272
            **params,
273
        )
274

275
        def gen() -> CompletionResponseGen:
276
            content = ""
277
            for r in completion:
278
                content_delta = r.text
279
                content += content_delta
280
                yield CompletionResponse(
281
                    text=content, delta=content_delta, raw=r.__dict__
282
                )
283

284
        return gen()
285

286
    @llm_chat_callback()
287
    async def achat(
288
        self, messages: Sequence[ChatMessage], **kwargs: Any
289
    ) -> ChatResponse:
290
        question = _parse_message(messages[-1], self._is_gemini)
291
        chat_history = _parse_chat_history(messages[:-1], self._is_gemini)
292
        chat_params = {**chat_history}
293
        kwargs = kwargs if kwargs else {}
294
        params = {**self._model_kwargs, **kwargs}
295
        if self.iscode and "candidate_count" in params:
296
            raise (ValueError("candidate_count is not supported by the codey model's"))
297
        if self.examples and "examples" not in params:
298
            chat_params["examples"] = _parse_examples(self.examples)
299
        elif "examples" in params:
300
            raise (
301
                ValueError(
302
                    "examples are not supported in chat generation pass them as a constructor parameter"
303
                )
304
            )
305
        generation = await acompletion_with_retry(
306
            client=self._chat_client,
307
            prompt=question,
308
            chat=True,
309
            is_gemini=self._is_gemini,
310
            params=chat_params,
311
            max_retries=self.max_retries,
312
            **params,
313
        )
314
        ##this is due to a bug in vertex AI we have to await twice
315
        if self.iscode:
316
            generation = await generation
317
        return ChatResponse(
318
            message=ChatMessage(role=MessageRole.ASSISTANT, content=generation.text),
319
            raw=generation.__dict__,
320
        )
321

322
    @llm_completion_callback()
323
    async def acomplete(
324
        self, prompt: str, formatted: bool = False, **kwargs: Any
325
    ) -> CompletionResponse:
326
        kwargs = kwargs if kwargs else {}
327
        params = {**self._model_kwargs, **kwargs}
328
        if self.iscode and "candidate_count" in params:
329
            raise (ValueError("candidate_count is not supported by the codey model's"))
330
        completion = await acompletion_with_retry(
331
            client=self._client,
332
            prompt=prompt,
333
            max_retries=self.max_retries,
334
            is_gemini=self._is_gemini,
335
            **params,
336
        )
337
        return CompletionResponse(text=completion.text)
338

339
    @llm_chat_callback()
340
    async def astream_chat(
341
        self, messages: Sequence[ChatMessage], **kwargs: Any
342
    ) -> ChatResponseAsyncGen:
343
        raise (ValueError("Not Implemented"))
344

345
    @llm_completion_callback()
346
    async def astream_complete(
347
        self, prompt: str, formatted: bool = False, **kwargs: Any
348
    ) -> CompletionResponseAsyncGen:
349
        raise (ValueError("Not Implemented"))
350

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.