llama-index

Форк
0
193 строки · 6.6 Кб
1
"""Google's hosted Gemini API."""
2

3
import os
4
import typing
5
from typing import Any, Dict, Optional, Sequence
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
10
from llama_index.legacy.core.llms.types import (
11
    ChatMessage,
12
    ChatResponse,
13
    ChatResponseGen,
14
    CompletionResponse,
15
    CompletionResponseGen,
16
    LLMMetadata,
17
)
18
from llama_index.legacy.llms.base import (
19
    llm_chat_callback,
20
    llm_completion_callback,
21
)
22
from llama_index.legacy.llms.custom import CustomLLM
23
from llama_index.legacy.llms.gemini_utils import (
24
    ROLES_FROM_GEMINI,
25
    chat_from_gemini_response,
26
    chat_message_to_gemini,
27
    completion_from_gemini_response,
28
    merge_neighboring_same_role_messages,
29
)
30

31
if typing.TYPE_CHECKING:
32
    import google.generativeai as genai
33

34

35
GEMINI_MODELS = (
36
    "models/gemini-pro",
37
    "models/gemini-ultra",
38
)
39

40

41
class Gemini(CustomLLM):
42
    """Gemini."""
43

44
    model_name: str = Field(
45
        default=GEMINI_MODELS[0], description="The Gemini model to use."
46
    )
47
    temperature: float = Field(
48
        default=DEFAULT_TEMPERATURE,
49
        description="The temperature to use during generation.",
50
        gte=0.0,
51
        lte=1.0,
52
    )
53
    max_tokens: int = Field(
54
        default=DEFAULT_NUM_OUTPUTS,
55
        description="The number of tokens to generate.",
56
        gt=0,
57
    )
58
    generate_kwargs: dict = Field(
59
        default_factory=dict, description="Kwargs for generation."
60
    )
61

62
    _model: "genai.GenerativeModel" = PrivateAttr()
63
    _model_meta: "genai.types.Model" = PrivateAttr()
64

65
    def __init__(
66
        self,
67
        api_key: Optional[str] = None,
68
        model_name: Optional[str] = GEMINI_MODELS[0],
69
        temperature: float = DEFAULT_TEMPERATURE,
70
        max_tokens: Optional[int] = None,
71
        generation_config: Optional["genai.types.GenerationConfigDict"] = None,
72
        safety_settings: "genai.types.SafetySettingOptions" = None,
73
        callback_manager: Optional[CallbackManager] = None,
74
        api_base: Optional[str] = None,
75
        transport: Optional[str] = None,
76
        **generate_kwargs: Any,
77
    ):
78
        """Creates a new Gemini model interface."""
79
        try:
80
            import google.generativeai as genai
81
        except ImportError:
82
            raise ValueError(
83
                "Gemini is not installed. Please install it with "
84
                "`pip install 'google-generativeai>=0.3.0'`."
85
            )
86

87
        # API keys are optional. The API can be authorised via OAuth (detected
88
        # environmentally) or by the GOOGLE_API_KEY environment variable.
89
        config_params: Dict[str, Any] = {
90
            "api_key": api_key or os.getenv("GOOGLE_API_KEY"),
91
        }
92
        if api_base:
93
            config_params["client_options"] = {"api_endpoint": api_base}
94
        if transport:
95
            config_params["transport"] = transport
96
        # transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
97
        genai.configure(**config_params)
98

99
        base_gen_config = generation_config if generation_config else {}
100
        # Explicitly passed args take precedence over the generation_config.
101
        final_gen_config = {"temperature": temperature, **base_gen_config}
102

103
        self._model = genai.GenerativeModel(
104
            model_name=model_name,
105
            generation_config=final_gen_config,
106
            safety_settings=safety_settings,
107
        )
108

109
        self._model_meta = genai.get_model(model_name)
110

111
        supported_methods = self._model_meta.supported_generation_methods
112
        if "generateContent" not in supported_methods:
113
            raise ValueError(
114
                f"Model {model_name} does not support content generation, only "
115
                f"{supported_methods}."
116
            )
117

118
        if not max_tokens:
119
            max_tokens = self._model_meta.output_token_limit
120
        else:
121
            max_tokens = min(max_tokens, self._model_meta.output_token_limit)
122

123
        super().__init__(
124
            model_name=model_name,
125
            temperature=temperature,
126
            max_tokens=max_tokens,
127
            generate_kwargs=generate_kwargs,
128
            callback_manager=callback_manager,
129
        )
130

131
    @classmethod
132
    def class_name(cls) -> str:
133
        return "Gemini_LLM"
134

135
    @property
136
    def metadata(self) -> LLMMetadata:
137
        total_tokens = self._model_meta.input_token_limit + self.max_tokens
138
        return LLMMetadata(
139
            context_window=total_tokens,
140
            num_output=self.max_tokens,
141
            model_name=self.model_name,
142
            is_chat_model=True,
143
        )
144

145
    @llm_completion_callback()
146
    def complete(
147
        self, prompt: str, formatted: bool = False, **kwargs: Any
148
    ) -> CompletionResponse:
149
        result = self._model.generate_content(prompt, **kwargs)
150
        return completion_from_gemini_response(result)
151

152
    def stream_complete(
153
        self, prompt: str, formatted: bool = False, **kwargs: Any
154
    ) -> CompletionResponseGen:
155
        it = self._model.generate_content(prompt, stream=True, **kwargs)
156
        yield from map(completion_from_gemini_response, it)
157

158
    @llm_chat_callback()
159
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
160
        merged_messages = merge_neighboring_same_role_messages(messages)
161
        *history, next_msg = map(chat_message_to_gemini, merged_messages)
162
        chat = self._model.start_chat(history=history)
163
        response = chat.send_message(next_msg)
164
        return chat_from_gemini_response(response)
165

166
    def stream_chat(
167
        self, messages: Sequence[ChatMessage], **kwargs: Any
168
    ) -> ChatResponseGen:
169
        merged_messages = merge_neighboring_same_role_messages(messages)
170
        *history, next_msg = map(chat_message_to_gemini, merged_messages)
171
        chat = self._model.start_chat(history=history)
172
        response = chat.send_message(next_msg, stream=True)
173

174
        def gen() -> ChatResponseGen:
175
            content = ""
176
            for r in response:
177
                top_candidate = r.candidates[0]
178
                content_delta = top_candidate.content.parts[0].text
179
                role = ROLES_FROM_GEMINI[top_candidate.content.role]
180
                raw = {
181
                    **(type(top_candidate).to_dict(top_candidate)),
182
                    **(
183
                        type(response.prompt_feedback).to_dict(response.prompt_feedback)
184
                    ),
185
                }
186
                content += content_delta
187
                yield ChatResponse(
188
                    message=ChatMessage(role=role, content=content),
189
                    delta=content_delta,
190
                    raw=raw,
191
                )
192

193
        return gen()
194

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.