stanford_alpaca

Форк
0
/
utils.py 
173 строки · 6.3 Кб
1
import dataclasses
2
import logging
3
import math
4
import os
5
import io
6
import sys
7
import time
8
import json
9
from typing import Optional, Sequence, Union
10

11
import openai
12
import tqdm
13
from openai import openai_object
14
import copy
15

16
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
17

18
openai_org = os.getenv("OPENAI_ORG")
19
if openai_org is not None:
20
    openai.organization = openai_org
21
    logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
22

23

24
@dataclasses.dataclass
25
class OpenAIDecodingArguments(object):
26
    max_tokens: int = 1800
27
    temperature: float = 0.2
28
    top_p: float = 1.0
29
    n: int = 1
30
    stream: bool = False
31
    stop: Optional[Sequence[str]] = None
32
    presence_penalty: float = 0.0
33
    frequency_penalty: float = 0.0
34
    suffix: Optional[str] = None
35
    logprobs: Optional[int] = None
36
    echo: bool = False
37

38

39
def openai_completion(
40
    prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
41
    decoding_args: OpenAIDecodingArguments,
42
    model_name="text-davinci-003",
43
    sleep_time=2,
44
    batch_size=1,
45
    max_instances=sys.maxsize,
46
    max_batches=sys.maxsize,
47
    return_text=False,
48
    **decoding_kwargs,
49
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
50
    """Decode with OpenAI API.
51

52
    Args:
53
        prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
54
            as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
55
            it can also be a dictionary (or list thereof) as explained here:
56
            https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
57
        decoding_args: Decoding arguments.
58
        model_name: Model name. Can be either in the format of "org/model" or just "model".
59
        sleep_time: Time to sleep once the rate-limit is hit.
60
        batch_size: Number of prompts to send in a single request. Only for non chat model.
61
        max_instances: Maximum number of prompts to decode.
62
        max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
63
        return_text: If True, return text instead of full completion object (which contains things like logprob).
64
        decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
65

66
    Returns:
67
        A completion or a list of completions.
68
        Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
69
            - a string (if return_text is True)
70
            - an openai_object.OpenAIObject object (if return_text is False)
71
            - a list of objects of the above types (if decoding_args.n > 1)
72
    """
73
    is_single_prompt = isinstance(prompts, (str, dict))
74
    if is_single_prompt:
75
        prompts = [prompts]
76

77
    if max_batches < sys.maxsize:
78
        logging.warning(
79
            "`max_batches` will be deprecated in the future, please use `max_instances` instead."
80
            "Setting `max_instances` to `max_batches * batch_size` for now."
81
        )
82
        max_instances = max_batches * batch_size
83

84
    prompts = prompts[:max_instances]
85
    num_prompts = len(prompts)
86
    prompt_batches = [
87
        prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
88
        for batch_id in range(int(math.ceil(num_prompts / batch_size)))
89
    ]
90

91
    completions = []
92
    for batch_id, prompt_batch in tqdm.tqdm(
93
        enumerate(prompt_batches),
94
        desc="prompt_batches",
95
        total=len(prompt_batches),
96
    ):
97
        batch_decoding_args = copy.deepcopy(decoding_args)  # cloning the decoding_args
98

99
        while True:
100
            try:
101
                shared_kwargs = dict(
102
                    model=model_name,
103
                    **batch_decoding_args.__dict__,
104
                    **decoding_kwargs,
105
                )
106
                completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
107
                choices = completion_batch.choices
108

109
                for choice in choices:
110
                    choice["total_tokens"] = completion_batch.usage.total_tokens
111
                completions.extend(choices)
112
                break
113
            except openai.error.OpenAIError as e:
114
                logging.warning(f"OpenAIError: {e}.")
115
                if "Please reduce your prompt" in str(e):
116
                    batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
117
                    logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
118
                else:
119
                    logging.warning("Hit request rate limit; retrying...")
120
                    time.sleep(sleep_time)  # Annoying rate limit on requests.
121

122
    if return_text:
123
        completions = [completion.text for completion in completions]
124
    if decoding_args.n > 1:
125
        # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
126
        completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
127
    if is_single_prompt:
128
        # Return non-tuple if only 1 input and 1 generation.
129
        (completions,) = completions
130
    return completions
131

132

133
def _make_w_io_base(f, mode: str):
134
    if not isinstance(f, io.IOBase):
135
        f_dirname = os.path.dirname(f)
136
        if f_dirname != "":
137
            os.makedirs(f_dirname, exist_ok=True)
138
        f = open(f, mode=mode)
139
    return f
140

141

142
def _make_r_io_base(f, mode: str):
143
    if not isinstance(f, io.IOBase):
144
        f = open(f, mode=mode)
145
    return f
146

147

148
def jdump(obj, f, mode="w", indent=4, default=str):
149
    """Dump a str or dictionary to a file in json format.
150

151
    Args:
152
        obj: An object to be written.
153
        f: A string path to the location on disk.
154
        mode: Mode for opening the file.
155
        indent: Indent for storing json dictionaries.
156
        default: A function to handle non-serializable entries; defaults to `str`.
157
    """
158
    f = _make_w_io_base(f, mode)
159
    if isinstance(obj, (dict, list)):
160
        json.dump(obj, f, indent=indent, default=default)
161
    elif isinstance(obj, str):
162
        f.write(obj)
163
    else:
164
        raise ValueError(f"Unexpected type: {type(obj)}")
165
    f.close()
166

167

168
def jload(f, mode="r"):
169
    """Load a .json file into a dictionary."""
170
    f = _make_r_io_base(f, mode)
171
    jdict = json.load(f)
172
    f.close()
173
    return jdict
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.