llama-index

Форк
0
230 строк · 7.5 Кб
1
# utils script
2

3
# generation with retry
4
import logging
5
from typing import Any, Callable, Optional
6

7
from tenacity import (
8
    before_sleep_log,
9
    retry,
10
    retry_if_exception_type,
11
    stop_after_attempt,
12
    wait_exponential,
13
)
14

15
from llama_index.legacy.core.llms.types import ChatMessage, MessageRole
16

17
CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"]
18
TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"]
19
CODE_MODELS = ["code-bison", "code-bison-32k", "code-bison@001"]
20
CODE_CHAT_MODELS = ["codechat-bison", "codechat-bison-32k", "codechat-bison@001"]
21

22

23
logger = logging.getLogger(__name__)
24

25

26
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
27
    import google.api_core
28

29
    min_seconds = 4
30
    max_seconds = 10
31

32
    return retry(
33
        reraise=True,
34
        stop=stop_after_attempt(max_retries),
35
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
36
        retry=(
37
            retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
38
            | retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
39
            | retry_if_exception_type(google.api_core.exceptions.Aborted)
40
            | retry_if_exception_type(google.api_core.exceptions.DeadlineExceeded)
41
        ),
42
        before_sleep=before_sleep_log(logger, logging.WARNING),
43
    )
44

45

46
def completion_with_retry(
47
    client: Any,
48
    prompt: Optional[Any],
49
    max_retries: int = 5,
50
    chat: bool = False,
51
    stream: bool = False,
52
    is_gemini: bool = False,
53
    params: Any = {},
54
    **kwargs: Any,
55
) -> Any:
56
    """Use tenacity to retry the completion call."""
57
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
58

59
    @retry_decorator
60
    def _completion_with_retry(**kwargs: Any) -> Any:
61
        if is_gemini:
62
            history = params["message_history"] if "message_history" in params else []
63

64
            generation = client.start_chat(history=history)
65
            generation_config = dict(kwargs)
66
            return generation.send_message(
67
                prompt, stream=stream, generation_config=generation_config
68
            )
69
        elif chat:
70
            generation = client.start_chat(**params)
71
            if stream:
72
                return generation.send_message_streaming(prompt, **kwargs)
73
            else:
74
                return generation.send_message(prompt, **kwargs)
75
        else:
76
            if stream:
77
                return client.predict_streaming(prompt, **kwargs)
78
            else:
79
                return client.predict(prompt, **kwargs)
80

81
    return _completion_with_retry(**kwargs)
82

83

84
async def acompletion_with_retry(
85
    client: Any,
86
    prompt: Optional[str],
87
    max_retries: int = 5,
88
    chat: bool = False,
89
    is_gemini: bool = False,
90
    params: Any = {},
91
    **kwargs: Any,
92
) -> Any:
93
    """Use tenacity to retry the completion call."""
94
    retry_decorator = _create_retry_decorator(max_retries=max_retries)
95

96
    @retry_decorator
97
    async def _completion_with_retry(**kwargs: Any) -> Any:
98
        if is_gemini:
99
            history = params["message_history"] if "message_history" in params else []
100

101
            generation = client.start_chat(history=history)
102
            generation_config = dict(kwargs)
103
            return await generation.send_message_async(
104
                prompt, generation_config=generation_config
105
            )
106
        elif chat:
107
            generation = client.start_chat(**params)
108
            return await generation.send_message_async(prompt, **kwargs)
109
        else:
110
            return await client.predict_async(prompt, **kwargs)
111

112
    return await _completion_with_retry(**kwargs)
113

114

115
def init_vertexai(
116
    project: Optional[str] = None,
117
    location: Optional[str] = None,
118
    credentials: Optional[Any] = None,
119
) -> None:
120
    """Init vertexai.
121

122
    Args:
123
        project: The default GCP project to use when making Vertex API calls.
124
        location: The default location to use when making API calls.
125
        credentials: The default custom
126
            credentials to use when making API calls. If not provided credentials
127
            will be ascertained from the environment.
128

129
    Raises:
130
        ImportError: If importing vertexai SDK did not succeed.
131
    """
132
    try:
133
        import vertexai
134
    except ImportError:
135
        raise (ValueError(f"Please install vertex AI client by following the steps"))
136

137
    vertexai.init(
138
        project=project,
139
        location=location,
140
        credentials=credentials,
141
    )
142

143

144
def _parse_message(message: ChatMessage, is_gemini: bool) -> Any:
145
    if is_gemini:
146
        from llama_index.legacy.llms.vertex_gemini_utils import (
147
            convert_chat_message_to_gemini_content,
148
        )
149

150
        return convert_chat_message_to_gemini_content(message=message, is_history=False)
151
    else:
152
        return message.content
153

154

155
def _parse_chat_history(history: Any, is_gemini: bool) -> Any:
156
    """Parse a sequence of messages into history.
157

158
    Args:
159
        history: The list of messages to re-create the history of the chat.
160

161
    Returns:
162
        A parsed chat history.
163

164
    Raises:
165
        ValueError: If a sequence of message has a SystemMessage not at the
166
        first place.
167
    """
168
    from vertexai.language_models import ChatMessage
169

170
    vertex_messages, context = [], None
171
    for i, message in enumerate(history):
172
        if i == 0 and message.role == MessageRole.SYSTEM:
173
            if is_gemini:
174
                raise ValueError("Gemini model don't support system messages")
175
            context = message.content
176
        elif message.role == MessageRole.ASSISTANT or message.role == MessageRole.USER:
177
            if is_gemini:
178
                from llama_index.legacy.llms.vertex_gemini_utils import (
179
                    convert_chat_message_to_gemini_content,
180
                )
181

182
                vertex_messages.append(
183
                    convert_chat_message_to_gemini_content(
184
                        message=message, is_history=True
185
                    )
186
                )
187
            else:
188
                vertex_message = ChatMessage(
189
                    content=message.content,
190
                    author="bot" if message.role == MessageRole.ASSISTANT else "user",
191
                )
192
                vertex_messages.append(vertex_message)
193
        else:
194
            raise ValueError(
195
                f"Unexpected message with type {type(message)} at the position {i}."
196
            )
197
    if len(vertex_messages) % 2 != 0:
198
        raise ValueError("total no of messages should be even")
199

200
    return {"context": context, "message_history": vertex_messages}
201

202

203
def _parse_examples(examples: Any) -> Any:
204
    from vertexai.language_models import InputOutputTextPair
205

206
    if len(examples) % 2 != 0:
207
        raise ValueError(
208
            f"Expect examples to have an even amount of messages, got {len(examples)}."
209
        )
210
    example_pairs = []
211
    input_text = None
212
    for i, example in enumerate(examples):
213
        if i % 2 == 0:
214
            if not example.role == MessageRole.USER:
215
                raise ValueError(
216
                    f"Expected the first message in a part to be from user, got "
217
                    f"{type(example)} for the {i}th message."
218
                )
219
            input_text = example.content
220
        if i % 2 == 1:
221
            if not example.role == MessageRole.ASSISTANT:
222
                raise ValueError(
223
                    f"Expected the second message in a part to be from AI, got "
224
                    f"{type(example)} for the {i}th message."
225
                )
226
            pair = InputOutputTextPair(
227
                input_text=input_text, output_text=example.content
228
            )
229
            example_pairs.append(pair)
230
    return example_pairs
231

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.