llama-index
209 строк · 6.5 Кб
1import logging2from typing import Any, Callable, Dict, List, Optional, Sequence, Type3
4from openai.resources import Completions5from tenacity import (6before_sleep_log,7retry,8retry_if_exception_type,9stop_after_attempt,10wait_exponential,11)
12
13from llama_index.legacy.bridge.pydantic import BaseModel14from llama_index.legacy.core.llms.types import ChatMessage15
16MISSING_API_KEY_ERROR_MESSAGE = """No API key found for LLM.17E.g. to use openai Please set the OPENAI_API_KEY environment variable or \
18openai.api_key prior to initialization.
19API keys can be found or created at \
20https://platform.openai.com/account/api-keys
21"""
22INVALID_API_KEY_ERROR_MESSAGE = """Invalid LLM API key."""23
24try:25from litellm.utils import Message26except ModuleNotFoundError:27Message = Any28
29logger = logging.getLogger(__name__)30
31CompletionClientType = Type[Completions]32
33
34def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:35import litellm36
37min_seconds = 438max_seconds = 1039# Wait 2^x * 1 second between each retry starting with40# 4 seconds, then up to 10 seconds, then 10 seconds afterwards41return retry(42reraise=True,43stop=stop_after_attempt(max_retries),44wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),45retry=(46retry_if_exception_type(litellm.exceptions.Timeout)47| retry_if_exception_type(litellm.exceptions.APIError)48| retry_if_exception_type(litellm.exceptions.APIConnectionError)49| retry_if_exception_type(litellm.exceptions.RateLimitError)50| retry_if_exception_type(litellm.exceptions.ServiceUnavailableError)51),52before_sleep=before_sleep_log(logger, logging.WARNING),53)54
55
56def completion_with_retry(is_chat_model: bool, max_retries: int, **kwargs: Any) -> Any:57from litellm import completion58
59"""Use tenacity to retry the completion call."""60retry_decorator = _create_retry_decorator(max_retries=max_retries)61
62@retry_decorator63def _completion_with_retry(**kwargs: Any) -> Any:64return completion(**kwargs)65
66return _completion_with_retry(**kwargs)67
68
69async def acompletion_with_retry(70is_chat_model: bool, max_retries: int, **kwargs: Any71) -> Any:72from litellm import acompletion73
74"""Use tenacity to retry the async completion call."""75retry_decorator = _create_retry_decorator(max_retries=max_retries)76
77@retry_decorator78async def _completion_with_retry(**kwargs: Any) -> Any:79# Use OpenAI's async api https://github.com/openai/openai-python#async-api80return await acompletion(**kwargs)81
82return await _completion_with_retry(**kwargs)83
84
85def openai_modelname_to_contextsize(modelname: str) -> int:86import litellm87
88"""Calculate the maximum number of tokens possible to generate for a model.89
90Args:
91modelname: The modelname we want to know the context size for.
92
93Returns:
94The maximum context size
95
96Example:
97.. code-block:: python
98
99max_tokens = openai.modelname_to_contextsize("text-davinci-003")
100
101Modified from:
102https://github.com/hwchase17/langchain/blob/master/langchain/llms/openai.py
103"""
104# handling finetuned models105if modelname.startswith("ft:"):106modelname = modelname.split(":")[1]107elif ":ft-" in modelname: # legacy fine-tuning108modelname = modelname.split(":")[0]109
110try:111context_size = int(litellm.get_max_tokens(modelname))112except Exception:113context_size = 2048 # by default assume models have at least 2048 tokens114
115if context_size is None:116raise ValueError(117f"Unknown model: {modelname}. Please provide a valid OpenAI model name."118"Known models are: "119+ ", ".join(litellm.model_list)120+ "\nKnown providers are: "121+ ", ".join(litellm.provider_list)122)123
124return context_size125
126
127def is_chat_model(model: str) -> bool:128import litellm129
130return model in litellm.model_list131
132
133def is_function_calling_model(model: str) -> bool:134is_chat_model_ = is_chat_model(model)135is_old = "0314" in model or "0301" in model136return is_chat_model_ and not is_old137
138
139def get_completion_endpoint(is_chat_model: bool) -> CompletionClientType:140from litellm import completion141
142return completion143
144
145def to_openai_message_dict(message: ChatMessage) -> dict:146"""Convert generic message to OpenAI message dict."""147message_dict = {148"role": message.role,149"content": message.content,150}151
152# NOTE: openai messages have additional arguments:153# - function messages have `name`154# - assistant messages have optional `function_call`155message_dict.update(message.additional_kwargs)156
157return message_dict158
159
160def to_openai_message_dicts(messages: Sequence[ChatMessage]) -> List[dict]:161"""Convert generic messages to OpenAI message dicts."""162return [to_openai_message_dict(message) for message in messages]163
164
165def from_openai_message_dict(message_dict: dict) -> ChatMessage:166"""Convert openai message dict to generic message."""167role = message_dict["role"]168# NOTE: Azure OpenAI returns function calling messages without a content key169content = message_dict.get("content", None)170
171additional_kwargs = message_dict.copy()172additional_kwargs.pop("role")173additional_kwargs.pop("content", None)174
175return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs)176
177
178def from_litellm_message(message: Message) -> ChatMessage:179"""Convert litellm.utils.Message instance to generic message."""180role = message.get("role")181# NOTE: Azure OpenAI returns function calling messages without a content key182content = message.get("content", None)183
184return ChatMessage(role=role, content=content)185
186
187def from_openai_message_dicts(message_dicts: Sequence[dict]) -> List[ChatMessage]:188"""Convert openai message dicts to generic messages."""189return [from_openai_message_dict(message_dict) for message_dict in message_dicts]190
191
192def to_openai_function(pydantic_class: Type[BaseModel]) -> Dict[str, Any]:193"""Convert pydantic class to OpenAI function."""194schema = pydantic_class.schema()195return {196"name": schema["title"],197"description": schema["description"],198"parameters": pydantic_class.schema(),199}200
201
202def validate_litellm_api_key(203api_key: Optional[str] = None, api_type: Optional[str] = None204) -> None:205import litellm206
207api_key = litellm.validate_environment()208if api_key is None:209raise ValueError(MISSING_API_KEY_ERROR_MESSAGE)210