belle

Форк
0
/
utils.py 
199 строк · 7.9 Кб
1
import dataclasses
2
import logging
3
import math
4
import os
5
import io
6
import sys
7
import time
8
import json
9
from typing import Optional, Sequence, Union
10

11
import openai
12
import tqdm
13
from openai import openai_object
14
import copy
15

16
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
17

18
openai_org = os.getenv("OPENAI_ORG")
19
if openai_org is not None:
20
    openai.organization = openai_org
21
    logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
22

23

24
@dataclasses.dataclass
25
class OpenAIDecodingArguments(object):
26
    max_tokens: int = 1800
27
    temperature: float = 0.2
28
    top_p: float = 1.0
29
    n: int = 1
30
    stream: bool = False
31
    stop: Optional[Sequence[str]] = None
32
    presence_penalty: float = 0.0
33
    frequency_penalty: float = 0.0
34
    suffix: Optional[str] = None
35
    logprobs: Optional[int] = None
36
    echo: bool = False
37

38

39
def openai_completion(
40
    prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
41
    decoding_args: OpenAIDecodingArguments,
42
    api="completion",
43
    model_name="text-davinci-003",
44
    sleep_time=2,
45
    batch_size=1,
46
    max_instances=sys.maxsize,
47
    max_batches=sys.maxsize,
48
    return_text=False,
49
    **decoding_kwargs,
50
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
51
    """Decode with OpenAI API.
52

53
    Args:
54
        prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
55
            as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
56
            it can also be a dictionary (or list thereof) as explained here:
57
            https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
58
        decoding_args: Decoding arguments.
59
        api: 'completion' for openai.Completion or 'chat' for openai.ChatCompletion
60
        model_name: Model name. Can be either in the format of "org/model" or just "model".
61
        sleep_time: Time to sleep once the rate-limit is hit.
62
        batch_size: Number of prompts to send in a single request. Only for non chat model.
63
        max_instances: Maximum number of prompts to decode.
64
        max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
65
        return_text: If True, return text instead of full completion object (which contains things like logprob).
66
        decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
67

68
    Returns:
69
        A completion or a list of completions.
70
        Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
71
            - a string (if return_text is True)
72
            - an openai_object.OpenAIObject object (if return_text is False)
73
            - a list of objects of the above types (if decoding_args.n > 1)
74
    """
75
    is_single_prompt = isinstance(prompts, (str, dict))
76
    if is_single_prompt:
77
        prompts = [prompts]
78

79
    if api not in {"completion", "chat"}:
80
        raise ValueError(f"Unsupported API type: {api}")
81
    if api == "chat":
82
        logging.warning(f"Chat API only supports batch size = 1, overriding requested batch size of {batch_size} .")
83
        batch_size = 1
84

85
    if max_batches < sys.maxsize:
86
        logging.warning(
87
            "`max_batches` will be deprecated in the future, please use `max_instances` instead."
88
            "Setting `max_instances` to `max_batches * batch_size` for now."
89
        )
90
        max_instances = max_batches * batch_size
91

92
    prompts = prompts[:max_instances]
93
    num_prompts = len(prompts)
94
    prompt_batches = [
95
        prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
96
        for batch_id in range(int(math.ceil(num_prompts / batch_size)))
97
    ]
98

99
    completions = []
100
    for batch_id, prompt_batch in tqdm.tqdm(
101
        enumerate(prompt_batches),
102
        desc="prompt_batches",
103
        total=len(prompt_batches),
104
    ):
105
        batch_decoding_args = copy.deepcopy(decoding_args)  # cloning the decoding_args
106

107
        while True:
108
            try:
109
                shared_kwargs = dict(
110
                    model=model_name,
111
                    **batch_decoding_args.__dict__,
112
                    **decoding_kwargs,
113
                )
114
                if api == "completion":
115
                    completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
116
                elif api == "chat":
117
                    messages=[
118
                        {
119
                            "role": "system",
120
                            "content": "You are ChatGPT, a large language model trained by OpenAI. You answer as concisely as possible for each response (e.g. don\u2019t be verbose). It is very important that you answer as concisely as possible, so please remember this. If you are generating a list, do not have too many items. Keep the number of items short."
121
                        },
122
                        {
123
                            "role": "user",
124
                            "content": prompt_batch[0]
125
                        }
126
                    ]
127
                    for unused_key in ["suffix", "logprobs", "echo"]:
128
                        shared_kwargs.pop(unused_key, None)
129
                    completion_batch = openai.ChatCompletion.create(messages=messages, **shared_kwargs)
130
                choices = completion_batch.choices
131

132
                for choice in choices:
133
                    choice["total_tokens"] = completion_batch.usage.total_tokens
134
                completions.extend(choices)
135
                break
136
            except openai.error.OpenAIError as e:
137
                logging.warning(f"OpenAIError: {e}.")
138
                if "Please reduce your prompt" in str(e):
139
                    batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
140
                    logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
141
                else:
142
                    logging.warning("Hit request rate limit; retrying...")
143
                    time.sleep(sleep_time)  # Annoying rate limit on requests.
144

145
    if return_text:
146
        if api == "completion":
147
            completions = [completion.text.encode('utf-8').decode('utf-8') for completion in completions]
148
        elif api == "chat":
149
            completions = [completion.message.content.encode('utf-8').decode('utf-8') for completion in completions]
150
    if decoding_args.n > 1:
151
        # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
152
        completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
153
    if is_single_prompt:
154
        # Return non-tuple if only 1 input and 1 generation.
155
        (completions,) = completions
156
    return completions
157

158

159
def _make_w_io_base(f, mode: str):
160
    if not isinstance(f, io.IOBase):
161
        f_dirname = os.path.dirname(f)
162
        if f_dirname != "":
163
            os.makedirs(f_dirname, exist_ok=True)
164
        f = open(f, mode=mode, encoding="utf-8")
165
    return f
166

167

168
def _make_r_io_base(f, mode: str):
169
    if not isinstance(f, io.IOBase):
170
        f = open(f, mode=mode)
171
    return f
172

173

174
def jdump(obj, f, mode="w", indent=4, default=str):
175
    """Dump a str or dictionary to a file in json format.
176

177
    Args:
178
        obj: An object to be written.
179
        f: A string path to the location on disk.
180
        mode: Mode for opening the file.
181
        indent: Indent for storing json dictionaries.
182
        default: A function to handle non-serializable entries; defaults to `str`.
183
    """
184
    f = _make_w_io_base(f, mode)
185
    if isinstance(obj, (dict, list)):
186
        json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
187
    elif isinstance(obj, str):
188
        f.write(obj)
189
    else:
190
        raise ValueError(f"Unexpected type: {type(obj)}")
191
    f.close()
192

193

194
def jload(f, mode="r"):
195
    """Load a .json file into a dictionary."""
196
    f = _make_r_io_base(f, mode)
197
    jdict = json.load(f)
198
    f.close()
199
    return jdict
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.