llama-index
144 строки · 4.6 Кб
1"""Palm API."""
2
3import os4from typing import Any, Callable, Optional, Sequence5
6from llama_index.legacy.bridge.pydantic import Field, PrivateAttr7from llama_index.legacy.callbacks import CallbackManager8from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS9from llama_index.legacy.core.llms.types import (10ChatMessage,11CompletionResponse,12CompletionResponseGen,13LLMMetadata,14)
15from llama_index.legacy.llms.base import llm_completion_callback16from llama_index.legacy.llms.custom import CustomLLM17from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode18
19DEFAULT_PALM_MODEL = "models/text-bison-001"20
21
22class PaLM(CustomLLM):23"""PaLM LLM."""24
25model_name: str = Field(26default=DEFAULT_PALM_MODEL, description="The PaLM model to use."27)28num_output: int = Field(29default=DEFAULT_NUM_OUTPUTS,30description="The number of tokens to generate.",31gt=0,32)33generate_kwargs: dict = Field(34default_factory=dict, description="Kwargs for generation."35)36
37_model: Any = PrivateAttr()38
39def __init__(40self,41api_key: Optional[str] = None,42model_name: Optional[str] = DEFAULT_PALM_MODEL,43num_output: Optional[int] = None,44callback_manager: Optional[CallbackManager] = None,45system_prompt: Optional[str] = None,46messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,47completion_to_prompt: Optional[Callable[[str], str]] = None,48pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,49output_parser: Optional[BaseOutputParser] = None,50**generate_kwargs: Any,51) -> None:52"""Initialize params."""53try:54import google.generativeai as palm55except ImportError:56raise ValueError(57"PaLM is not installed. "58"Please install it with `pip install google-generativeai`."59)60api_key = api_key or os.environ.get("PALM_API_KEY")61palm.configure(api_key=api_key)62
63models = palm.list_models()64models_dict = {m.name: m for m in models}65if model_name not in models_dict:66raise ValueError(67f"Model name {model_name} not found in {models_dict.keys()}"68)69
70model_name = model_name71self._model = models_dict[model_name]72
73# get num_output74num_output = num_output or self._model.output_token_limit75
76generate_kwargs = generate_kwargs or {}77super().__init__(78model_name=model_name,79num_output=num_output,80generate_kwargs=generate_kwargs,81callback_manager=callback_manager,82system_prompt=system_prompt,83messages_to_prompt=messages_to_prompt,84completion_to_prompt=completion_to_prompt,85pydantic_program_mode=pydantic_program_mode,86output_parser=output_parser,87)88
89@classmethod90def class_name(cls) -> str:91return "PaLM_llm"92
93@property94def metadata(self) -> LLMMetadata:95"""Get LLM metadata."""96# TODO: google palm actually separates input and output token limits97total_tokens = self._model.input_token_limit + self.num_output98return LLMMetadata(99context_window=total_tokens,100num_output=self.num_output,101model_name=self.model_name,102)103
104@llm_completion_callback()105def complete(106self, prompt: str, formatted: bool = False, **kwargs: Any107) -> CompletionResponse:108"""Predict the answer to a query.109
110Args:
111prompt (str): Prompt to use for prediction.
112
113Returns:
114Tuple[str, str]: Tuple of the predicted answer and the formatted prompt.
115
116"""
117import google.generativeai as palm118
119completion = palm.generate_text(120model=self.model_name,121prompt=prompt,122**kwargs,123)124return CompletionResponse(text=completion.result, raw=completion.candidates[0])125
126@llm_completion_callback()127def stream_complete(128self, prompt: str, formatted: bool = False, **kwargs: Any129) -> CompletionResponseGen:130"""Stream the answer to a query.131
132NOTE: this is a beta feature. Will try to build or use
133better abstractions about response handling.
134
135Args:
136prompt (str): Prompt to use for prediction.
137
138Returns:
139str: The predicted answer.
140
141"""
142raise NotImplementedError(143"PaLM does not support streaming completion in LlamaIndex currently."144)145