llama-index

Форк
0
144 строки · 4.6 Кб
1
"""Palm API."""
2

3
import os
4
from typing import Any, Callable, Optional, Sequence
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.callbacks import CallbackManager
8
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    CompletionResponse,
12
    CompletionResponseGen,
13
    LLMMetadata,
14
)
15
from llama_index.legacy.llms.base import llm_completion_callback
16
from llama_index.legacy.llms.custom import CustomLLM
17
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
18

19
DEFAULT_PALM_MODEL = "models/text-bison-001"
20

21

22
class PaLM(CustomLLM):
23
    """PaLM LLM."""
24

25
    model_name: str = Field(
26
        default=DEFAULT_PALM_MODEL, description="The PaLM model to use."
27
    )
28
    num_output: int = Field(
29
        default=DEFAULT_NUM_OUTPUTS,
30
        description="The number of tokens to generate.",
31
        gt=0,
32
    )
33
    generate_kwargs: dict = Field(
34
        default_factory=dict, description="Kwargs for generation."
35
    )
36

37
    _model: Any = PrivateAttr()
38

39
    def __init__(
40
        self,
41
        api_key: Optional[str] = None,
42
        model_name: Optional[str] = DEFAULT_PALM_MODEL,
43
        num_output: Optional[int] = None,
44
        callback_manager: Optional[CallbackManager] = None,
45
        system_prompt: Optional[str] = None,
46
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
47
        completion_to_prompt: Optional[Callable[[str], str]] = None,
48
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
49
        output_parser: Optional[BaseOutputParser] = None,
50
        **generate_kwargs: Any,
51
    ) -> None:
52
        """Initialize params."""
53
        try:
54
            import google.generativeai as palm
55
        except ImportError:
56
            raise ValueError(
57
                "PaLM is not installed. "
58
                "Please install it with `pip install google-generativeai`."
59
            )
60
        api_key = api_key or os.environ.get("PALM_API_KEY")
61
        palm.configure(api_key=api_key)
62

63
        models = palm.list_models()
64
        models_dict = {m.name: m for m in models}
65
        if model_name not in models_dict:
66
            raise ValueError(
67
                f"Model name {model_name} not found in {models_dict.keys()}"
68
            )
69

70
        model_name = model_name
71
        self._model = models_dict[model_name]
72

73
        # get num_output
74
        num_output = num_output or self._model.output_token_limit
75

76
        generate_kwargs = generate_kwargs or {}
77
        super().__init__(
78
            model_name=model_name,
79
            num_output=num_output,
80
            generate_kwargs=generate_kwargs,
81
            callback_manager=callback_manager,
82
            system_prompt=system_prompt,
83
            messages_to_prompt=messages_to_prompt,
84
            completion_to_prompt=completion_to_prompt,
85
            pydantic_program_mode=pydantic_program_mode,
86
            output_parser=output_parser,
87
        )
88

89
    @classmethod
90
    def class_name(cls) -> str:
91
        return "PaLM_llm"
92

93
    @property
94
    def metadata(self) -> LLMMetadata:
95
        """Get LLM metadata."""
96
        # TODO: google palm actually separates input and output token limits
97
        total_tokens = self._model.input_token_limit + self.num_output
98
        return LLMMetadata(
99
            context_window=total_tokens,
100
            num_output=self.num_output,
101
            model_name=self.model_name,
102
        )
103

104
    @llm_completion_callback()
105
    def complete(
106
        self, prompt: str, formatted: bool = False, **kwargs: Any
107
    ) -> CompletionResponse:
108
        """Predict the answer to a query.
109

110
        Args:
111
            prompt (str): Prompt to use for prediction.
112

113
        Returns:
114
            Tuple[str, str]: Tuple of the predicted answer and the formatted prompt.
115

116
        """
117
        import google.generativeai as palm
118

119
        completion = palm.generate_text(
120
            model=self.model_name,
121
            prompt=prompt,
122
            **kwargs,
123
        )
124
        return CompletionResponse(text=completion.result, raw=completion.candidates[0])
125

126
    @llm_completion_callback()
127
    def stream_complete(
128
        self, prompt: str, formatted: bool = False, **kwargs: Any
129
    ) -> CompletionResponseGen:
130
        """Stream the answer to a query.
131

132
        NOTE: this is a beta feature. Will try to build or use
133
        better abstractions about response handling.
134

135
        Args:
136
            prompt (str): Prompt to use for prediction.
137

138
        Returns:
139
            str: The predicted answer.
140

141
        """
142
        raise NotImplementedError(
143
            "PaLM does not support streaming completion in LlamaIndex currently."
144
        )
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.