llama-index
193 строки · 6.6 Кб
1"""Google's hosted Gemini API."""
2
3import os
4import typing
5from typing import Any, Dict, Optional, Sequence
6
7from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8from llama_index.legacy.callbacks import CallbackManager
9from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
10from llama_index.legacy.core.llms.types import (
11ChatMessage,
12ChatResponse,
13ChatResponseGen,
14CompletionResponse,
15CompletionResponseGen,
16LLMMetadata,
17)
18from llama_index.legacy.llms.base import (
19llm_chat_callback,
20llm_completion_callback,
21)
22from llama_index.legacy.llms.custom import CustomLLM
23from llama_index.legacy.llms.gemini_utils import (
24ROLES_FROM_GEMINI,
25chat_from_gemini_response,
26chat_message_to_gemini,
27completion_from_gemini_response,
28merge_neighboring_same_role_messages,
29)
30
31if typing.TYPE_CHECKING:
32import google.generativeai as genai
33
34
35GEMINI_MODELS = (
36"models/gemini-pro",
37"models/gemini-ultra",
38)
39
40
41class Gemini(CustomLLM):
42"""Gemini."""
43
44model_name: str = Field(
45default=GEMINI_MODELS[0], description="The Gemini model to use."
46)
47temperature: float = Field(
48default=DEFAULT_TEMPERATURE,
49description="The temperature to use during generation.",
50gte=0.0,
51lte=1.0,
52)
53max_tokens: int = Field(
54default=DEFAULT_NUM_OUTPUTS,
55description="The number of tokens to generate.",
56gt=0,
57)
58generate_kwargs: dict = Field(
59default_factory=dict, description="Kwargs for generation."
60)
61
62_model: "genai.GenerativeModel" = PrivateAttr()
63_model_meta: "genai.types.Model" = PrivateAttr()
64
65def __init__(
66self,
67api_key: Optional[str] = None,
68model_name: Optional[str] = GEMINI_MODELS[0],
69temperature: float = DEFAULT_TEMPERATURE,
70max_tokens: Optional[int] = None,
71generation_config: Optional["genai.types.GenerationConfigDict"] = None,
72safety_settings: "genai.types.SafetySettingOptions" = None,
73callback_manager: Optional[CallbackManager] = None,
74api_base: Optional[str] = None,
75transport: Optional[str] = None,
76**generate_kwargs: Any,
77):
78"""Creates a new Gemini model interface."""
79try:
80import google.generativeai as genai
81except ImportError:
82raise ValueError(
83"Gemini is not installed. Please install it with "
84"`pip install 'google-generativeai>=0.3.0'`."
85)
86
87# API keys are optional. The API can be authorised via OAuth (detected
88# environmentally) or by the GOOGLE_API_KEY environment variable.
89config_params: Dict[str, Any] = {
90"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
91}
92if api_base:
93config_params["client_options"] = {"api_endpoint": api_base}
94if transport:
95config_params["transport"] = transport
96# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
97genai.configure(**config_params)
98
99base_gen_config = generation_config if generation_config else {}
100# Explicitly passed args take precedence over the generation_config.
101final_gen_config = {"temperature": temperature, **base_gen_config}
102
103self._model = genai.GenerativeModel(
104model_name=model_name,
105generation_config=final_gen_config,
106safety_settings=safety_settings,
107)
108
109self._model_meta = genai.get_model(model_name)
110
111supported_methods = self._model_meta.supported_generation_methods
112if "generateContent" not in supported_methods:
113raise ValueError(
114f"Model {model_name} does not support content generation, only "
115f"{supported_methods}."
116)
117
118if not max_tokens:
119max_tokens = self._model_meta.output_token_limit
120else:
121max_tokens = min(max_tokens, self._model_meta.output_token_limit)
122
123super().__init__(
124model_name=model_name,
125temperature=temperature,
126max_tokens=max_tokens,
127generate_kwargs=generate_kwargs,
128callback_manager=callback_manager,
129)
130
131@classmethod
132def class_name(cls) -> str:
133return "Gemini_LLM"
134
135@property
136def metadata(self) -> LLMMetadata:
137total_tokens = self._model_meta.input_token_limit + self.max_tokens
138return LLMMetadata(
139context_window=total_tokens,
140num_output=self.max_tokens,
141model_name=self.model_name,
142is_chat_model=True,
143)
144
145@llm_completion_callback()
146def complete(
147self, prompt: str, formatted: bool = False, **kwargs: Any
148) -> CompletionResponse:
149result = self._model.generate_content(prompt, **kwargs)
150return completion_from_gemini_response(result)
151
152def stream_complete(
153self, prompt: str, formatted: bool = False, **kwargs: Any
154) -> CompletionResponseGen:
155it = self._model.generate_content(prompt, stream=True, **kwargs)
156yield from map(completion_from_gemini_response, it)
157
158@llm_chat_callback()
159def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
160merged_messages = merge_neighboring_same_role_messages(messages)
161*history, next_msg = map(chat_message_to_gemini, merged_messages)
162chat = self._model.start_chat(history=history)
163response = chat.send_message(next_msg)
164return chat_from_gemini_response(response)
165
166def stream_chat(
167self, messages: Sequence[ChatMessage], **kwargs: Any
168) -> ChatResponseGen:
169merged_messages = merge_neighboring_same_role_messages(messages)
170*history, next_msg = map(chat_message_to_gemini, merged_messages)
171chat = self._model.start_chat(history=history)
172response = chat.send_message(next_msg, stream=True)
173
174def gen() -> ChatResponseGen:
175content = ""
176for r in response:
177top_candidate = r.candidates[0]
178content_delta = top_candidate.content.parts[0].text
179role = ROLES_FROM_GEMINI[top_candidate.content.role]
180raw = {
181**(type(top_candidate).to_dict(top_candidate)),
182**(
183type(response.prompt_feedback).to_dict(response.prompt_feedback)
184),
185}
186content += content_delta
187yield ChatResponse(
188message=ChatMessage(role=role, content=content),
189delta=content_delta,
190raw=raw,
191)
192
193return gen()
194