paddlenlp

Форк
0
334 строки · 14.9 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17
import os
18

19
from tqdm import tqdm
20

21
from paddlenlp import Taskflow
22

23

24
# yapf: disable
25
def parse_args():
26
    parser = argparse.ArgumentParser(__doc__)
27
    parser.add_argument('--answer_generation_model_path', type=str, default=None, help='the model path to be loaded for answer extraction')
28
    parser.add_argument('--question_generation_model_path', type=str, default=None, help='the model path to be loaded for question generation')
29
    parser.add_argument('--filtration_model_path', type=str, default=None, help='the model path to be loaded for filtration')
30
    parser.add_argument('--source_file_path', type=str, default=None, help='the source file path')
31
    parser.add_argument('--target_file_path', type=str, default=None, help='the target json file path')
32
    parser.add_argument('--batch_size', type=int, default=1, help='the batch size when using taskflow')
33
    parser.add_argument("--do_debug", action='store_true', help="Whether to do debug")
34
    parser.add_argument('--a_prompt', type=str, default='答案', help='the prompt when using taskflow, separate by ,')
35
    parser.add_argument('--a_position_prob', type=float, default=0.01, help='confidence threshold for answer extraction')
36
    parser.add_argument('--a_max_answer_candidates', type=int, default=5, help='the max number of return answer candidate for each input')
37
    parser.add_argument('--q_num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')
38
    parser.add_argument('--q_max_question_length', type=int, default=50, help='the max decoding length')
39
    parser.add_argument('--q_decode_strategy', type=str, default='sampling', help='the decode strategy')
40
    parser.add_argument('--q_num_beams', type=int, default=6, help='the number of beams when using beam search')
41
    parser.add_argument('--q_num_beam_groups', type=int, default=1, help='the number of beam groups when using diverse beam search')
42
    parser.add_argument('--q_diversity_rate', type=float, default=0.0, help='the diversity_rate when using diverse beam search')
43
    parser.add_argument('--q_top_k', type=float, default=5, help='the top_k when using sampling decoding strategy')
44
    parser.add_argument('--q_top_p', type=float, default=1.0, help='the top_p when using sampling decoding strategy')
45
    parser.add_argument('--q_temperature', type=float, default=1.0, help='the temperature when using sampling decoding strategy')
46
    parser.add_argument("--do_filtration", action='store_true', help="Whether to do filtration")
47
    parser.add_argument('--f_filtration_position_prob', type=float, default=0.1, help='confidence threshold for filtration')
48
    args = parser.parse_args()
49
    return args
50
# yapf: enable
51

52

53
def answer_generation_from_paragraphs(
54
    paragraphs, batch_size=16, model=None, max_answer_candidates=5, schema=None, wf=None
55
):
56
    """Generate answer from given paragraphs."""
57
    result = []
58
    buffer = []
59
    i = 0
60
    len_paragraphs = len(paragraphs)
61
    for paragraph_tobe in tqdm(paragraphs):
62
        buffer.append(paragraph_tobe)
63
        if len(buffer) == batch_size or (i + 1) == len_paragraphs:
64
            predicts = model(buffer)
65
            paragraph_list = buffer
66
            buffer = []
67
            for predict_dict, paragraph in zip(predicts, paragraph_list):
68
                answers = []
69
                probabilitys = []
70
                for prompt in schema:
71
                    if prompt in predict_dict:
72
                        answer_dicts = predict_dict[prompt]
73
                        answers += [answer_dict["text"] for answer_dict in answer_dicts]
74
                        probabilitys += [answer_dict["probability"] for answer_dict in answer_dicts]
75
                    else:
76
                        answers += []
77
                        probabilitys += []
78
                candidates = sorted(list(set([(a, p) for a, p in zip(answers, probabilitys)])), key=lambda x: -x[1])
79
                if len(candidates) > max_answer_candidates:
80
                    candidates = candidates[:max_answer_candidates]
81
                outdict = {
82
                    "context": paragraph,
83
                    "answer_candidates": candidates,
84
                }
85
                if wf:
86
                    wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")
87
                result.append(outdict)
88
        i += 1
89
    return result
90

91

92
def create_fake_question(
93
    json_file_or_pair_list, out_json=None, num_return_sequences=1, all_sample_num=None, batch_size=8
94
):
95
    if out_json:
96
        wf = open(out_json, "w", encoding="utf-8")
97
    if isinstance(json_file_or_pair_list, list):
98
        all_lines = json_file_or_pair_list
99
    else:
100
        rf = open(json_file_or_pair_list, "r", encoding="utf-8")
101
        all_lines = []
102
        for json_line in rf:
