stanford_alpaca

Форк
0
/
generate_instruction.py 
217 строк · 8.2 Кб
1
"""
2
batch_selfinstruct_generate.py
3

4
run:
5
python -m generate_instruction generate_instruction_following_data \
6
  --output_dir ./ \
7
  --num_instructions_to_generate 10 \
8
  --model_name="text-davinci-003" \
9
"""
10
import time
11
import json
12
import os
13
import random
14
import re
15
import string
16
from functools import partial
17
from multiprocessing import Pool
18

19
import numpy as np
20
import tqdm
21
from rouge_score import rouge_scorer
22
import utils
23

24
import fire
25

26

27
def encode_prompt(prompt_instructions):
28
    """Encode multiple prompt instructions into a single string."""
29
    prompt = open("./prompt.txt").read() + "\n"
30

31
    for idx, task_dict in enumerate(prompt_instructions):
32
        (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"]
33
        instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":")
34
        input = "<noinput>" if input.lower() == "" else input
35
        prompt += f"###\n"
36
        prompt += f"{idx + 1}. Instruction: {instruction}\n"
37
        prompt += f"{idx + 1}. Input:\n{input}\n"
38
        prompt += f"{idx + 1}. Output:\n{output}\n"
39
    prompt += f"###\n"
40
    prompt += f"{idx + 2}. Instruction:"
41
    return prompt
42

43

44
def post_process_gpt3_response(num_prompt_instructions, response):
45
    if response is None:
46
        return []
47
    raw_instructions = f"{num_prompt_instructions+1}. Instruction:" + response["text"]
48
    raw_instructions = re.split("###", raw_instructions)
49
    instructions = []
50
    for idx, inst in enumerate(raw_instructions):
51
        # if the decoding stops due to length, the last example is likely truncated so we discard it
52
        if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length":
53
            continue
54
        idx += num_prompt_instructions + 1
55
        splitted_data = re.split(f"{idx}\.\s+(Instruction|Input|Output):", inst)
56
        if len(splitted_data) != 7:
57
            continue
58
        else:
59
            inst = splitted_data[2].strip()
60
            input = splitted_data[4].strip()
61
            input = "" if input.lower() == "<noinput>" else input
62
            output = splitted_data[6].strip()
63
        # filter out too short or too long instructions
64
        if len(inst.split()) <= 3 or len(inst.split()) > 150:
65
            continue
66
        # filter based on keywords that are not suitable for language models.
67
        blacklist = [
68
            "image",
69
            "images",
70
            "graph",
71
            "graphs",
72
            "picture",
73
            "pictures",
74
            "file",
75
            "files",
76
            "map",
77
            "maps",
78
            "draw",
79
            "plot",
80
            "go to",
81
            "video",
82
            "audio",
83
            "music",
84
            "flowchart",
85
            "diagram",
86
        ]
87
        blacklist += []
88
        if any(find_word_in_string(word, inst) for word in blacklist):
89
            continue
90
        # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions.
91
        # And it's a bit comfusing whether the model need to write a program or directly output the result.
92
        # Here we filter them out.
93
        # Note this is not a comprehensive filtering for all programming instructions.
94
        if inst.startswith("Write a program"):
95
            continue
96
        # filter those starting with punctuation
97
        if inst[0] in string.punctuation:
98
            continue
99
        # filter those starting with non-english character
100
        if not inst[0].isascii():
101
            continue
102
        instructions.append({"instruction": inst, "input": input, "output": output})
103
    return instructions
104

105

106
def find_word_in_string(w, s):
107
    return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s)
108

109

110
def generate_instruction_following_data(
111
    output_dir="./",
112
    seed_tasks_path="./seed_tasks.jsonl",
113
    num_instructions_to_generate=100,
114
    model_name="text-davinci-003",
115
    num_prompt_instructions=3,
116
    request_batch_size=5,
117
    temperature=1.0,
118
    top_p=1.0,
119
    num_cpus=16,
120
):
121
    seed_tasks = [json.loads(l) for l in open(seed_tasks_path, "r")]
