llama-index
184 строки · 6.9 Кб
1from typing import Any, Callable, Dict, Optional, Sequence2
3import httpx4from openai import AsyncAzureOpenAI5from openai import AzureOpenAI as SyncAzureOpenAI6
7from llama_index.legacy.bridge.pydantic import Field, PrivateAttr, root_validator8from llama_index.legacy.callbacks import CallbackManager9from llama_index.legacy.core.llms.types import ChatMessage10from llama_index.legacy.llms.generic_utils import get_from_param_or_env11from llama_index.legacy.llms.openai import OpenAI12from llama_index.legacy.llms.openai_utils import (13refresh_openai_azuread_token,14resolve_from_aliases,15)
16from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode17
18
19class AzureOpenAI(OpenAI):20"""21Azure OpenAI.
22
23To use this, you must first deploy a model on Azure OpenAI.
24Unlike OpenAI, you need to specify a `engine` parameter to identify
25your deployment (called "model deployment name" in Azure portal).
26
27- model: Name of the model (e.g. `text-davinci-003`)
28This in only used to decide completion vs. chat endpoint.
29- engine: This will correspond to the custom name you chose
30for your deployment when you deployed a model.
31
32You must have the following environment variables set:
33- `OPENAI_API_VERSION`: set this to `2023-05-15`
34This may change in the future.
35- `AZURE_OPENAI_ENDPOINT`: your endpoint should look like the following
36https://YOUR_RESOURCE_NAME.openai.azure.com/
37- `AZURE_OPENAI_API_KEY`: your API key if the api type is `azure`
38
39More information can be found here:
40https://learn.microsoft.com/en-us/azure/cognitive-services/openai/quickstart?tabs=command-line&pivots=programming-language-python
41"""
42
43engine: str = Field(description="The name of the deployed azure engine.")44azure_endpoint: Optional[str] = Field(45default=None, description="The Azure endpoint to use."46)47azure_deployment: Optional[str] = Field(48default=None, description="The Azure deployment to use."49)50use_azure_ad: bool = Field(51description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication"52)53
54_azure_ad_token: Any = PrivateAttr()55_client: SyncAzureOpenAI = PrivateAttr()56_aclient: AsyncAzureOpenAI = PrivateAttr()57
58def __init__(59self,60model: str = "gpt-35-turbo",61engine: Optional[str] = None,62temperature: float = 0.1,63max_tokens: Optional[int] = None,64additional_kwargs: Optional[Dict[str, Any]] = None,65max_retries: int = 3,66timeout: float = 60.0,67reuse_client: bool = True,68api_key: Optional[str] = None,69api_version: Optional[str] = None,70# azure specific71azure_endpoint: Optional[str] = None,72azure_deployment: Optional[str] = None,73use_azure_ad: bool = False,74callback_manager: Optional[CallbackManager] = None,75# aliases for engine76deployment_name: Optional[str] = None,77deployment_id: Optional[str] = None,78deployment: Optional[str] = None,79# custom httpx client80http_client: Optional[httpx.Client] = None,81# base class82system_prompt: Optional[str] = None,83messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,84completion_to_prompt: Optional[Callable[[str], str]] = None,85pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,86output_parser: Optional[BaseOutputParser] = None,87**kwargs: Any,88) -> None:89engine = resolve_from_aliases(90engine, deployment_name, deployment_id, deployment, azure_deployment91)92
93if engine is None:94raise ValueError("You must specify an `engine` parameter.")95
96azure_endpoint = get_from_param_or_env(97"azure_endpoint", azure_endpoint, "AZURE_OPENAI_ENDPOINT", ""98)99
100super().__init__(101engine=engine,102model=model,103temperature=temperature,104max_tokens=max_tokens,105additional_kwargs=additional_kwargs,106max_retries=max_retries,107timeout=timeout,108reuse_client=reuse_client,109api_key=api_key,110azure_endpoint=azure_endpoint,111azure_deployment=azure_deployment,112use_azure_ad=use_azure_ad,113api_version=api_version,114callback_manager=callback_manager,115http_client=http_client,116system_prompt=system_prompt,117messages_to_prompt=messages_to_prompt,118completion_to_prompt=completion_to_prompt,119pydantic_program_mode=pydantic_program_mode,120output_parser=output_parser,121**kwargs,122)123
124@root_validator(pre=True)125def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:126"""Validate necessary credentials are set."""127if (128values["api_base"] == "https://api.openai.com/v1"129and values["azure_endpoint"] is None130):131raise ValueError(132"You must set OPENAI_API_BASE to your Azure endpoint. "133"It should look like https://YOUR_RESOURCE_NAME.openai.azure.com/"134)135if values["api_version"] is None:136raise ValueError("You must set OPENAI_API_VERSION for Azure OpenAI.")137
138return values139
140def _get_client(self) -> SyncAzureOpenAI:141if not self.reuse_client:142return SyncAzureOpenAI(**self._get_credential_kwargs())143
144if self._client is None:145self._client = SyncAzureOpenAI(146**self._get_credential_kwargs(),147)148return self._client149
150def _get_aclient(self) -> AsyncAzureOpenAI:151if not self.reuse_client:152return AsyncAzureOpenAI(**self._get_credential_kwargs())153
154if self._aclient is None:155self._aclient = AsyncAzureOpenAI(156**self._get_credential_kwargs(),157)158return self._aclient159
160def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]:161if self.use_azure_ad:162self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token)163self.api_key = self._azure_ad_token.token164
165return {166"api_key": self.api_key,167"max_retries": self.max_retries,168"timeout": self.timeout,169"azure_endpoint": self.azure_endpoint,170"azure_deployment": self.azure_deployment,171"api_version": self.api_version,172"default_headers": self.default_headers,173"http_client": self._http_client,174**kwargs,175}176
177def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:178model_kwargs = super()._get_model_kwargs(**kwargs)179model_kwargs["model"] = self.engine180return model_kwargs181
182@classmethod183def class_name(cls) -> str:184return "azure_openai_llm"185