belle

Форк
0
/
generate_instruction.py 
224 строки · 9.9 Кб
1
"""
2
batch_selfinstruct_generate.py
3

4
run:
5
python -m generate_instruction generate_instruction_following_data \
6
  --output_dir ./ \
7
  --num_instructions_to_generate 10 \
8
  --model_name="text-davinci-003" \
9
"""
10
import time
11
import json
12
import os
13
import random
14
import re
15
import string
16
from functools import partial
17
from multiprocessing import Pool
18

19
import numpy as np
20
import tqdm
21
from rouge_score import rouge_scorer
22
import utils
23

24
import fire
25
from gensim.summarization import bm25
26
from  transformers import AutoTokenizer
27
checkpoint = "bigscience/bloomz-7b1"
28
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
29

30

31

32
def encode_prompt(prompt_instructions):
33
    """Encode multiple prompt instructions into a single string."""
34
    prompt = open("./prompt_cn.txt").read() + "\n"
35

36
    for idx, task_dict in enumerate(prompt_instructions):
37
        (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"]
38
        instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":")
39
        input = "<无输入>" if input.lower() == "" else input
40
        prompt += f"###\n"
41
        prompt += f"{idx + 1}. 指令: {instruction}\n"
42
        prompt += f"{idx + 1}. 输入:\n{input}\n"
43
        prompt += f"{idx + 1}. 输出:\n{output}\n"
44
    prompt += f"###\n"
45
    prompt += f"{idx + 2}. 指令:"
46
    return prompt
47

48

49
def post_process_gpt3_response(num_prompt_instructions, response):
50
    if response is None:
51
        return []
52
    try: #for gpt-3.5-turbo
53
        raw_instructions = response["message"]["content"]
54
    except:
55
        try:
56
            raw_instructions = response["text"]  #for text-davinci-003
57
        except:
58
            print("ERROR parse!")
59
    if '指令:' not in raw_instructions[0: 10] and '指令:' not in raw_instructions[0: 10]:
60
        raw_instructions = f"{num_prompt_instructions+1}. 指令:" + raw_instructions
61
    raw_instructions = re.split("###", raw_instructions)
62
    instructions = []
63
    blacklist = ["图像", "图片", "照片", "文件", "图表", "图层", "曲线图", "折线图", "直线图", "柱形图", "饼状图", "链接", "http",'OpenAI', 'chatgpt', 'gpt-3', 'gpt-3.5', 'gpt-4']
64
    replace_empty_list = ['要求GPT模型能够', '要求GPT能够', '要求GPT模型', '让GPT模型', '使用GPT模型', '请向GPT模型', 'GPT模型应', 'GPT模型应该', '请求GPT模型', '需要GPT模型回答', '请GPT模型'
65
                          , '请让GPT模型', '训练GPT模型', 'GPT模型需要', '要求GPT', '让GPT', '使用GPT', '请向GPT', 'GPT应', 'GPT应该', '请求GPT', '需要GPT回答', '请GPT', '请让GPT'
66
                          , '训练GPT', 'GPT需要', '希望GPT模型能够', '希望GPT能够', '以便GPT模型能够', '以便GPT能够', '使得GPT模型能够', '使得GPT能够', '使GPT模型能够', '使GPT能够'
67
                          , '由GPT模型', '使GPT模型']
68
    for idx, inst in enumerate(raw_instructions):
69
        # if the decoding stops due to length, the last example is likely truncated so we discard it
70
        if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length":
71
            continue
72
        # filter based on keywords that are not suitable for language models.
73
        if any(find_word_in_string(word, inst) for word in blacklist):
74
            continue
75
        intruction_pattern = re.compile(r"(?<=(?:" + '|'.join(['指令:', '指令:']) + "))[\s\S]*?(?=" + '|'.join(['输入:', '输入:']) + ")")
76
        input_pattern = re.compile(r"(?<=(?:" + '|'.join(['输入:', '输入:']) + "))[\s\S]*?(?=" + '|'.join(['输出:', '输出:']) + ")")
77
        output_pattern = re.compile(r"(?<=(?:" + '|'.join(['输出:', '输出:']) + "))[\s\S]*?(?=$)")
78
        intruction_match = intruction_pattern.search(inst)
79
        input_match = input_pattern.search(inst)
80
        output_match = output_pattern.search(inst)
81
        if intruction_match and input_match and output_match:
82
            inst = re.sub(r'\d+\.$', '', intruction_match.group().strip()).strip('\n')
83
            input = re.sub(r'\d+\.$', '', input_match.group().strip()).strip('\n')
84
            input = "" if "无输入" in input else input
85
            output = output_match.group().strip().strip('\n')
86
            if '指令:' in output and '输入:' in output and '输出:' in output: # 返回若没有以###号区分,取第一条数据
87
                output_pattern_new = re.compile(r"(?<=(?:" + "))[\s\S]*?(?=" + '|'.join(['指令:', '指令:']) + ")")
88
                output_match_new = output_pattern_new.search(output)
89
                if output_match_new:
90
                    output = re.sub(r'\d+\.$', '', output_match_new.group().strip()).strip('\n')
91
            # 去掉不合理的instruction
92
            if len(inst) <= 3:
93
                continue
