llama-index
112 строк · 3.1 Кб
1import logging2from typing import Any, Callable, Dict, List, Optional, Sequence3
4from tenacity import (5before_sleep_log,6retry,7retry_if_exception_type,8stop_after_attempt,9wait_exponential,10)
11
12from llama_index.legacy.core.llms.types import ChatMessage13
14COMMAND_MODELS = {15"command": 4096,16"command-nightly": 4096,17"command-light": 4096,18"command-light-nightly": 4096,19}
20
21GENERATION_MODELS = {"base": 2048, "base-light": 2048}22
23REPRESENTATION_MODELS = {24"embed-english-light-v2.0": 512,25"embed-english-v2.0": 512,26"embed-multilingual-v2.0": 256,27}
28
29ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}30CHAT_MODELS = {**COMMAND_MODELS}31
32logger = logging.getLogger(__name__)33
34
35def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:36min_seconds = 437max_seconds = 1038# Wait 2^x * 1 second between each retry starting with39# 4 seconds, then up to 10 seconds, then 10 seconds afterwards40try:41import cohere42except ImportError as e:43raise ImportError(44"You must install the `cohere` package to use Cohere."45"Please `pip install cohere`"46) from e47
48return retry(49reraise=True,50stop=stop_after_attempt(max_retries),51wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),52retry=(retry_if_exception_type(cohere.error.CohereConnectionError)),53before_sleep=before_sleep_log(logger, logging.WARNING),54)55
56
57def completion_with_retry(58client: Any, max_retries: int, chat: bool = False, **kwargs: Any59) -> Any:60"""Use tenacity to retry the completion call."""61retry_decorator = _create_retry_decorator(max_retries=max_retries)62
63@retry_decorator64def _completion_with_retry(**kwargs: Any) -> Any:65if chat:66return client.chat(**kwargs)67else:68return client.generate(**kwargs)69
70return _completion_with_retry(**kwargs)71
72
73async def acompletion_with_retry(74aclient: Any,75max_retries: int,76chat: bool = False,77**kwargs: Any,78) -> Any:79"""Use tenacity to retry the async completion call."""80retry_decorator = _create_retry_decorator(max_retries=max_retries)81
82@retry_decorator83async def _completion_with_retry(**kwargs: Any) -> Any:84if chat:85return await aclient.chat(**kwargs)86else:87return await aclient.generate(**kwargs)88
89return await _completion_with_retry(**kwargs)90
91
92def cohere_modelname_to_contextsize(modelname: str) -> int:93context_size = ALL_AVAILABLE_MODELS.get(modelname, None)94if context_size is None:95raise ValueError(96f"Unknown model: {modelname}. Please provide a valid Cohere model name."97"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())98)99
100return context_size101
102
103def is_chat_model(model: str) -> bool:104return model in COMMAND_MODELS105
106
107def messages_to_cohere_history(108messages: Sequence[ChatMessage],109) -> List[Dict[str, Optional[str]]]:110return [111{"user_name": message.role, "message": message.content} for message in messages112]113