llama-index

Форк
0
203 строки · 6.0 Кб
1
import logging
2
from abc import ABC, abstractmethod
3
from typing import Any, Callable, Optional, Sequence
4

5
from tenacity import (
6
    before_sleep_log,
7
    retry,
8
    retry_if_exception_type,
9
    stop_after_attempt,
10
    wait_exponential,
11
)
12

13
from llama_index.legacy.core.llms.types import ChatMessage
14
from llama_index.legacy.llms.anthropic_utils import messages_to_anthropic_prompt
15
from llama_index.legacy.llms.generic_utils import (
16
    prompt_to_messages,
17
)
18
from llama_index.legacy.llms.llama_utils import (
19
    completion_to_prompt as completion_to_llama_prompt,
20
)
21
from llama_index.legacy.llms.llama_utils import (
22
    messages_to_prompt as messages_to_llama_prompt,
23
)
24

25
HUMAN_PREFIX = "\n\nHuman:"
26
ASSISTANT_PREFIX = "\n\nAssistant:"
27

28
# Values taken from https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html#model-parameters-claude
29
COMPLETION_MODELS = {
30
    "amazon.titan-tg1-large": 8000,
31
    "amazon.titan-text-express-v1": 8000,
32
    "ai21.j2-grande-instruct": 8000,
33
    "ai21.j2-jumbo-instruct": 8000,
34
    "ai21.j2-mid": 8000,
35
    "ai21.j2-mid-v1": 8000,
36
    "ai21.j2-ultra": 8000,
37
    "ai21.j2-ultra-v1": 8000,
38
    "cohere.command-text-v14": 4096,
39
}
40

41
# Anthropic models require prompt to start with "Human:" and
42
# end with "Assistant:"
43
CHAT_ONLY_MODELS = {
44
    "anthropic.claude-instant-v1": 100000,
45
    "anthropic.claude-v1": 100000,
46
    "anthropic.claude-v2": 100000,
47
    "anthropic.claude-v2:1": 200000,
48
    "meta.llama2-13b-chat-v1": 2048,
49
    "meta.llama2-70b-chat-v1": 4096,
50
}
51
BEDROCK_FOUNDATION_LLMS = {**COMPLETION_MODELS, **CHAT_ONLY_MODELS}
52

53
# Only the following models support streaming as
54
# per result of Bedrock.Client.list_foundation_models
55
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
56
STREAMING_MODELS = {
57
    "amazon.titan-tg1-large",
58
    "amazon.titan-text-express-v1",
59
    "anthropic.claude-instant-v1",
60
    "anthropic.claude-v1",
61
    "anthropic.claude-v2",
62
    "anthropic.claude-v2:1",
63
    "meta.llama2-13b-chat-v1",
64
}
65

66

67
class Provider(ABC):
68
    @property
69
    @abstractmethod
70
    def max_tokens_key(self) -> str:
71
        ...
72

73
    @abstractmethod
74
    def get_text_from_response(self, response: dict) -> str:
75
        ...
76

77
    def get_text_from_stream_response(self, response: dict) -> str:
78
        return self.get_text_from_response(response)
79

80
    def get_request_body(self, prompt: str, inference_parameters: dict) -> dict:
81
        return {"prompt": prompt, **inference_parameters}
82

83
    messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None
84
    completion_to_prompt: Optional[Callable[[str], str]] = None
85

86

87
class AmazonProvider(Provider):
88
    max_tokens_key = "maxTokenCount"
89

90
    def get_text_from_response(self, response: dict) -> str:
91
        return response["results"][0]["outputText"]
92

93
    def get_text_from_stream_response(self, response: dict) -> str:
94
        return response["outputText"]
95

96
    def get_request_body(self, prompt: str, inference_parameters: dict) -> dict:
97
        return {
98
            "inputText": prompt,
99
            "textGenerationConfig": {**inference_parameters},
100
        }
101

102

103
class Ai21Provider(Provider):
104
    max_tokens_key = "maxTokens"
105

106
    def get_text_from_response(self, response: dict) -> str:
107
        return response["completions"][0]["data"]["text"]
108

109

110
def completion_to_anthopic_prompt(completion: str) -> str:
111
    return messages_to_anthropic_prompt(prompt_to_messages(completion))
112

113

114
class AnthropicProvider(Provider):
115
    max_tokens_key = "max_tokens_to_sample"
116

117
    def __init__(self) -> None:
118
        self.messages_to_prompt = messages_to_anthropic_prompt
119
        self.completion_to_prompt = completion_to_anthopic_prompt
120

121
    def get_text_from_response(self, response: dict) -> str:
122
        return response["completion"]
123

124

125
class CohereProvider(Provider):
126
    max_tokens_key = "max_tokens"
127

128
    def get_text_from_response(self, response: dict) -> str:
129
        return response["generations"][0]["text"]
130

131

132
class MetaProvider(Provider):
133
    max_tokens_key = "max_gen_len"
134

135
    def __init__(self) -> None:
136
        self.messages_to_prompt = messages_to_llama_prompt
137
        self.completion_to_prompt = completion_to_llama_prompt
138

139
    def get_text_from_response(self, response: dict) -> str:
140
        return response["generation"]
141

142

143
PROVIDERS = {
144
    "amazon": AmazonProvider(),
145
    "ai21": Ai21Provider(),
146
    "anthropic": AnthropicProvider(),
147
    "cohere": CohereProvider(),
148
    "meta": MetaProvider(),
149
}
150

151

152
def get_provider(model: str) -> Provider:
153
    provider_name = model.split(".")[0]
154
    if provider_name not in PROVIDERS:
155
        raise ValueError(f"Provider {provider_name} for model {model} is not supported")
156
    return PROVIDERS[provider_name]
157

158

159
logger = logging.getLogger(__name__)
160

161

162
def _create_retry_decorator(client: Any, max_retries: int) -> Callable[[Any], Any]:
163
    min_seconds = 4
164
    max_seconds = 10
165
    # Wait 2^x * 1 second between each retry starting with
166
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
167
    try:
168
        import boto3  # noqa
169
    except ImportError as e:
170
        raise ImportError(
171
            "You must install the `boto3` package to use Bedrock."
172
            "Please `pip install boto3`"
173
        ) from e
174

175
    return retry(
176
        reraise=True,
177
        stop=stop_after_attempt(max_retries),
178
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
179
        retry=(retry_if_exception_type(client.exceptions.ThrottlingException)),
180
        before_sleep=before_sleep_log(logger, logging.WARNING),
181
    )
182

183

184
def completion_with_retry(
185
    client: Any,
186
    model: str,
187
    request_body: str,
188
    max_retries: int,
189
    stream: bool = False,
190
    **kwargs: Any,
191
) -> Any:
192
    """Use tenacity to retry the completion call."""
193
    retry_decorator = _create_retry_decorator(client=client, max_retries=max_retries)
194

195
    @retry_decorator
196
    def _completion_with_retry(**kwargs: Any) -> Any:
197
        if stream:
198
            return client.invoke_model_with_response_stream(
199
                modelId=model, body=request_body
200
            )
201
        return client.invoke_model(modelId=model, body=request_body)
202

203
    return _completion_with_retry(**kwargs)
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.