llama-index

Форк
0
117 строк · 4.0 Кб
1
from typing import Any, Dict, Optional
2

3
import httpx
4
from openai import AsyncAzureOpenAI, AzureOpenAI
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr, root_validator
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
9
from llama_index.legacy.embeddings.openai import (
10
    OpenAIEmbedding,
11
    OpenAIEmbeddingMode,
12
    OpenAIEmbeddingModelType,
13
)
14
from llama_index.legacy.llms.generic_utils import get_from_param_or_env
15
from llama_index.legacy.llms.openai_utils import resolve_from_aliases
16

17

18
class AzureOpenAIEmbedding(OpenAIEmbedding):
19
    azure_endpoint: Optional[str] = Field(
20
        default=None, description="The Azure endpoint to use."
21
    )
22
    azure_deployment: Optional[str] = Field(
23
        default=None, description="The Azure deployment to use."
24
    )
25

26
    _client: AzureOpenAI = PrivateAttr()
27
    _aclient: AsyncAzureOpenAI = PrivateAttr()
28

29
    def __init__(
30
        self,
31
        mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
32
        model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
33
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
34
        additional_kwargs: Optional[Dict[str, Any]] = None,
35
        api_key: Optional[str] = None,
36
        api_version: Optional[str] = None,
37
        # azure specific
38
        azure_endpoint: Optional[str] = None,
39
        azure_deployment: Optional[str] = None,
40
        deployment_name: Optional[str] = None,
41
        max_retries: int = 10,
42
        reuse_client: bool = True,
43
        callback_manager: Optional[CallbackManager] = None,
44
        # custom httpx client
45
        http_client: Optional[httpx.Client] = None,
46
        **kwargs: Any,
47
    ):
48
        azure_endpoint = get_from_param_or_env(
49
            "azure_endpoint", azure_endpoint, "AZURE_OPENAI_ENDPOINT", ""
50
        )
51

52
        azure_deployment = resolve_from_aliases(
53
            azure_deployment,
54
            deployment_name,
55
        )
56

57
        super().__init__(
58
            mode=mode,
59
            model=model,
60
            embed_batch_size=embed_batch_size,
61
            additional_kwargs=additional_kwargs,
62
            api_key=api_key,
63
            api_version=api_version,
64
            azure_endpoint=azure_endpoint,
65
            azure_deployment=azure_deployment,
66
            max_retries=max_retries,
67
            reuse_client=reuse_client,
68
            callback_manager=callback_manager,
69
            http_client=http_client,
70
            **kwargs,
71
        )
72

73
    @root_validator(pre=True)
74
    def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
75
        """Validate necessary credentials are set."""
76
        if (
77
            values["api_base"] == "https://api.openai.com/v1"
78
            and values["azure_endpoint"] is None
79
        ):
80
            raise ValueError(
81
                "You must set OPENAI_API_BASE to your Azure endpoint. "
82
                "It should look like https://YOUR_RESOURCE_NAME.openai.azure.com/"
83
            )
84
        if values["api_version"] is None:
85
            raise ValueError("You must set OPENAI_API_VERSION for Azure OpenAI.")
86

87
        return values
88

89
    def _get_client(self) -> AzureOpenAI:
90
        if not self.reuse_client:
91
            return AzureOpenAI(**self._get_credential_kwargs())
92

93
        if self._client is None:
94
            self._client = AzureOpenAI(**self._get_credential_kwargs())
95
        return self._client
96

97
    def _get_aclient(self) -> AsyncAzureOpenAI:
98
        if not self.reuse_client:
99
            return AsyncAzureOpenAI(**self._get_credential_kwargs())
100

101
        if self._aclient is None:
102
            self._aclient = AsyncAzureOpenAI(**self._get_credential_kwargs())
103
        return self._aclient
104

105
    def _get_credential_kwargs(self) -> Dict[str, Any]:
106
        return {
107
            "api_key": self.api_key,
108
            "azure_endpoint": self.azure_endpoint,
109
            "azure_deployment": self.azure_deployment,
110
            "api_version": self.api_version,
111
            "default_headers": self.default_headers,
112
            "http_client": self._http_client,
113
        }
114

115
    @classmethod
116
    def class_name(cls) -> str:
117
        return "AzureOpenAIEmbedding"
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.