stanford_alpaca
/
utils.py
173 строки · 6.3 Кб
1import dataclasses2import logging3import math4import os5import io6import sys7import time8import json9from typing import Optional, Sequence, Union10
11import openai12import tqdm13from openai import openai_object14import copy15
16StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]17
18openai_org = os.getenv("OPENAI_ORG")19if openai_org is not None:20openai.organization = openai_org21logging.warning(f"Switching to organization: {openai_org} for OAI API key.")22
23
24@dataclasses.dataclass25class OpenAIDecodingArguments(object):26max_tokens: int = 180027temperature: float = 0.228top_p: float = 1.029n: int = 130stream: bool = False31stop: Optional[Sequence[str]] = None32presence_penalty: float = 0.033frequency_penalty: float = 0.034suffix: Optional[str] = None35logprobs: Optional[int] = None36echo: bool = False37
38
39def openai_completion(40prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],41decoding_args: OpenAIDecodingArguments,42model_name="text-davinci-003",43sleep_time=2,44batch_size=1,45max_instances=sys.maxsize,46max_batches=sys.maxsize,47return_text=False,48**decoding_kwargs,49) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:50"""Decode with OpenAI API.51
52Args:
53prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
54as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
55it can also be a dictionary (or list thereof) as explained here:
56https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
57decoding_args: Decoding arguments.
58model_name: Model name. Can be either in the format of "org/model" or just "model".
59sleep_time: Time to sleep once the rate-limit is hit.
60batch_size: Number of prompts to send in a single request. Only for non chat model.
61max_instances: Maximum number of prompts to decode.
62max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
63return_text: If True, return text instead of full completion object (which contains things like logprob).
64decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
65
66Returns:
67A completion or a list of completions.
68Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
69- a string (if return_text is True)
70- an openai_object.OpenAIObject object (if return_text is False)
71- a list of objects of the above types (if decoding_args.n > 1)
72"""
73is_single_prompt = isinstance(prompts, (str, dict))74if is_single_prompt:75prompts = [prompts]76
77if max_batches < sys.maxsize:78logging.warning(79"`max_batches` will be deprecated in the future, please use `max_instances` instead."80"Setting `max_instances` to `max_batches * batch_size` for now."81)82max_instances = max_batches * batch_size83
84prompts = prompts[:max_instances]85num_prompts = len(prompts)86prompt_batches = [87prompts[batch_id * batch_size : (batch_id + 1) * batch_size]88for batch_id in range(int(math.ceil(num_prompts / batch_size)))89]90
91completions = []92for batch_id, prompt_batch in tqdm.tqdm(93enumerate(prompt_batches),94desc="prompt_batches",95total=len(prompt_batches),96):97batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args98
99while True:100try:101shared_kwargs = dict(102model=model_name,103**batch_decoding_args.__dict__,104**decoding_kwargs,105)106completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)107choices = completion_batch.choices108
109for choice in choices:110choice["total_tokens"] = completion_batch.usage.total_tokens111completions.extend(choices)112break113except openai.error.OpenAIError as e:114logging.warning(f"OpenAIError: {e}.")115if "Please reduce your prompt" in str(e):116batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)117logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")118else:119logging.warning("Hit request rate limit; retrying...")120time.sleep(sleep_time) # Annoying rate limit on requests.121
122if return_text:123completions = [completion.text for completion in completions]124if decoding_args.n > 1:125# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.126completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]127if is_single_prompt:128# Return non-tuple if only 1 input and 1 generation.129(completions,) = completions130return completions131
132
133def _make_w_io_base(f, mode: str):134if not isinstance(f, io.IOBase):135f_dirname = os.path.dirname(f)136if f_dirname != "":137os.makedirs(f_dirname, exist_ok=True)138f = open(f, mode=mode)139return f140
141
142def _make_r_io_base(f, mode: str):143if not isinstance(f, io.IOBase):144f = open(f, mode=mode)145return f146
147
148def jdump(obj, f, mode="w", indent=4, default=str):149"""Dump a str or dictionary to a file in json format.150
151Args:
152obj: An object to be written.
153f: A string path to the location on disk.
154mode: Mode for opening the file.
155indent: Indent for storing json dictionaries.
156default: A function to handle non-serializable entries; defaults to `str`.
157"""
158f = _make_w_io_base(f, mode)159if isinstance(obj, (dict, list)):160json.dump(obj, f, indent=indent, default=default)161elif isinstance(obj, str):162f.write(obj)163else:164raise ValueError(f"Unexpected type: {type(obj)}")165f.close()166
167
168def jload(f, mode="r"):169"""Load a .json file into a dictionary."""170f = _make_r_io_base(f, mode)171jdict = json.load(f)172f.close()173return jdict174