llama-index
298 строк · 10.9 Кб
1import json2from typing import Any, Callable, Dict, Optional, Sequence3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.callbacks import CallbackManager6from llama_index.legacy.constants import (7DEFAULT_TEMPERATURE,8)
9from llama_index.legacy.core.llms.types import (10ChatMessage,11ChatResponse,12ChatResponseAsyncGen,13ChatResponseGen,14CompletionResponse,15CompletionResponseAsyncGen,16CompletionResponseGen,17LLMMetadata,18)
19from llama_index.legacy.llms.base import (20llm_chat_callback,21llm_completion_callback,22)
23from llama_index.legacy.llms.bedrock_utils import (24BEDROCK_FOUNDATION_LLMS,25CHAT_ONLY_MODELS,26STREAMING_MODELS,27Provider,28completion_with_retry,29get_provider,30)
31from llama_index.legacy.llms.generic_utils import (32completion_response_to_chat_response,33stream_completion_response_to_chat_response,34)
35from llama_index.legacy.llms.llm import LLM36from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode37
38
39class Bedrock(LLM):40model: str = Field(description="The modelId of the Bedrock model to use.")41temperature: float = Field(description="The temperature to use for sampling.")42max_tokens: int = Field(description="The maximum number of tokens to generate.")43context_size: int = Field("The maximum number of tokens available for input.")44profile_name: Optional[str] = Field(45description="The name of aws profile to use. If not given, then the default profile is used."46)47aws_access_key_id: Optional[str] = Field(48description="AWS Access Key ID to use", exclude=True49)50aws_secret_access_key: Optional[str] = Field(51description="AWS Secret Access Key to use", exclude=True52)53aws_session_token: Optional[str] = Field(54description="AWS Session Token to use", exclude=True55)56region_name: Optional[str] = Field(57description="AWS region name to use. Uses region configured in AWS CLI if not passed",58exclude=True,59)60botocore_session: Optional[Any] = Field(61description="Use this Botocore session instead of creating a new default one.",62exclude=True,63)64botocore_config: Optional[Any] = Field(65description="Custom configuration object to use instead of the default generated one.",66exclude=True,67)68max_retries: int = Field(69default=10, description="The maximum number of API retries.", gt=070)71timeout: float = Field(72default=60.0,73description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",74)75additional_kwargs: Dict[str, Any] = Field(76default_factory=dict,77description="Additional kwargs for the bedrock invokeModel request.",78)79
80_client: Any = PrivateAttr()81_aclient: Any = PrivateAttr()82_provider: Provider = PrivateAttr()83
84def __init__(85self,86model: str,87temperature: Optional[float] = DEFAULT_TEMPERATURE,88max_tokens: Optional[int] = 512,89context_size: Optional[int] = None,90profile_name: Optional[str] = None,91aws_access_key_id: Optional[str] = None,92aws_secret_access_key: Optional[str] = None,93aws_session_token: Optional[str] = None,94region_name: Optional[str] = None,95botocore_session: Optional[Any] = None,96client: Optional[Any] = None,97timeout: Optional[float] = 60.0,98max_retries: Optional[int] = 10,99botocore_config: Optional[Any] = None,100additional_kwargs: Optional[Dict[str, Any]] = None,101callback_manager: Optional[CallbackManager] = None,102system_prompt: Optional[str] = None,103messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,104completion_to_prompt: Optional[Callable[[str], str]] = None,105pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,106output_parser: Optional[BaseOutputParser] = None,107**kwargs: Any,108) -> None:109if context_size is None and model not in BEDROCK_FOUNDATION_LLMS:110raise ValueError(111"`context_size` argument not provided and"112"model provided refers to a non-foundation model."113" Please specify the context_size"114)115
116session_kwargs = {117"profile_name": profile_name,118"region_name": region_name,119"aws_access_key_id": aws_access_key_id,120"aws_secret_access_key": aws_secret_access_key,121"aws_session_token": aws_session_token,122"botocore_session": botocore_session,123}124config = None125try:126import boto3127from botocore.config import Config128
129config = (130Config(131retries={"max_attempts": max_retries, "mode": "standard"},132connect_timeout=timeout,133read_timeout=timeout,134)135if botocore_config is None136else botocore_config137)138session = boto3.Session(**session_kwargs)139except ImportError:140raise ImportError(141"boto3 package not found, install with" "'pip install boto3'"142)143
144# Prior to general availability, custom boto3 wheel files were145# distributed that used the bedrock service to invokeModel.146# This check prevents any services still using those wheel files147# from breaking148if client is not None:149self._client = client150elif "bedrock-runtime" in session.get_available_services():151self._client = session.client("bedrock-runtime", config=config)152else:153self._client = session.client("bedrock", config=config)154
155additional_kwargs = additional_kwargs or {}156callback_manager = callback_manager or CallbackManager([])157context_size = context_size or BEDROCK_FOUNDATION_LLMS[model]158self._provider = get_provider(model)159messages_to_prompt = messages_to_prompt or self._provider.messages_to_prompt160completion_to_prompt = (161completion_to_prompt or self._provider.completion_to_prompt162)163super().__init__(164model=model,165temperature=temperature,166max_tokens=max_tokens,167context_size=context_size,168profile_name=profile_name,169timeout=timeout,170max_retries=max_retries,171botocore_config=config,172additional_kwargs=additional_kwargs,173callback_manager=callback_manager,174system_prompt=system_prompt,175messages_to_prompt=messages_to_prompt,176completion_to_prompt=completion_to_prompt,177pydantic_program_mode=pydantic_program_mode,178output_parser=output_parser,179)180
181@classmethod182def class_name(cls) -> str:183"""Get class name."""184return "Bedrock_LLM"185
186@property187def metadata(self) -> LLMMetadata:188return LLMMetadata(189context_window=self.context_size,190num_output=self.max_tokens,191is_chat_model=self.model in CHAT_ONLY_MODELS,192model_name=self.model,193)194
195@property196def _model_kwargs(self) -> Dict[str, Any]:197base_kwargs = {198"temperature": self.temperature,199self._provider.max_tokens_key: self.max_tokens,200}201return {202**base_kwargs,203**self.additional_kwargs,204}205
206def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:207return {208**self._model_kwargs,209**kwargs,210}211
212@llm_completion_callback()213def complete(214self, prompt: str, formatted: bool = False, **kwargs: Any215) -> CompletionResponse:216if not formatted:217prompt = self.completion_to_prompt(prompt)218all_kwargs = self._get_all_kwargs(**kwargs)219request_body = self._provider.get_request_body(prompt, all_kwargs)220request_body_str = json.dumps(request_body)221response = completion_with_retry(222client=self._client,223model=self.model,224request_body=request_body_str,225max_retries=self.max_retries,226**all_kwargs,227)["body"].read()228response = json.loads(response)229return CompletionResponse(230text=self._provider.get_text_from_response(response), raw=response231)232
233@llm_completion_callback()234def stream_complete(235self, prompt: str, formatted: bool = False, **kwargs: Any236) -> CompletionResponseGen:237if self.model in BEDROCK_FOUNDATION_LLMS and self.model not in STREAMING_MODELS:238raise ValueError(f"Model {self.model} does not support streaming")239
240if not formatted:241prompt = self.completion_to_prompt(prompt)242
243all_kwargs = self._get_all_kwargs(**kwargs)244request_body = self._provider.get_request_body(prompt, all_kwargs)245request_body_str = json.dumps(request_body)246response = completion_with_retry(247client=self._client,248model=self.model,249request_body=request_body_str,250max_retries=self.max_retries,251stream=True,252**all_kwargs,253)["body"]254
255def gen() -> CompletionResponseGen:256content = ""257for r in response:258r = json.loads(r["chunk"]["bytes"])259content_delta = self._provider.get_text_from_stream_response(r)260content += content_delta261yield CompletionResponse(text=content, delta=content_delta, raw=r)262
263return gen()264
265@llm_chat_callback()266def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:267prompt = self.messages_to_prompt(messages)268completion_response = self.complete(prompt, formatted=True, **kwargs)269return completion_response_to_chat_response(completion_response)270
271def stream_chat(272self, messages: Sequence[ChatMessage], **kwargs: Any273) -> ChatResponseGen:274prompt = self.messages_to_prompt(messages)275completion_response = self.stream_complete(prompt, formatted=True, **kwargs)276return stream_completion_response_to_chat_response(completion_response)277
278async def achat(279self, messages: Sequence[ChatMessage], **kwargs: Any280) -> ChatResponse:281"""Chat asynchronously."""282# TODO: do synchronous chat for now283return self.chat(messages, **kwargs)284
285async def acomplete(286self, prompt: str, formatted: bool = False, **kwargs: Any287) -> CompletionResponse:288raise NotImplementedError289
290async def astream_chat(291self, messages: Sequence[ChatMessage], **kwargs: Any292) -> ChatResponseAsyncGen:293raise NotImplementedError294
295async def astream_complete(296self, prompt: str, formatted: bool = False, **kwargs: Any297) -> CompletionResponseAsyncGen:298raise NotImplementedError299