llama-index

Форк
0
124 строки · 4.2 Кб
1
import os
2
from typing import Any, Callable, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.constants import (
7
    DEFAULT_CONTEXT_WINDOW,
8
    DEFAULT_NUM_OUTPUTS,
9
    DEFAULT_TEMPERATURE,
10
)
11
from llama_index.legacy.core.llms.types import (
12
    ChatMessage,
13
    CompletionResponse,
14
    CompletionResponseGen,
15
    LLMMetadata,
16
)
17
from llama_index.legacy.llms.base import llm_completion_callback
18
from llama_index.legacy.llms.custom import CustomLLM
19
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
20

21

22
class PredibaseLLM(CustomLLM):
23
    """Predibase LLM."""
24

25
    model_name: str = Field(description="The Predibase model to use.")
26
    predibase_api_key: str = Field(description="The Predibase API key to use.")
27
    max_new_tokens: int = Field(
28
        default=DEFAULT_NUM_OUTPUTS,
29
        description="The number of tokens to generate.",
30
        gt=0,
31
    )
32
    temperature: float = Field(
33
        default=DEFAULT_TEMPERATURE,
34
        description="The temperature to use for sampling.",
35
        gte=0.0,
36
        lte=1.0,
37
    )
38
    context_window: int = Field(
39
        default=DEFAULT_CONTEXT_WINDOW,
40
        description="The number of context tokens available to the LLM.",
41
        gt=0,
42
    )
43

44
    _client: Any = PrivateAttr()
45

46
    def __init__(
47
        self,
48
        model_name: str,
49
        predibase_api_key: Optional[str] = None,
50
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
51
        temperature: float = DEFAULT_TEMPERATURE,
52
        context_window: int = DEFAULT_CONTEXT_WINDOW,
53
        callback_manager: Optional[CallbackManager] = None,
54
        system_prompt: Optional[str] = None,
55
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
56
        completion_to_prompt: Optional[Callable[[str], str]] = None,
57
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
58
        output_parser: Optional[BaseOutputParser] = None,
59
    ) -> None:
60
        predibase_api_key = (
61
            predibase_api_key
62
            if predibase_api_key
63
            else os.environ.get("PREDIBASE_API_TOKEN")
64
        )
65
        assert predibase_api_key is not None
66

67
        self._client = self.initialize_client(predibase_api_key)
68

69
        super().__init__(
70
            model_name=model_name,
71
            predibase_api_key=predibase_api_key,
72
            max_new_tokens=max_new_tokens,
73
            temperature=temperature,
74
            context_window=context_window,
75
            callback_manager=callback_manager,
76
            system_prompt=system_prompt,
77
            messages_to_prompt=messages_to_prompt,
78
            completion_to_prompt=completion_to_prompt,
79
            pydantic_program_mode=pydantic_program_mode,
80
            output_parser=output_parser,
81
        )
82

83
    @staticmethod
84
    def initialize_client(predibase_api_key: str) -> Any:
85
        try:
86
            from predibase import PredibaseClient
87

88
            return PredibaseClient(token=predibase_api_key)
89
        except ImportError as e:
90
            raise ImportError(
91
                "Could not import Predibase Python package. "
92
                "Please install it with `pip install predibase`."
93
            ) from e
94
        except ValueError as e:
95
            raise ValueError("Your API key is not correct. Please try again") from e
96

97
    @classmethod
98
    def class_name(cls) -> str:
99
        return "PredibaseLLM"
100

101
    @property
102
    def metadata(self) -> LLMMetadata:
103
        """Get LLM metadata."""
104
        return LLMMetadata(
105
            context_window=self.context_window,
106
            num_output=self.max_new_tokens,
107
            model_name=self.model_name,
108
        )
109

110
    @llm_completion_callback()
111
    def complete(
112
        self, prompt: str, formatted: bool = False, **kwargs: Any
113
    ) -> "CompletionResponse":
114
        llm = self._client.LLM(f"pb://deployments/{self.model_name}")
115
        results = llm.prompt(
116
            prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature
117
        )
118
        return CompletionResponse(text=results.response)
119

120
    @llm_completion_callback()
121
    def stream_complete(
122
        self, prompt: str, formatted: bool = False, **kwargs: Any
123
    ) -> "CompletionResponseGen":
124
        raise NotImplementedError
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.