9
from typing import Optional, Sequence, Union
13
from openai import openai_object
16
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
18
openai_org = os.getenv("OPENAI_ORG")
19
if openai_org is not None:
20
openai.organization = openai_org
21
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
25
class OpenAIDecodingArguments(object):
26
max_tokens: int = 1800
27
temperature: float = 0.2
31
stop: Optional[Sequence[str]] = None
32
presence_penalty: float = 0.0
33
frequency_penalty: float = 0.0
34
suffix: Optional[str] = None
35
logprobs: Optional[int] = None
40
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
41
decoding_args: OpenAIDecodingArguments,
43
model_name="text-davinci-003",
46
max_instances=sys.maxsize,
47
max_batches=sys.maxsize,
50
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
51
"""Decode with OpenAI API.
54
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
55
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
56
it can also be a dictionary (or list thereof) as explained here:
57
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
58
decoding_args: Decoding arguments.
59
api: 'completion' for openai.Completion or 'chat' for openai.ChatCompletion
60
model_name: Model name. Can be either in the format of "org/model" or just "model".
61
sleep_time: Time to sleep once the rate-limit is hit.
62
batch_size: Number of prompts to send in a single request. Only for non chat model.
63
max_instances: Maximum number of prompts to decode.
64
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
65
return_text: If True, return text instead of full completion object (which contains things like logprob).
66
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
69
A completion or a list of completions.
70
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
71
- a string (if return_text is True)
72
- an openai_object.OpenAIObject object (if return_text is False)
73
- a list of objects of the above types (if decoding_args.n > 1)
75
is_single_prompt = isinstance(prompts, (str, dict))
79
if api not in {"completion", "chat"}:
80
raise ValueError(f"Unsupported API type: {api}")
82
logging.warning(f"Chat API only supports batch size = 1, overriding requested batch size of {batch_size} .")
85
if max_batches < sys.maxsize:
87
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
88
"Setting `max_instances` to `max_batches * batch_size` for now."
90
max_instances = max_batches * batch_size
92
prompts = prompts[:max_instances]
93
num_prompts = len(prompts)
95
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
96
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
100
for batch_id, prompt_batch in tqdm.tqdm(
101
enumerate(prompt_batches),
102
desc="prompt_batches",
103
total=len(prompt_batches),
105
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
109
shared_kwargs = dict(
111
**batch_decoding_args.__dict__,
114
if api == "completion":
115
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
120
"content": "You are ChatGPT, a large language model trained by OpenAI. You answer as concisely as possible for each response (e.g. don\u2019t be verbose). It is very important that you answer as concisely as possible, so please remember this. If you are generating a list, do not have too many items. Keep the number of items short."
124
"content": prompt_batch[0]
127
for unused_key in ["suffix", "logprobs", "echo"]:
128
shared_kwargs.pop(unused_key, None)
129
completion_batch = openai.ChatCompletion.create(messages=messages, **shared_kwargs)
130
choices = completion_batch.choices
132
for choice in choices:
133
choice["total_tokens"] = completion_batch.usage.total_tokens
134
completions.extend(choices)
136
except openai.error.OpenAIError as e:
137
logging.warning(f"OpenAIError: {e}.")
138
if "Please reduce your prompt" in str(e):
139
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
140
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
142
logging.warning("Hit request rate limit; retrying...")
143
time.sleep(sleep_time) # Annoying rate limit on requests.
146
if api == "completion":
147
completions = [completion.text.encode('utf-8').decode('utf-8') for completion in completions]
149
completions = [completion.message.content.encode('utf-8').decode('utf-8') for completion in completions]
150
if decoding_args.n > 1:
151
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
152
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
154
# Return non-tuple if only 1 input and 1 generation.
155
(completions,) = completions
159
def _make_w_io_base(f, mode: str):
160
if not isinstance(f, io.IOBase):
161
f_dirname = os.path.dirname(f)
163
os.makedirs(f_dirname, exist_ok=True)
164
f = open(f, mode=mode, encoding="utf-8")
168
def _make_r_io_base(f, mode: str):
169
if not isinstance(f, io.IOBase):
170
f = open(f, mode=mode)
174
def jdump(obj, f, mode="w", indent=4, default=str):
175
"""Dump a str or dictionary to a file in json format.
178
obj: An object to be written.
179
f: A string path to the location on disk.
180
mode: Mode for opening the file.
181
indent: Indent for storing json dictionaries.
182
default: A function to handle non-serializable entries; defaults to `str`.
184
f = _make_w_io_base(f, mode)
185
if isinstance(obj, (dict, list)):
186
json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
187
elif isinstance(obj, str):
190
raise ValueError(f"Unexpected type: {type(obj)}")
194
def jload(f, mode="r"):
195
"""Load a .json file into a dictionary."""
196
f = _make_r_io_base(f, mode)