122
    seed_instruction_data = [
123
        {"instruction": t["instruction"], "input": t["instances"][0]["input"], "output": t["instances"][0]["output"]}
124
        for t in seed_tasks
125
    ]
126
    print(f"Loaded {len(seed_instruction_data)} human-written seed instructions")
127

128
    os.makedirs(output_dir, exist_ok=True)
129
    request_idx = 0
130
    # load the LM-generated instructions
131
    machine_instruction_data = []
132
    if os.path.exists(os.path.join(output_dir, "regen.json")):
133
        machine_instruction_data = utils.jload(os.path.join(output_dir, "regen.json"))
134
        print(f"Loaded {len(machine_instruction_data)} machine-generated instructions")
135

136
    # similarities = {}
137
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
138

139
    # now let's generate new instructions!
140
    progress_bar = tqdm.tqdm(total=num_instructions_to_generate)
141
    if machine_instruction_data:
142
        progress_bar.update(len(machine_instruction_data))
143

144
    # first we tokenize all the seed instructions and generated machine instructions
145
    all_instructions = [d["instruction"] for d in seed_instruction_data] + [
146
        d["instruction"] for d in machine_instruction_data
147
    ]
148
    all_instruction_tokens = [scorer._tokenizer.tokenize(inst) for inst in all_instructions]
149

150
    while len(machine_instruction_data) < num_instructions_to_generate:
151
        request_idx += 1
152

153
        batch_inputs = []
154
        for _ in range(request_batch_size):
155
            # only sampling from the seed tasks
156
            prompt_instructions = random.sample(seed_instruction_data, num_prompt_instructions)
157
            prompt = encode_prompt(prompt_instructions)
158
            batch_inputs.append(prompt)
159
        decoding_args = utils.OpenAIDecodingArguments(
160
            temperature=temperature,
161
            n=1,
162
            max_tokens=3072,  # hard-code to maximize the length. the requests will be automatically adjusted
163
            top_p=top_p,
164
            stop=["\n20", "20.", "20."],
165
        )
166
        request_start = time.time()
167
        results = utils.openai_completion(
168
            prompts=batch_inputs,
169
            model_name=model_name,
170
            batch_size=request_batch_size,
171
            decoding_args=decoding_args,
172
            logit_bias={"50256": -100},  # prevent the <|endoftext|> token from being generated
173
        )
174
        request_duration = time.time() - request_start
175

176
        process_start = time.time()
177
        instruction_data = []
178
        for result in results:
179
            new_instructions = post_process_gpt3_response(num_prompt_instructions, result)
180
            instruction_data += new_instructions
181

182
        total = len(instruction_data)
183
        keep = 0
184
        for instruction_data_entry in instruction_data:
185
            # computing similarity with the pre-tokenzied instructions
186
            new_instruction_tokens = scorer._tokenizer.tokenize(instruction_data_entry["instruction"])
187
            with Pool(num_cpus) as p:
188
                rouge_scores = p.map(
189
                    partial(rouge_scorer._score_lcs, new_instruction_tokens),
190
                    all_instruction_tokens,
191
                )
192
            rouge_scores = [score.fmeasure for score in rouge_scores]
193
            most_similar_instructions = {
194
                all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1]
195
            }
196
            if max(rouge_scores) > 0.7:
197
                continue
198
            else:
199
                keep += 1
200
            instruction_data_entry["most_similar_instructions"] = most_similar_instructions
201
            instruction_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores))
202
            machine_instruction_data.append(instruction_data_entry)
203
            all_instructions.append(instruction_data_entry["instruction"])
204
            all_instruction_tokens.append(new_instruction_tokens)
205
            progress_bar.update(1)
206
        process_duration = time.time() - process_start
207
        print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s")
208
        print(f"Generated {total} instructions, kept {keep} instructions")
209
        utils.jdump(machine_instruction_data, os.path.join(output_dir, "regen.json"))
210

211

212
def main(task, **kwargs):
213
    globals()[task](**kwargs)
214

215

216
if __name__ == "__main__":
217
    fire.Fire(main)
218

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.