paddlenlp

Форк
0
118 строк · 4.9 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17
import os
18

19

20
# yapf: disable
21
def parse_args():
22
    parser = argparse.ArgumentParser(__doc__)
23
    parser.add_argument("--do_create_test_qq_pair", action='store_true', help="Whether to do create_test_qq_pair")
24
    parser.add_argument('--qq_pair_source_ori_file_path', type=str, default=None, help='the original source file path for qq-pair creating')
25
    parser.add_argument('--qq_pair_source_trans_file_path', type=str, default=None, help='the translated source file path for qq-pair creating')
26
    parser.add_argument('--qq_pair_target_file_path', type=str, default=None, help='the target file path for qq-pair creating')
27
    parser.add_argument('--trans_query_answer_path', type=str, default=None, help='the target query-answer file path for extract_trans_from_fake_question')
28
    parser.add_argument('--dev_sample_num', type=int, default=None, help='the test sample number when convert_json_to_data, if None, treat all lines as dev samples')
29
    args = parser.parse_args()
30
    return args
31
# yapf: enable
32

33

34
def extract_q_from_json_file(json_file, out_file=None, test_sample_num=None, query_answer_path=None):
35
    with open(json_file, "r", encoding="utf-8") as rf:
36
        if out_file:
37
            wf = open(os.path.join(out_file), "w", encoding="utf-8")
38
        if query_answer_path:
39
            qeury_answer_wf = open(query_answer_path, "w", encoding="utf-8")
40
        q_list = []
41
        for i, json_line in enumerate(rf.readlines()):
42
            line_dict = json.loads(json_line)
43
            if isinstance(line_dict["question"], list):
44
                question = line_dict["question"][0]
45
            else:
46
                question = line_dict["question"]
47
            answer = line_dict["answer"]
48
            if not test_sample_num or i < test_sample_num:
49
                if query_answer_path:
50
                    qeury_answer_wf.write(
51
                        question.replace("\n", " ").replace("\t", " ").strip()
52
                        + "\t"
53
                        + answer.replace("\n", " ").replace("\t", " ").strip()
54
                        + "\n"
55
                    )
56
                if out_file:
57
                    wf.write(question.replace("\n", " ").replace("\t", " ").strip() + "\n")
58
                q_list.append(question.strip())
59
            else:
60
                break
61
        if query_answer_path:
62
            qeury_answer_wf.close()
63
        if out_file:
64
            wf.colse()
65
        return q_list
66

67

68
def create_test_qq_pair(
69
    ori_path=None, trans_path=None, write_path=None, trans_query_answer_path=None, test_sample_num=None
70
):
71
    assert trans_path
72
    trans_rf = open(trans_path, "r", encoding="utf-8")
73
    wf = open(write_path, "w", encoding="utf-8")
74
    if trans_path.endswith(".json"):
75
        trans_q_list = extract_q_from_json_file(trans_path, None, test_sample_num, trans_query_answer_path)
76
    else:
77
        trans_q_list = [
78
            line.strip() for i, line in enumerate(trans_rf.readlines()) if not test_sample_num or i < test_sample_num
79
        ]
80

81
    if not ori_path or ori_path in ["NONE", "None", "none"]:
82
        origin_q_list = ["-" for _ in range(len(trans_q_list))]
83
    else:
84
        origin_rf = open(ori_path, "r", encoding="utf-8")
85
        if ori_path.endswith(".json"):
86
            origin_q_list = extract_q_from_json_file(ori_path, None, test_sample_num)
87
        else:
88
            origin_q_list = [
89
                line.strip()
90
                for i, line in enumerate(origin_rf.readlines())
91
                if not test_sample_num or i < test_sample_num
92
            ]
93

94
    for origin, trans in zip(origin_q_list, trans_q_list):
95
        wf.write(
96
            trans.replace("\n", " ").replace("\t", " ").strip()
97
            + "\t"
98
            + origin.replace("\n", " ").replace("\t", " ").strip()
99
            + "\n"
100
        )
101
    if not ori_path or ori_path in ["NONE", "None", "none"]:
102
        pass
103
    else:
104
        origin_rf.close()
105
    trans_rf.close()
106
    wf.close()
107

108

109
if __name__ == "__main__":
110
    args = parse_args()
111
    if args.do_create_test_qq_pair:
112
        create_test_qq_pair(
113
            ori_path=args.qq_pair_source_ori_file_path,
114
            trans_path=args.qq_pair_source_trans_file_path,
115
            write_path=args.qq_pair_target_file_path,
116
            trans_query_answer_path=args.trans_query_answer_path,
117
            test_sample_num=args.dev_sample_num,
118
        )
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.