paddlenlp
118 строк · 4.9 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import json17import os18
19
20# yapf: disable
21def parse_args():22parser = argparse.ArgumentParser(__doc__)23parser.add_argument("--do_create_test_qq_pair", action='store_true', help="Whether to do create_test_qq_pair")24parser.add_argument('--qq_pair_source_ori_file_path', type=str, default=None, help='the original source file path for qq-pair creating')25parser.add_argument('--qq_pair_source_trans_file_path', type=str, default=None, help='the translated source file path for qq-pair creating')26parser.add_argument('--qq_pair_target_file_path', type=str, default=None, help='the target file path for qq-pair creating')27parser.add_argument('--trans_query_answer_path', type=str, default=None, help='the target query-answer file path for extract_trans_from_fake_question')28parser.add_argument('--dev_sample_num', type=int, default=None, help='the test sample number when convert_json_to_data, if None, treat all lines as dev samples')29args = parser.parse_args()30return args31# yapf: enable
32
33
34def extract_q_from_json_file(json_file, out_file=None, test_sample_num=None, query_answer_path=None):35with open(json_file, "r", encoding="utf-8") as rf:36if out_file:37wf = open(os.path.join(out_file), "w", encoding="utf-8")38if query_answer_path:39qeury_answer_wf = open(query_answer_path, "w", encoding="utf-8")40q_list = []41for i, json_line in enumerate(rf.readlines()):42line_dict = json.loads(json_line)43if isinstance(line_dict["question"], list):44question = line_dict["question"][0]45else:46question = line_dict["question"]47answer = line_dict["answer"]48if not test_sample_num or i < test_sample_num:49if query_answer_path:50qeury_answer_wf.write(51question.replace("\n", " ").replace("\t", " ").strip()52+ "\t"53+ answer.replace("\n", " ").replace("\t", " ").strip()54+ "\n"55)56if out_file:57wf.write(question.replace("\n", " ").replace("\t", " ").strip() + "\n")58q_list.append(question.strip())59else:60break61if query_answer_path:62qeury_answer_wf.close()63if out_file:64wf.colse()65return q_list66
67
68def create_test_qq_pair(69ori_path=None, trans_path=None, write_path=None, trans_query_answer_path=None, test_sample_num=None70):71assert trans_path72trans_rf = open(trans_path, "r", encoding="utf-8")73wf = open(write_path, "w", encoding="utf-8")74if trans_path.endswith(".json"):75trans_q_list = extract_q_from_json_file(trans_path, None, test_sample_num, trans_query_answer_path)76else:77trans_q_list = [78line.strip() for i, line in enumerate(trans_rf.readlines()) if not test_sample_num or i < test_sample_num79]80
81if not ori_path or ori_path in ["NONE", "None", "none"]:82origin_q_list = ["-" for _ in range(len(trans_q_list))]83else:84origin_rf = open(ori_path, "r", encoding="utf-8")85if ori_path.endswith(".json"):86origin_q_list = extract_q_from_json_file(ori_path, None, test_sample_num)87else:88origin_q_list = [89line.strip()90for i, line in enumerate(origin_rf.readlines())91if not test_sample_num or i < test_sample_num92]93
94for origin, trans in zip(origin_q_list, trans_q_list):95wf.write(96trans.replace("\n", " ").replace("\t", " ").strip()97+ "\t"98+ origin.replace("\n", " ").replace("\t", " ").strip()99+ "\n"100)101if not ori_path or ori_path in ["NONE", "None", "none"]:102pass103else:104origin_rf.close()105trans_rf.close()106wf.close()107
108
109if __name__ == "__main__":110args = parse_args()111if args.do_create_test_qq_pair:112create_test_qq_pair(113ori_path=args.qq_pair_source_ori_file_path,114trans_path=args.qq_pair_source_trans_file_path,115write_path=args.qq_pair_target_file_path,116trans_query_answer_path=args.trans_query_answer_path,117test_sample_num=args.dev_sample_num,118)119