llama-index

Форк
0
209 строк · 6.5 Кб
1
import logging
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
3

4
from openai.resources import Completions
5
from tenacity import (
6
    before_sleep_log,
7
    retry,
8
    retry_if_exception_type,
9
    stop_after_attempt,
10
    wait_exponential,
11
)
12

13
from llama_index.legacy.bridge.pydantic import BaseModel
14
from llama_index.legacy.core.llms.types import ChatMessage
15

16
MISSING_API_KEY_ERROR_MESSAGE = """No API key found for LLM.
17
E.g. to use openai Please set the OPENAI_API_KEY environment variable or \
18
openai.api_key prior to initialization.
19
API keys can be found or created at \
20
https://platform.openai.com/account/api-keys
21
"""
22
INVALID_API_KEY_ERROR_MESSAGE = """Invalid LLM API key."""
23

24
try:
25
    from litellm.utils import Message
26
except ModuleNotFoundError:
27
    Message = Any
28

29
logger = logging.getLogger(__name__)
30

31
CompletionClientType = Type[Completions]
32

33

34
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
35
    import litellm
36

37
    min_seconds = 4
38
    max_seconds = 10
39
    # Wait 2^x * 1 second between each retry starting with
40
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
41
    return retry(
42
        reraise=True,
43
        stop=stop_after_attempt(max_retries),
44
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
45
        retry=(
46
            retry_if_exception_type(litellm.exceptions.Timeout)
47
            | retry_if_exception_type(litellm.exceptions.APIError)
48
            | retry_if_exception_type(litellm.exceptions.APIConnectionError)
49
            | retry_if_exception_type(litellm.exceptions.RateLimitError)
50
            | retry_if_exception_type(litellm.exceptions.ServiceUnavailableError)
51
        ),
52
        before_sleep=before_sleep_log(logger, logging.WARNING),
53
    )
54

55

56
def completion_with_retry(is_chat_model: bool, max_retries: int, **kwargs: Any) -> Any:
57
    from litellm import completion
58

59
    """Use tenacity to retry the completion call."""
60
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
61

62
    @retry_decorator
63
    def _completion_with_retry(**kwargs: Any) -> Any:
64
        return completion(**kwargs)
65

66
    return _completion_with_retry(**kwargs)
67

68

69
async def acompletion_with_retry(
70
    is_chat_model: bool, max_retries: int, **kwargs: Any
71
) -> Any:
72
    from litellm import acompletion
73

74
    """Use tenacity to retry the async completion call."""
75
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
76

77
    @retry_decorator
78
    async def _completion_with_retry(**kwargs: Any) -> Any:
79
        # Use OpenAI's async api https://github.com/openai/openai-python#async-api
80
        return await acompletion(**kwargs)
81

82
    return await _completion_with_retry(**kwargs)
83

84

85
def openai_modelname_to_contextsize(modelname: str) -> int:
86
    import litellm
87

88
    """Calculate the maximum number of tokens possible to generate for a model.
89

90
    Args:
91
        modelname: The modelname we want to know the context size for.
92

93
    Returns:
94
        The maximum context size
95

96
    Example:
97
        .. code-block:: python
98

99
            max_tokens = openai.modelname_to_contextsize("text-davinci-003")
100

101
    Modified from:
102
        https://github.com/hwchase17/langchain/blob/master/langchain/llms/openai.py
103
    """
104
    # handling finetuned models
105
    if modelname.startswith("ft:"):
106
        modelname = modelname.split(":")[1]
107
    elif ":ft-" in modelname:  # legacy fine-tuning
108
        modelname = modelname.split(":")[0]
109

110
    try:
111
        context_size = int(litellm.get_max_tokens(modelname))
112
    except Exception:
113
        context_size = 2048  # by default assume models have at least 2048 tokens
114

115
    if context_size is None:
116
        raise ValueError(
117
            f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
118
            "Known models are: "
119
            + ", ".join(litellm.model_list)
120
            + "\nKnown providers are: "
121
            + ", ".join(litellm.provider_list)
122
        )
123

124
    return context_size
125

126

127
def is_chat_model(model: str) -> bool:
128
    import litellm
129

130
    return model in litellm.model_list
131

132

133
def is_function_calling_model(model: str) -> bool:
134
    is_chat_model_ = is_chat_model(model)
135
    is_old = "0314" in model or "0301" in model
136
    return is_chat_model_ and not is_old
137

138

139
def get_completion_endpoint(is_chat_model: bool) -> CompletionClientType:
140
    from litellm import completion
141

142
    return completion
143

144

145
def to_openai_message_dict(message: ChatMessage) -> dict:
146
    """Convert generic message to OpenAI message dict."""
147
    message_dict = {
148
        "role": message.role,
149
        "content": message.content,
150
    }
151

152
    # NOTE: openai messages have additional arguments:
153
    # - function messages have `name`
154
    # - assistant messages have optional `function_call`
155
    message_dict.update(message.additional_kwargs)
156

157
    return message_dict
158

159

160
def to_openai_message_dicts(messages: Sequence[ChatMessage]) -> List[dict]:
161
    """Convert generic messages to OpenAI message dicts."""
162
    return [to_openai_message_dict(message) for message in messages]
163

164

165
def from_openai_message_dict(message_dict: dict) -> ChatMessage:
166
    """Convert openai message dict to generic message."""
167
    role = message_dict["role"]
168
    # NOTE: Azure OpenAI returns function calling messages without a content key
169
    content = message_dict.get("content", None)
170

171
    additional_kwargs = message_dict.copy()
172
    additional_kwargs.pop("role")
173
    additional_kwargs.pop("content", None)
174

175
    return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs)
176

177

178
def from_litellm_message(message: Message) -> ChatMessage:
179
    """Convert litellm.utils.Message instance to generic message."""
180
    role = message.get("role")
181
    # NOTE: Azure OpenAI returns function calling messages without a content key
182
    content = message.get("content", None)
183

184
    return ChatMessage(role=role, content=content)
185

186

187
def from_openai_message_dicts(message_dicts: Sequence[dict]) -> List[ChatMessage]:
188
    """Convert openai message dicts to generic messages."""
189
    return [from_openai_message_dict(message_dict) for message_dict in message_dicts]
190

191

192
def to_openai_function(pydantic_class: Type[BaseModel]) -> Dict[str, Any]:
193
    """Convert pydantic class to OpenAI function."""
194
    schema = pydantic_class.schema()
195
    return {
196
        "name": schema["title"],
197
        "description": schema["description"],
198
        "parameters": pydantic_class.schema(),
199
    }
200

201

202
def validate_litellm_api_key(
203
    api_key: Optional[str] = None, api_type: Optional[str] = None
204
) -> None:
205
    import litellm
206

207
    api_key = litellm.validate_environment()
208
    if api_key is None:
209
        raise ValueError(MISSING_API_KEY_ERROR_MESSAGE)
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.