paddlenlp

Форк
0
161 строка · 6.6 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17
import os
18

19

20
# yapf: disable
21
def parse_args():
22
    parser = argparse.ArgumentParser(__doc__)
23
    parser.add_argument('--source_file_path', type=str, default=None, help='the source json file path')
24
    parser.add_argument('--target_dir', type=str, default='data', help='the target file path')
25
    parser.add_argument('--do_answer_prompt', action="store_true", help="is use answer prompt")
26
    parser.add_argument('--do_len_prompt', action="store_true", help="is use length prompt")
27
    parser.add_argument('--do_domain_prompt', action="store_true", help="is use domain prompt")
28
    parser.add_argument('--domain', type=str, default=None, help='the domain of the dataset when using domain prompt')
29
    args = parser.parse_args()
30
    return args
31
# yapf: enable
32

33

34
def convert_from_json_to_answer_extraction_format(
35
    json_file, output_path, domain=None, do_answer_prompt=True, do_len_prompt=False, do_domain_prompt=False
36
):
37
    with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:
38
        for line in rf:
39
            json_line = json.loads(line)
40
            context = json_line["context"]
41

42
            answer = json_line["answer"]
43
            # Cut the abnormally long sample
44
            if len(answer) > 300:
45
                answer = answer[:300]
46

47
            begin_id = context.find(answer)
48
            assert begin_id != -1, "'" + answer + "' is not found in " + context
49
            end_id = begin_id + len(answer)
50
            result = {"text": answer, "start": begin_id, "end": end_id}
51
            if do_answer_prompt:
52
                outdict = {
53
                    "content": context,
54
                    "result_list": [result],
55
                    "prompt": "答案",
56
                }
57
                wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")
58
            if do_len_prompt:
59
                if len(answer) < 10:
60
                    len_prompat = "短答案"
61
                elif len(answer) < 20:
62
                    len_prompat = "中短答案"
63
                elif len(answer) < 30:
64
                    len_prompat = "中长答案"
65
                else:
66
                    len_prompat = "长答案"
67

68
                len_outdict = {
69
                    "content": context,
70
                    "result_list": [result],
71
                    "prompt": len_prompat,
72
                }
73
                wf.write(json.dumps(len_outdict, ensure_ascii=False) + "\n")
74
            if do_domain_prompt and domain:
75
                domain_outdict = {
76
                    "content": context,
77
                    "result_list": [result],
78
                    "prompt": domain,
79
                }
80
                wf.write(json.dumps(domain_outdict, ensure_ascii=False) + "\n")
81

82

83
def convert_from_json_to_question_generation_format(json_file, output_path, tokenizer=None):
84
    with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:
85
        for line in rf:
86
            json_line = json.loads(line)
87
            context = json_line["context"]
88

89
            answer = json_line["answer"]
90
            # Cut the abnormally long sample
91
            if len(answer) > 300:
92
                answer = answer[:300]
93
            question = json_line["question"]
94

95
            outdict = {
96
                "question": question,
97
                "answer": answer,
98
                "context": context,
99
            }
100
            wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")
101

102

103
def convert_from_json_to_filtration_format(json_file, output_path, tokenizer=None):
104
    with open(json_file, "r", encoding="utf-8") as rf, open(output_path, "w", encoding="utf-8") as wf:
105
        for line in rf:
106
            json_line = json.loads(line)
107
            context = json_line["context"]
108

109
            answer = json_line["answer"]
110
            # Cut the abnormally long sample
111
            if len(answer) > 300:
112
                answer = answer[:300]
113
            question = json_line["question"]
114

115
            prefix = "问题:" + question + "上下文:"
116
            content = prefix + context
117

118
            begin_id = context.find(answer)
119
            assert begin_id != -1, "'" + answer + "' is not found in " + context
120
            end_id = begin_id + len(answer)
121
            begin_id += len(prefix)
122
            end_id += len(prefix)
123

124
            result = {"text": answer, "start": begin_id, "end": end_id}
125
            outdict = {
126
                "content": content,
127
                "result_list": [result],
128
                "prompt": "答案",
129
            }
130
            wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")
131

132

133
if __name__ == "__main__":
134
    args = parse_args()
135
    answer_extraction_target_file_path = os.path.join(
136
        args.target_dir, "answer_extraction", os.path.basename(args.source_file_path)
137
    )
138
    if not os.path.exists(os.path.dirname(answer_extraction_target_file_path)):
139
        os.makedirs(os.path.dirname(answer_extraction_target_file_path))
140
    convert_from_json_to_answer_extraction_format(
141
        json_file=args.source_file_path,
142
        output_path=answer_extraction_target_file_path,
143
        domain=args.domain,
144
        do_answer_prompt=args.do_answer_prompt,
145
        do_len_prompt=args.do_len_prompt,
146
        do_domain_prompt=args.do_domain_prompt,
147
    )
148

149
    question_generation_target_file_path = os.path.join(
150
        args.target_dir, "question_generation", os.path.basename(args.source_file_path)
151
    )
152
    if not os.path.exists(os.path.dirname(question_generation_target_file_path)):
153
        os.makedirs(os.path.dirname(question_generation_target_file_path))
154
    convert_from_json_to_question_generation_format(
155
        json_file=args.source_file_path, output_path=question_generation_target_file_path
156
    )
157

158
    filtration_target_file_path = os.path.join(args.target_dir, "filtration", os.path.basename(args.source_file_path))
159
    if not os.path.exists(os.path.dirname(filtration_target_file_path)):
160
        os.makedirs(os.path.dirname(filtration_target_file_path))
161
    convert_from_json_to_filtration_format(json_file=args.source_file_path, output_path=filtration_target_file_path)
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.