94
                
95
            for item in replace_empty_list:
96
                inst = inst.replace(item, "") 
97
            
98
            if "GPT" in inst or 'GPT' in input:
99
                continue
100
                
101
            if len(input) == 0:  # input无输入
102
                instructions.append({"instruction": inst, "input": input, "output": output})
103
            else:
104
                if '示例' in inst or '例子' in inst:  # inst里给例子
105
                    if len(inst) < 150:
106
                        instructions.append({"instruction": inst, "input": input, "output": output})
107
                else:  # 没给例子
108
                    if len(inst) < 100:
109
                        instructions.append({"instruction": inst, "input": input, "output": output})
110
    return instructions
111

112

113
def find_word_in_string(w, s):
114
    return w in s
115

116

117
def generate_instruction_following_data(
118
    output_dir="./",
119
    seed_tasks_path="./zh_seed_tasks.json",
120
    num_instructions_to_generate=1,
121
    api="completion",
122
    model_name="text-davinci-003",
123
    num_prompt_instructions=3,
124
    request_batch_size=1,
125
    temperature=1.0,
126
    top_p=1.0,
127
    num_cpus=16,
128
):
129
    seed_tasks = [json.loads(l) for l in open(seed_tasks_path, "r")]
130
    seed_instruction_data = [
131
        {"instruction": t["instruction"], "input": t["instances"][0]["input"], "output": t["instances"][0]["output"]}
132
        for t in seed_tasks
133
    ]
134
    print(f"Loaded {len(seed_instruction_data)} human-written seed instructions")
135

136

137
    os.makedirs(output_dir, exist_ok=True)
138
    request_idx = 0
139
    # load the LM-generated instructions
140
    machine_instruction_data = []
141
    if os.path.exists(os.path.join(output_dir, "Belle.train.json")):
142
        machine_instruction_data = utils.jload(os.path.join(output_dir, "Belle.train.json"))
143
        print(f"Loaded {len(machine_instruction_data)} machine-generated instructions")
144

145

146
    # now let's generate new instructions!
147
    progress_bar = tqdm.tqdm(total=num_instructions_to_generate)
148
    if machine_instruction_data:
149
        progress_bar.update(len(machine_instruction_data))
150

151
    # first we tokenize all the seed instructions and generated machine instructions
152
    all_instructions = [d["instruction"] for d in seed_instruction_data] + [
153
        d["instruction"] for d in machine_instruction_data
154
    ]
155
    all_instruction_tokens = [tokenizer.tokenize(inst) for inst in all_instructions]
156
    bm25Model = bm25.BM25(all_instruction_tokens)
157

158

159
    while len(machine_instruction_data) < num_instructions_to_generate:
160
        request_idx += 1
161

162
        batch_inputs = []
163
        for _ in range(request_batch_size):
164
            # only sampling from the seed tasks
165
            prompt_instructions = random.sample(seed_instruction_data, num_prompt_instructions)#    seed_instruction_data, num_prompt_instructions)
166
            prompt = encode_prompt(prompt_instructions)
167
            batch_inputs.append(prompt)
168
        decoding_args = utils.OpenAIDecodingArguments(
169
            temperature=temperature,
170
            n=1,
171
            max_tokens=1024,  # hard-code to maximize the length. the requests will be automatically adjusted
172
            top_p=top_p,
173
            stop=["\n20", "20.", "20."],
174
        )
175
        request_start = time.time()
176
        results = utils.openai_completion(
177
            prompts=batch_inputs,
178
            api=api,
179
            model_name=model_name,
180
            batch_size=request_batch_size,
181
            decoding_args=decoding_args,
182
            logit_bias={"50256": -100},  # prevent the <|endoftext|> token from being generated
183
        )
184
        
185
        request_duration = time.time() - request_start
186

187
        process_start = time.time()
188
        instruction_data = []
189
        for result in results:
190
            new_instructions = post_process_gpt3_response(num_prompt_instructions, result)
191
            instruction_data += new_instructions
192

193
        total = len(instruction_data)
194
        keep = 0
195
        for instruction_data_entry in instruction_data:
196
            # computing similarity with the pre-tokenzied instructions
197
            new_instruction_tokens = tokenizer.tokenize(instruction_data_entry["instruction"])
198
            rouge_scores = bm25Model.get_scores(new_instruction_tokens)
199

200
            most_similar_instructions = {
201
                all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1]
202
            }
203
            if max(rouge_scores) >18:
204
                continue
205
            else:
206
                keep += 1
207
            instruction_data_entry["most_similar_instructions"] = most_similar_instructions
208
            instruction_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores))
209
            machine_instruction_data.append(instruction_data_entry)
210
            all_instructions.append(instruction_data_entry["instruction"])
211
            all_instruction_tokens.append(new_instruction_tokens)
212
            progress_bar.update(1)
213
        process_duration = time.time() - process_start
214
        print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s")
215
        print(f"Generated {total} instructions, kept {keep} instructions")
216
        utils.jdump(machine_instruction_data, os.path.join(output_dir, "regen.json"))
217

218

219
def main(task, **kwargs):
220
    globals()[task](**kwargs)
221

222

223
if __name__ == "__main__":
224
    fire.Fire(main)
225

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.