paddlenlp
119 строк · 5.9 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import json17
18from tqdm import tqdm19
20from paddlenlp import Taskflow21
22
23# yapf: disable
24def parse_args():25parser = argparse.ArgumentParser(__doc__)26parser.add_argument('--model_path', type=str, default=None, help='the model path to be loaded for question_generation taskflow')27parser.add_argument('--max_length', type=int, default=50, help='the max decoding length')28parser.add_argument('--num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')29parser.add_argument('--source_file_path', type=str, default=None, help='the souce json file path')30parser.add_argument('--target_file_path', type=str, default=None, help='the target json file path')31parser.add_argument('--all_sample_num', type=int, default=None, help='the test sample number when convert_json_to_data')32parser.add_argument('--batch_size', type=int, default=1, help='the batch size when using taskflow')33parser.add_argument('--decode_strategy', type=str, default=None, help='the decode strategy')34parser.add_argument('--num_beams', type=int, default=6, help='the number of beams when using beam search')35parser.add_argument('--num_beam_groups', type=int, default=1, help='the number of beam groups when using diverse beam search')36parser.add_argument('--diversity_rate', type=float, default=0.0, help='the diversity_rate when using diverse beam search')37parser.add_argument('--top_k', type=float, default=0, help='the top_k when using sampling decoding strategy')38parser.add_argument('--top_p', type=float, default=1.0, help='the top_p when using sampling decoding strategy')39parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when using sampling decoding strategy')40args = parser.parse_args()41return args42# yapf: enable
43
44
45def create_fake_question(json_file, out_json, num_return_sequences, all_sample_num=None, batch_size=8):46with open(json_file, "r", encoding="utf-8") as rf, open(out_json, "w", encoding="utf-8") as wf:47all_lines = rf.readlines()48num_all_lines = len(all_lines)49context_buffer = []50answer_buffer = []51true_question_buffer = []52for i, json_line in enumerate(tqdm(all_lines)):53line_dict = json.loads(json_line)54q = line_dict["question"]55a = line_dict["answer"]56c = line_dict["context"]57
58context_buffer += [c]59answer_buffer += [a]60true_question_buffer += [q]61if (62(i + 1) % batch_size == 063or (all_sample_num and (i + 1) == all_sample_num or (i + 1))64or (i + 1) == num_all_lines65):66result_buffer = question_generation(67[{"context": context, "answer": answer} for context, answer in zip(context_buffer, answer_buffer)]68)69context_buffer_temp, answer_buffer_temp, true_question_buffer_temp = [], [], []70for context, answer, true_question in zip(context_buffer, answer_buffer, true_question_buffer):71context_buffer_temp += [context] * num_return_sequences72answer_buffer_temp += [answer] * num_return_sequences73true_question_buffer_temp += [true_question] * num_return_sequences74result_one_two_buffer = [(one, two) for one, two in zip(result_buffer[0], result_buffer[1])]75for context, answer, true_question, result in zip(76context_buffer_temp, answer_buffer_temp, true_question_buffer_temp, result_one_two_buffer77):78fake_quesitons_tokens = [result[0]]79fake_quesitons_scores = [result[1]]80for fake_quesitons_token, fake_quesitons_score in zip(81fake_quesitons_tokens, fake_quesitons_scores82):83out_dict = {84"context": context,85"answer": answer,86"question": fake_quesitons_token,87"true_question": true_question,88"score": fake_quesitons_score,89}90wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")91context_buffer = []92answer_buffer = []93true_question_buffer = []94
95if all_sample_num and (i + 1) >= all_sample_num:96break97
98
99if __name__ == "__main__":100args = parse_args()101question_generation = Taskflow(102"question_generation",103task_path=args.model_path,104output_scores=True,105max_length=args.max_length,106is_select_from_num_return_sequences=False,107num_return_sequences=args.num_return_sequences,108batch_size=args.batch_size,109decode_strategy=args.decode_strategy,110num_beams=args.num_beams,111num_beam_groups=args.num_beam_groups,112diversity_rate=args.diversity_rate,113top_k=args.top_k,114top_p=args.top_p,115temperature=args.temperature,116)117create_fake_question(118args.source_file_path, args.target_file_path, args.num_return_sequences, args.all_sample_num, args.batch_size119)120