paddlenlp
161 строка · 6.6 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import json17import os18
19
20# yapf: disable
21def parse_args():22parser = argparse.ArgumentParser(__doc__)23parser.add_argument('--source_file_path', type=str, default=None, help='the source json file path')24parser.add_argument('--target_dir', type=str, default='data', help='the target file path')25parser.add_argument('--do_answer_prompt', action="store_true", help="is use answer prompt")26parser.add_argument('--do_len_prompt', action="store_true", help="is use length prompt")27parser.add_argument('--do_domain_prompt', action="store_true", help="is use domain prompt")28parser.add_argument('--domain', type=str, default=None, help='the domain of the dataset when using domain prompt')29args = parser.parse_args()30return args31# yapf: enable
32
33
34def convert_from_json_to_answer_extraction_format(35json_file, output_path, domain=None, do_answer_prompt=True, do_len_prompt=False, do_domain_prompt=False36):37with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:38for line in rf:39json_line = json.loads(line)40context = json_line["context"]41
42answer = json_line["answer"]43# Cut the abnormally long sample44if len(answer) > 300:45answer = answer[:300]46
47begin_id = context.find(answer)48assert begin_id != -1, "'" + answer + "' is not found in " + context49end_id = begin_id + len(answer)50result = {"text": answer, "start": begin_id, "end": end_id}51if do_answer_prompt:52outdict = {53"content": context,54"result_list": [result],55"prompt": "答案",56}57wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")58if do_len_prompt:59if len(answer) < 10:60len_prompat = "短答案"61elif len(answer) < 20:62len_prompat = "中短答案"63elif len(answer) < 30:64len_prompat = "中长答案"65else:66len_prompat = "长答案"67
68len_outdict = {69"content": context,70"result_list": [result],71"prompt": len_prompat,72}73wf.write(json.dumps(len_outdict, ensure_ascii=False) + "\n")74if do_domain_prompt and domain:75domain_outdict = {76"content": context,77"result_list": [result],78"prompt": domain,79}80wf.write(json.dumps(domain_outdict, ensure_ascii=False) + "\n")81
82
83def convert_from_json_to_question_generation_format(json_file, output_path, tokenizer=None):84with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:85for line in rf:86json_line = json.loads(line)87context = json_line["context"]88
89answer = json_line["answer"]90# Cut the abnormally long sample91if len(answer) > 300:92answer = answer[:300]93question = json_line["question"]94
95outdict = {96"question": question,97"answer": answer,98"context": context,99}100wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")101
102
103def convert_from_json_to_filtration_format(json_file, output_path, tokenizer=None):104with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:105for line in rf:106json_line = json.loads(line)107context = json_line["context"]108
109answer = json_line["answer"]110# Cut the abnormally long sample111if len(answer) > 300:112answer = answer[:300]113question = json_line["question"]114
115prefix = "问题:" + question + "上下文:"116content = prefix + context117
118begin_id = context.find(answer)119assert begin_id != -1, "'" + answer + "' is not found in " + context120end_id = begin_id + len(answer)121begin_id += len(prefix)122end_id += len(prefix)123
124result = {"text": answer, "start": begin_id, "end": end_id}125outdict = {126"content": content,127"result_list": [result],128"prompt": "答案",129}130wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")131
132
133if __name__ == "__main__":134args = parse_args()135answer_extraction_target_file_path = os.path.join(136args.target_dir, "answer_extraction", os.path.basename(args.source_file_path)137)138if not os.path.exists(os.path.dirname(answer_extraction_target_file_path)):139os.makedirs(os.path.dirname(answer_extraction_target_file_path))140convert_from_json_to_answer_extraction_format(141json_file=args.source_file_path,142output_path=answer_extraction_target_file_path,143domain=args.domain,144do_answer_prompt=args.do_answer_prompt,145do_len_prompt=args.do_len_prompt,146do_domain_prompt=args.do_domain_prompt,147)148
149question_generation_target_file_path = os.path.join(150args.target_dir, "question_generation", os.path.basename(args.source_file_path)151)152if not os.path.exists(os.path.dirname(question_generation_target_file_path)):153os.makedirs(os.path.dirname(question_generation_target_file_path))154convert_from_json_to_question_generation_format(155json_file=args.source_file_path, output_path=question_generation_target_file_path156)157
158filtration_target_file_path = os.path.join(args.target_dir, "filtration", os.path.basename(args.source_file_path))159if not os.path.exists(os.path.dirname(filtration_target_file_path)):160os.makedirs(os.path.dirname(filtration_target_file_path))161convert_from_json_to_filtration_format(json_file=args.source_file_path, output_path=filtration_target_file_path)162