103
            line_dict = json.loads(json_line)
104
            all_lines.append(line_dict)
105
        rf.close()
106
    num_all_lines = len(all_lines)
107
    output = []
108
    context_buffer = []
109
    answer_buffer = []
110
    answer_probability_buffer = []
111
    true_question_buffer = []
112
    i = 0
113
    for index, line_dict in enumerate(tqdm(all_lines)):
114
        if "question" in line_dict:
115
            q = line_dict["question"]
116
        else:
117
            q = ""
118
        c = line_dict["context"]
119
        assert "answer_candidates" in line_dict
120
        answers = line_dict["answer_candidates"]
121
        if not answers:
122
            continue
123
        for j, pair in enumerate(answers):
124
            a, p = pair
125
            context_buffer += [c]
126
            answer_buffer += [a]
127
            answer_probability_buffer += [p]
128
            true_question_buffer += [q]
129
            if (
130
                (i + 1) % batch_size == 0
131
                or (all_sample_num and (i + 1) == all_sample_num)
132
                or ((index + 1) == num_all_lines and j == len(answers) - 1)
133
            ):
134
                result_buffer = question_generation(
135
                    [{"context": context, "answer": answer} for context, answer in zip(context_buffer, answer_buffer)]
136
                )
137
                context_buffer_temp, answer_buffer_temp, answer_probability_buffer_temp, true_question_buffer_temp = (
138
                    [],
139
                    [],
140
                    [],
141
                    [],
142
                )
143
                for context, answer, answer_probability, true_question in zip(
144
                    context_buffer, answer_buffer, answer_probability_buffer, true_question_buffer
145
                ):
146
                    context_buffer_temp += [context] * num_return_sequences
147
                    answer_buffer_temp += [answer] * num_return_sequences
148
                    answer_probability_buffer_temp += [answer_probability] * num_return_sequences
149
                    true_question_buffer_temp += [true_question] * num_return_sequences
150
                result_one_two_buffer = [(one, two) for one, two in zip(result_buffer[0], result_buffer[1])]
151
                for context, answer, answer_probability, true_question, result in zip(
152
                    context_buffer_temp,
153
                    answer_buffer_temp,
154
                    answer_probability_buffer_temp,
155
                    true_question_buffer_temp,
156
                    result_one_two_buffer,
157
                ):
158
                    fake_questions_tokens = [result[0]]
159
                    fake_questions_scores = [result[1]]
160
                    for fake_questions_token, fake_questions_score in zip(
161
                        fake_questions_tokens, fake_questions_scores
162
                    ):
163
                        out_dict = {
164
                            "context": context,
165
                            "synthetic_answer": answer,
166
                            "synthetic_answer_probability": answer_probability,
167
                            "synthetic_question": fake_questions_token,
168
                            "synthetic_question_probability": fake_questions_score,
169
                            "true_question": true_question,
170
                        }
171
                        if out_json:
172
                            wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")
173
                        output.append(out_dict)
174
                context_buffer = []
175
                answer_buffer = []
176
                true_question_buffer = []
177
            if all_sample_num and (i + 1) >= all_sample_num:
178
                break
179
            i += 1
180
    if out_json:
181
        wf.close()
182
    return output
183

184

185
def filtration(paragraphs, batch_size=16, model=None, schema=None, wf=None, wf_debug=None):
186
    result = []
187
    buffer = []
188
    valid_num, invalid_num = 0, 0
189
    i = 0
190
    len_paragraphs = len(paragraphs)
191
    for paragraph_tobe in tqdm(paragraphs):
192
        buffer.append(paragraph_tobe)
193
        if len(buffer) == batch_size or (i + 1) == len_paragraphs:
194
            model_inputs = []
195
            for d in buffer:
196
                context = d["context"]
197
                synthetic_question = d["synthetic_question"]
198
                prefix = "问题:" + synthetic_question + "上下文:"
199
                content = prefix + context
200
                model_inputs.append(content)
201
            predicts = model(model_inputs)
202
            paragraph_list = buffer
203
            buffer = []
204
            for predict_dict, paragraph in zip(predicts, paragraph_list):
205
                context = paragraph["context"]
206
                synthetic_question = paragraph["synthetic_question"]
207
                synthetic_question_probability = paragraph["synthetic_question_probability"]
208
                synthetic_answer = paragraph["synthetic_answer"]
209
                synthetic_answer_probability = paragraph["synthetic_answer_probability"]
210

211
                answers = []
212
                probabilitys = []
213
                for prompt in schema:
214
                    if prompt in predict_dict:
215
                        answer_dicts = predict_dict[prompt]
216
                        answers += [answer_dict["text"] for answer_dict in answer_dicts]
