llama-index

Форк
0
184 строки · 6.9 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
import httpx
4
from openai import AsyncAzureOpenAI
5
from openai import AzureOpenAI as SyncAzureOpenAI
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr, root_validator
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.core.llms.types import ChatMessage
10
from llama_index.legacy.llms.generic_utils import get_from_param_or_env
11
from llama_index.legacy.llms.openai import OpenAI
12
from llama_index.legacy.llms.openai_utils import (
13
    refresh_openai_azuread_token,
14
    resolve_from_aliases,
15
)
16
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
17

18

19
class AzureOpenAI(OpenAI):
20
    """
21
    Azure OpenAI.
22

23
    To use this, you must first deploy a model on Azure OpenAI.
24
    Unlike OpenAI, you need to specify a `engine` parameter to identify
25
    your deployment (called "model deployment name" in Azure portal).
26

27
    - model: Name of the model (e.g. `text-davinci-003`)
28
        This in only used to decide completion vs. chat endpoint.
29
    - engine: This will correspond to the custom name you chose
30
        for your deployment when you deployed a model.
31

32
    You must have the following environment variables set:
33
    - `OPENAI_API_VERSION`: set this to `2023-05-15`
34
        This may change in the future.
35
    - `AZURE_OPENAI_ENDPOINT`: your endpoint should look like the following
36
        https://YOUR_RESOURCE_NAME.openai.azure.com/
37
    - `AZURE_OPENAI_API_KEY`: your API key if the api type is `azure`
38

39
    More information can be found here:
40
        https://learn.microsoft.com/en-us/azure/cognitive-services/openai/quickstart?tabs=command-line&pivots=programming-language-python
41
    """
42

43
    engine: str = Field(description="The name of the deployed azure engine.")
44
    azure_endpoint: Optional[str] = Field(
45
        default=None, description="The Azure endpoint to use."
46
    )
47
    azure_deployment: Optional[str] = Field(
48
        default=None, description="The Azure deployment to use."
49
    )
50
    use_azure_ad: bool = Field(
51
        description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication"
52
    )
53

54
    _azure_ad_token: Any = PrivateAttr()
55
    _client: SyncAzureOpenAI = PrivateAttr()
56
    _aclient: AsyncAzureOpenAI = PrivateAttr()
57

58
    def __init__(
59
        self,
60
        model: str = "gpt-35-turbo",
61
        engine: Optional[str] = None,
62
        temperature: float = 0.1,
63
        max_tokens: Optional[int] = None,
64
        additional_kwargs: Optional[Dict[str, Any]] = None,
65
        max_retries: int = 3,
66
        timeout: float = 60.0,
67
        reuse_client: bool = True,
68
        api_key: Optional[str] = None,
69
        api_version: Optional[str] = None,
70
        # azure specific
71
        azure_endpoint: Optional[str] = None,
72
        azure_deployment: Optional[str] = None,
73
        use_azure_ad: bool = False,
74
        callback_manager: Optional[CallbackManager] = None,
75
        # aliases for engine
76
        deployment_name: Optional[str] = None,
77
        deployment_id: Optional[str] = None,
78
        deployment: Optional[str] = None,
79
        # custom httpx client
80
        http_client: Optional[httpx.Client] = None,
81
        # base class
82
        system_prompt: Optional[str] = None,
83
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
84
        completion_to_prompt: Optional[Callable[[str], str]] = None,
85
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
86
        output_parser: Optional[BaseOutputParser] = None,
87
        **kwargs: Any,
88
    ) -> None:
89
        engine = resolve_from_aliases(
90
            engine, deployment_name, deployment_id, deployment, azure_deployment
91
        )
92

93
        if engine is None:
94
            raise ValueError("You must specify an `engine` parameter.")
95

96
        azure_endpoint = get_from_param_or_env(
97
            "azure_endpoint", azure_endpoint, "AZURE_OPENAI_ENDPOINT", ""
98
        )
99

100
        super().__init__(
101
            engine=engine,
102
            model=model,
103
            temperature=temperature,
104
            max_tokens=max_tokens,
105
            additional_kwargs=additional_kwargs,
106
            max_retries=max_retries,
107
            timeout=timeout,
108
            reuse_client=reuse_client,
109
            api_key=api_key,
110
            azure_endpoint=azure_endpoint,
111
            azure_deployment=azure_deployment,
112
            use_azure_ad=use_azure_ad,
113
            api_version=api_version,
114
            callback_manager=callback_manager,
115
            http_client=http_client,
116
            system_prompt=system_prompt,
117
            messages_to_prompt=messages_to_prompt,
118
            completion_to_prompt=completion_to_prompt,
119
            pydantic_program_mode=pydantic_program_mode,
120
            output_parser=output_parser,
121
            **kwargs,
122
        )
123

124
    @root_validator(pre=True)
125
    def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
126
        """Validate necessary credentials are set."""
127
        if (
128
            values["api_base"] == "https://api.openai.com/v1"
129
            and values["azure_endpoint"] is None
130
        ):
131
            raise ValueError(
132
                "You must set OPENAI_API_BASE to your Azure endpoint. "
133
                "It should look like https://YOUR_RESOURCE_NAME.openai.azure.com/"
134
            )
135
        if values["api_version"] is None:
136
            raise ValueError("You must set OPENAI_API_VERSION for Azure OpenAI.")
137

138
        return values
139

140
    def _get_client(self) -> SyncAzureOpenAI:
141
        if not self.reuse_client:
142
            return SyncAzureOpenAI(**self._get_credential_kwargs())
143

144
        if self._client is None:
145
            self._client = SyncAzureOpenAI(
146
                **self._get_credential_kwargs(),
147
            )
148
        return self._client
149

150
    def _get_aclient(self) -> AsyncAzureOpenAI:
151
        if not self.reuse_client:
152
            return AsyncAzureOpenAI(**self._get_credential_kwargs())
153

154
        if self._aclient is None:
155
            self._aclient = AsyncAzureOpenAI(
156
                **self._get_credential_kwargs(),
157
            )
158
        return self._aclient
159

160
    def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
161
        if self.use_azure_ad:
162
            self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token)
163
            self.api_key = self._azure_ad_token.token
164

165
        return {
166
            "api_key": self.api_key,
167
            "max_retries": self.max_retries,
168
            "timeout": self.timeout,
169
            "azure_endpoint": self.azure_endpoint,
170
            "azure_deployment": self.azure_deployment,
171
            "api_version": self.api_version,
172
            "default_headers": self.default_headers,
173
            "http_client": self._http_client,
174
            **kwargs,
175
        }
176

177
    def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
178
        model_kwargs = super()._get_model_kwargs(**kwargs)
179
        model_kwargs["model"] = self.engine
180
        return model_kwargs
181

182
    @classmethod
183
    def class_name(cls) -> str:
184
        return "azure_openai_llm"
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.