paddlenlp
334 строки · 14.9 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import json17import os18
19from tqdm import tqdm20
21from paddlenlp import Taskflow22
23
24# yapf: disable
25def parse_args():26parser = argparse.ArgumentParser(__doc__)27parser.add_argument('--answer_generation_model_path', type=str, default=None, help='the model path to be loaded for answer extraction')28parser.add_argument('--question_generation_model_path', type=str, default=None, help='the model path to be loaded for question generation')29parser.add_argument('--filtration_model_path', type=str, default=None, help='the model path to be loaded for filtration')30parser.add_argument('--source_file_path', type=str, default=None, help='the source file path')31parser.add_argument('--target_file_path', type=str, default=None, help='the target json file path')32parser.add_argument('--batch_size', type=int, default=1, help='the batch size when using taskflow')33parser.add_argument("--do_debug", action='store_true', help="Whether to do debug")34parser.add_argument('--a_prompt', type=str, default='答案', help='the prompt when using taskflow, separate by ,')35parser.add_argument('--a_position_prob', type=float, default=0.01, help='confidence threshold for answer extraction')36parser.add_argument('--a_max_answer_candidates', type=int, default=5, help='the max number of return answer candidate for each input')37parser.add_argument('--q_num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')38parser.add_argument('--q_max_question_length', type=int, default=50, help='the max decoding length')39parser.add_argument('--q_decode_strategy', type=str, default='sampling', help='the decode strategy')40parser.add_argument('--q_num_beams', type=int, default=6, help='the number of beams when using beam search')41parser.add_argument('--q_num_beam_groups', type=int, default=1, help='the number of beam groups when using diverse beam search')42parser.add_argument('--q_diversity_rate', type=float, default=0.0, help='the diversity_rate when using diverse beam search')43parser.add_argument('--q_top_k', type=float, default=5, help='the top_k when using sampling decoding strategy')44parser.add_argument('--q_top_p', type=float, default=1.0, help='the top_p when using sampling decoding strategy')45parser.add_argument('--q_temperature', type=float, default=1.0, help='the temperature when using sampling decoding strategy')46parser.add_argument("--do_filtration", action='store_true', help="Whether to do filtration")47parser.add_argument('--f_filtration_position_prob', type=float, default=0.1, help='confidence threshold for filtration')48args = parser.parse_args()49return args50# yapf: enable
51
52
53def answer_generation_from_paragraphs(54paragraphs, batch_size=16, model=None, max_answer_candidates=5, schema=None, wf=None55):56"""Generate answer from given paragraphs."""57result = []58buffer = []59i = 060len_paragraphs = len(paragraphs)61for paragraph_tobe in tqdm(paragraphs):62buffer.append(paragraph_tobe)63if len(buffer) == batch_size or (i + 1) == len_paragraphs:64predicts = model(buffer)65paragraph_list = buffer66buffer = []67for predict_dict, paragraph in zip(predicts, paragraph_list):68answers = []69probabilitys = []70for prompt in schema:71if prompt in predict_dict:72answer_dicts = predict_dict[prompt]73answers += [answer_dict["text"] for answer_dict in answer_dicts]74probabilitys += [answer_dict["probability"] for answer_dict in answer_dicts]75else:76answers += []77probabilitys += []78candidates = sorted(list(set([(a, p) for a, p in zip(answers, probabilitys)])), key=lambda x: -x[1])79if len(candidates) > max_answer_candidates:80candidates = candidates[:max_answer_candidates]81outdict = {82"context": paragraph,83"answer_candidates": candidates,84}85if wf:86wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")87result.append(outdict)88i += 189return result90
91
92def create_fake_question(93json_file_or_pair_list, out_json=None, num_return_sequences=1, all_sample_num=None, batch_size=894):95if out_json:96wf = open(out_json, "w", encoding="utf-8")97if isinstance(json_file_or_pair_list, list):98all_lines = json_file_or_pair_list99else:100rf = open(json_file_or_pair_list, "r", encoding="utf-8")101all_lines = []102for json_line in rf:103line_dict = json.loads(json_line)104all_lines.append(line_dict)105rf.close()106num_all_lines = len(all_lines)107output = []108context_buffer = []109answer_buffer = []110answer_probability_buffer = []111true_question_buffer = []112i = 0113for index, line_dict in enumerate(tqdm(all_lines)):114if "question" in line_dict:115q = line_dict["question"]116else:117q = ""118c = line_dict["context"]119assert "answer_candidates" in line_dict120answers = line_dict["answer_candidates"]121if not answers:122continue123for j, pair in enumerate(answers):124a, p = pair125context_buffer += [c]126answer_buffer += [a]127answer_probability_buffer += [p]128true_question_buffer += [q]129if (130(i + 1) % batch_size == 0131or (all_sample_num and (i + 1) == all_sample_num)132or ((index + 1) == num_all_lines and j == len(answers) - 1)133):134result_buffer = question_generation(135[{"context": context, "answer": answer} for context, answer in zip(context_buffer, answer_buffer)]136)137context_buffer_temp, answer_buffer_temp, answer_probability_buffer_temp, true_question_buffer_temp = (138[],139[],140[],141[],142)143for context, answer, answer_probability, true_question in zip(144context_buffer, answer_buffer, answer_probability_buffer, true_question_buffer145):146context_buffer_temp += [context] * num_return_sequences147answer_buffer_temp += [answer] * num_return_sequences148answer_probability_buffer_temp += [answer_probability] * num_return_sequences149true_question_buffer_temp += [true_question] * num_return_sequences150result_one_two_buffer = [(one, two) for one, two in zip(result_buffer[0], result_buffer[1])]151for context, answer, answer_probability, true_question, result in zip(152context_buffer_temp,153answer_buffer_temp,154answer_probability_buffer_temp,155true_question_buffer_temp,156result_one_two_buffer,157):158fake_questions_tokens = [result[0]]159fake_questions_scores = [result[1]]160for fake_questions_token, fake_questions_score in zip(161fake_questions_tokens, fake_questions_scores162):163out_dict = {164"context": context,165"synthetic_answer": answer,166"synthetic_answer_probability": answer_probability,167"synthetic_question": fake_questions_token,168"synthetic_question_probability": fake_questions_score,169"true_question": true_question,170}171if out_json:172wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")173output.append(out_dict)174context_buffer = []175answer_buffer = []176true_question_buffer = []177if all_sample_num and (i + 1) >= all_sample_num:178break179i += 1180if out_json:181wf.close()182return output183
184
185def filtration(paragraphs, batch_size=16, model=None, schema=None, wf=None, wf_debug=None):186result = []187buffer = []188valid_num, invalid_num = 0, 0189i = 0190len_paragraphs = len(paragraphs)191for paragraph_tobe in tqdm(paragraphs):192buffer.append(paragraph_tobe)193if len(buffer) == batch_size or (i + 1) == len_paragraphs:194model_inputs = []195for d in buffer:196context = d["context"]197synthetic_question = d["synthetic_question"]198prefix = "问题:" + synthetic_question + "上下文:"199content = prefix + context200model_inputs.append(content)201predicts = model(model_inputs)202paragraph_list = buffer203buffer = []204for predict_dict, paragraph in zip(predicts, paragraph_list):205context = paragraph["context"]206synthetic_question = paragraph["synthetic_question"]207synthetic_question_probability = paragraph["synthetic_question_probability"]208synthetic_answer = paragraph["synthetic_answer"]209synthetic_answer_probability = paragraph["synthetic_answer_probability"]210
211answers = []212probabilitys = []213for prompt in schema:214if prompt in predict_dict:215answer_dicts = predict_dict[prompt]216answers += [answer_dict["text"] for answer_dict in answer_dicts]217probabilitys += [answer_dict["probability"] for answer_dict in answer_dicts]218else:219answers += []220probabilitys += []221candidates = [222an for an, pro in sorted([(a, p) for a, p in zip(answers, probabilitys)], key=lambda x: -x[1])223]224out_dict = {225"context": context,226"synthetic_answer": synthetic_answer,227"synthetic_answer_probability": synthetic_answer_probability,228"synthetic_question": synthetic_question,229"synthetic_question_probability": synthetic_question_probability,230}231if synthetic_answer in candidates:232if wf:233wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")234result.append(out_dict)235valid_num += 1236else:237if wf_debug:238wf_debug.write(json.dumps(out_dict, ensure_ascii=False) + "\n")239invalid_num += 1240i += 1241print("valid synthetic question-answer pairs number:", valid_num)242print("invalid synthetic question-answer pairs number:", invalid_num)243return result244
245
246if __name__ == "__main__":247args = parse_args()248assert args.a_prompt249schema = args.a_prompt.strip().split(",")250answer_generator = Taskflow(251"information_extraction",252schema=schema,253task_path=args.answer_generation_model_path,254batch_size=args.batch_size,255position_prob=args.a_position_prob,256)257assert args.source_file_path258paragraphs = []259if args.source_file_path.endswith(".json"):260with open(args.source_file_path, "r", encoding="utf-8") as rf:261for json_line in rf:262line_dict = json.loads(json_line)263assert "context" in line_dict or "content" in line_dict264if "context" in line_dict:265paragraphs.append(line_dict["context"].strip())266elif "content" in line_dict:267paragraphs.append(line_dict["content"].strip())268else:269with open(args.source_file_path, "r", encoding="utf-8") as rf:270for line in rf:271paragraphs.append(line.strip())272
273synthetic_context_answer_pairs = answer_generation_from_paragraphs(274paragraphs,275batch_size=args.batch_size,276model=answer_generator,277max_answer_candidates=args.a_max_answer_candidates,278schema=schema,279wf=None,280)281print("create synthetic answers successfully!")282
283question_generation = Taskflow(284"question_generation",285task_path=args.question_generation_model_path,286output_scores=True,287max_length=args.q_max_question_length,288is_select_from_num_return_sequences=False,289num_return_sequences=args.q_num_return_sequences,290batch_size=args.batch_size,291decode_strategy=args.q_decode_strategy,292num_beams=args.q_num_beams,293num_beam_groups=args.q_num_beam_groups,294diversity_rate=args.q_diversity_rate,295top_k=args.q_top_k,296top_p=args.q_top_p,297temperature=args.q_temperature,298)299synthetic_answer_question_pairs = create_fake_question(300synthetic_context_answer_pairs,301None if args.do_filtration else args.target_file_path,302args.q_num_return_sequences,303None,304args.batch_size,305)306print("create synthetic question-answer pairs successfully!")307
308wf = None309wf_debug = None310if args.target_file_path:311if not os.path.exists(os.path.dirname(args.target_file_path)):312os.makedirs(os.path.dirname(args.target_file_path))313wf = open(args.target_file_path, "w", encoding="utf-8")314if args.do_debug:315wf_debug = open(args.target_file_path + ".debug.json", "w", encoding="utf-8")316if args.do_filtration:317filtration_model = Taskflow(318"information_extraction",319schema=["答案"],320task_path=args.filtration_model_path,321batch_size=args.batch_size,322position_prob=args.f_filtration_position_prob,323)324filtration(325synthetic_answer_question_pairs,326batch_size=16,327model=filtration_model,328schema=["答案"],329wf=wf,330wf_debug=wf_debug,331)332print("filter synthetic question-answer pairs successfully!")333rf.close()334wf.close()335