217
                        probabilitys += [answer_dict["probability"] for answer_dict in answer_dicts]
218
                    else:
219
                        answers += []
220
                        probabilitys += []
221
                candidates = [
222
                    an for an, pro in sorted([(a, p) for a, p in zip(answers, probabilitys)], key=lambda x: -x[1])
223
                ]
224
                out_dict = {
225
                    "context": context,
226
                    "synthetic_answer": synthetic_answer,
227
                    "synthetic_answer_probability": synthetic_answer_probability,
228
                    "synthetic_question": synthetic_question,
229
                    "synthetic_question_probability": synthetic_question_probability,
230
                }
231
                if synthetic_answer in candidates:
232
                    if wf:
233
                        wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")
234
                    result.append(out_dict)
235
                    valid_num += 1
236
                else:
237
                    if wf_debug:
238
                        wf_debug.write(json.dumps(out_dict, ensure_ascii=False) + "\n")
239
                    invalid_num += 1
240
        i += 1
241
    print("valid synthetic question-answer pairs number:", valid_num)
242
    print("invalid synthetic question-answer pairs number:", invalid_num)
243
    return result
244

245

246
if __name__ == "__main__":
247
    args = parse_args()
248
    assert args.a_prompt
249
    schema = args.a_prompt.strip().split(",")
250
    answer_generator = Taskflow(
251
        "information_extraction",
252
        schema=schema,
253
        task_path=args.answer_generation_model_path,
254
        batch_size=args.batch_size,
255
        position_prob=args.a_position_prob,
256
    )
257
    assert args.source_file_path
258
    paragraphs = []
259
    if args.source_file_path.endswith(".json"):
260
        with open(args.source_file_path, "r", encoding="utf-8") as rf:
261
            for json_line in rf:
262
                line_dict = json.loads(json_line)
263
                assert "context" in line_dict or "content" in line_dict
264
                if "context" in line_dict:
265
                    paragraphs.append(line_dict["context"].strip())
266
                elif "content" in line_dict:
267
                    paragraphs.append(line_dict["content"].strip())
268
    else:
269
        with open(args.source_file_path, "r", encoding="utf-8") as rf:
270
            for line in rf:
271
                paragraphs.append(line.strip())
272

273
    synthetic_context_answer_pairs = answer_generation_from_paragraphs(
274
        paragraphs,
275
        batch_size=args.batch_size,
276
        model=answer_generator,
277
        max_answer_candidates=args.a_max_answer_candidates,
278
        schema=schema,
279
        wf=None,
280
    )
281
    print("create synthetic answers successfully!")
282

283
    question_generation = Taskflow(
284
        "question_generation",
285
        task_path=args.question_generation_model_path,
286
        output_scores=True,
287
        max_length=args.q_max_question_length,
288
        is_select_from_num_return_sequences=False,
289
        num_return_sequences=args.q_num_return_sequences,
290
        batch_size=args.batch_size,
291
        decode_strategy=args.q_decode_strategy,
292
        num_beams=args.q_num_beams,
293
        num_beam_groups=args.q_num_beam_groups,
294
        diversity_rate=args.q_diversity_rate,
295
        top_k=args.q_top_k,
296
        top_p=args.q_top_p,
297
        temperature=args.q_temperature,
298
    )
299
    synthetic_answer_question_pairs = create_fake_question(
300
        synthetic_context_answer_pairs,
301
        None if args.do_filtration else args.target_file_path,
302
        args.q_num_return_sequences,
303
        None,
304
        args.batch_size,
305
    )
306
    print("create synthetic question-answer pairs successfully!")
307

308
    wf = None
309
    wf_debug = None
310
    if args.target_file_path:
311
        if not os.path.exists(os.path.dirname(args.target_file_path)):
312
            os.makedirs(os.path.dirname(args.target_file_path))
313
        wf = open(args.target_file_path, "w", encoding="utf-8")
314
        if args.do_debug:
315
            wf_debug = open(args.target_file_path + ".debug.json", "w", encoding="utf-8")
316
    if args.do_filtration:
317
        filtration_model = Taskflow(
318
            "information_extraction",
319
            schema=["答案"],
320
            task_path=args.filtration_model_path,
321
            batch_size=args.batch_size,
322
            position_prob=args.f_filtration_position_prob,
323
        )
324
        filtration(
325
            synthetic_answer_question_pairs,
326
            batch_size=16,
327
            model=filtration_model,
328
            schema=["答案"],
329
            wf=wf,
330
            wf_debug=wf_debug,
331
        )
332
        print("filter synthetic question-answer pairs successfully!")
333
    rf.close()
334
    wf.close()
335

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.