llama-index

Форк
0
112 строк · 3.1 Кб
1
import logging
2
from typing import Any, Callable, Dict, List, Optional, Sequence
3

4
from tenacity import (
5
    before_sleep_log,
6
    retry,
7
    retry_if_exception_type,
8
    stop_after_attempt,
9
    wait_exponential,
10
)
11

12
from llama_index.legacy.core.llms.types import ChatMessage
13

14
COMMAND_MODELS = {
15
    "command": 4096,
16
    "command-nightly": 4096,
17
    "command-light": 4096,
18
    "command-light-nightly": 4096,
19
}
20

21
GENERATION_MODELS = {"base": 2048, "base-light": 2048}
22

23
REPRESENTATION_MODELS = {
24
    "embed-english-light-v2.0": 512,
25
    "embed-english-v2.0": 512,
26
    "embed-multilingual-v2.0": 256,
27
}
28

29
ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}
30
CHAT_MODELS = {**COMMAND_MODELS}
31

32
logger = logging.getLogger(__name__)
33

34

35
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
36
    min_seconds = 4
37
    max_seconds = 10
38
    # Wait 2^x * 1 second between each retry starting with
39
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
40
    try:
41
        import cohere
42
    except ImportError as e:
43
        raise ImportError(
44
            "You must install the `cohere` package to use Cohere."
45
            "Please `pip install cohere`"
46
        ) from e
47

48
    return retry(
49
        reraise=True,
50
        stop=stop_after_attempt(max_retries),
51
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
52
        retry=(retry_if_exception_type(cohere.error.CohereConnectionError)),
53
        before_sleep=before_sleep_log(logger, logging.WARNING),
54
    )
55

56

57
def completion_with_retry(
58
    client: Any, max_retries: int, chat: bool = False, **kwargs: Any
59
) -> Any:
60
    """Use tenacity to retry the completion call."""
61
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
62

63
    @retry_decorator
64
    def _completion_with_retry(**kwargs: Any) -> Any:
65
        if chat:
66
            return client.chat(**kwargs)
67
        else:
68
            return client.generate(**kwargs)
69

70
    return _completion_with_retry(**kwargs)
71

72

73
async def acompletion_with_retry(
74
    aclient: Any,
75
    max_retries: int,
76
    chat: bool = False,
77
    **kwargs: Any,
78
) -> Any:
79
    """Use tenacity to retry the async completion call."""
80
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
81

82
    @retry_decorator
83
    async def _completion_with_retry(**kwargs: Any) -> Any:
84
        if chat:
85
            return await aclient.chat(**kwargs)
86
        else:
87
            return await aclient.generate(**kwargs)
88

89
    return await _completion_with_retry(**kwargs)
90

91

92
def cohere_modelname_to_contextsize(modelname: str) -> int:
93
    context_size = ALL_AVAILABLE_MODELS.get(modelname, None)
94
    if context_size is None:
95
        raise ValueError(
96
            f"Unknown model: {modelname}. Please provide a valid Cohere model name."
97
            "Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
98
        )
99

100
    return context_size
101

102

103
def is_chat_model(model: str) -> bool:
104
    return model in COMMAND_MODELS
105

106

107
def messages_to_cohere_history(
108
    messages: Sequence[ChatMessage],
109
) -> List[Dict[str, Optional[str]]]:
110
    return [
111
        {"user_name": message.role, "message": message.content} for message in messages
112
    ]
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.