paddlenlp

Форк
0
119 строк · 5.9 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17

18
from tqdm import tqdm
19

20
from paddlenlp import Taskflow
21

22

23
# yapf: disable
24
def parse_args():
25
    parser = argparse.ArgumentParser(__doc__)
26
    parser.add_argument('--model_path', type=str, default=None, help='the model path to be loaded for question_generation taskflow')
27
    parser.add_argument('--max_length', type=int, default=50, help='the max decoding length')
28
    parser.add_argument('--num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')
29
    parser.add_argument('--source_file_path', type=str, default=None, help='the souce json file path')
30
    parser.add_argument('--target_file_path', type=str, default=None, help='the target json file path')
31
    parser.add_argument('--all_sample_num', type=int, default=None, help='the test sample number when convert_json_to_data')
32
    parser.add_argument('--batch_size', type=int, default=1, help='the batch size when using taskflow')
33
    parser.add_argument('--decode_strategy', type=str, default=None, help='the decode strategy')
34
    parser.add_argument('--num_beams', type=int, default=6, help='the number of beams when using beam search')
35
    parser.add_argument('--num_beam_groups', type=int, default=1, help='the number of beam groups when using diverse beam search')
36
    parser.add_argument('--diversity_rate', type=float, default=0.0, help='the diversity_rate when using diverse beam search')
37
    parser.add_argument('--top_k', type=float, default=0, help='the top_k when using sampling decoding strategy')
38
    parser.add_argument('--top_p', type=float, default=1.0, help='the top_p when using sampling decoding strategy')
39
    parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when using sampling decoding strategy')
40
    args = parser.parse_args()
41
    return args
42
# yapf: enable
43

44

45
def create_fake_question(json_file, out_json, num_return_sequences, all_sample_num=None, batch_size=8):
46
    with open(json_file, "r", encoding="utf-8") as rf, open(out_json, "w", encoding="utf-8") as wf:
47
        all_lines = rf.readlines()
48
        num_all_lines = len(all_lines)
49
        context_buffer = []
50
        answer_buffer = []
51
        true_question_buffer = []
52
        for i, json_line in enumerate(tqdm(all_lines)):
53
            line_dict = json.loads(json_line)
54
            q = line_dict["question"]
55
            a = line_dict["answer"]
56
            c = line_dict["context"]
57

58
            context_buffer += [c]
59
            answer_buffer += [a]
60
            true_question_buffer += [q]
61
            if (
62
                (i + 1) % batch_size == 0
63
                or (all_sample_num and (i + 1) == all_sample_num or (i + 1))
64
                or (i + 1) == num_all_lines
65
            ):
66
                result_buffer = question_generation(
67
                    [{"context": context, "answer": answer} for context, answer in zip(context_buffer, answer_buffer)]
68
                )
69
                context_buffer_temp, answer_buffer_temp, true_question_buffer_temp = [], [], []
70
                for context, answer, true_question in zip(context_buffer, answer_buffer, true_question_buffer):
71
                    context_buffer_temp += [context] * num_return_sequences
72
                    answer_buffer_temp += [answer] * num_return_sequences
73
                    true_question_buffer_temp += [true_question] * num_return_sequences
74
                result_one_two_buffer = [(one, two) for one, two in zip(result_buffer[0], result_buffer[1])]
75
                for context, answer, true_question, result in zip(
76
                    context_buffer_temp, answer_buffer_temp, true_question_buffer_temp, result_one_two_buffer
77
                ):
78
                    fake_quesitons_tokens = [result[0]]
79
                    fake_quesitons_scores = [result[1]]
80
                    for fake_quesitons_token, fake_quesitons_score in zip(
81
                        fake_quesitons_tokens, fake_quesitons_scores
82
                    ):
83
                        out_dict = {
84
                            "context": context,
85
                            "answer": answer,
86
                            "question": fake_quesitons_token,
87
                            "true_question": true_question,
88
                            "score": fake_quesitons_score,
89
                        }
90
                        wf.write(json.dumps(out_dict, ensure_ascii=False) + "\n")
91
                context_buffer = []
92
                answer_buffer = []
93
                true_question_buffer = []
94

95
            if all_sample_num and (i + 1) >= all_sample_num:
96
                break
97

98

99
if __name__ == "__main__":
100
    args = parse_args()
101
    question_generation = Taskflow(
102
        "question_generation",
103
        task_path=args.model_path,
104
        output_scores=True,
105
        max_length=args.max_length,
106
        is_select_from_num_return_sequences=False,
107
        num_return_sequences=args.num_return_sequences,
108
        batch_size=args.batch_size,
109
        decode_strategy=args.decode_strategy,
110
        num_beams=args.num_beams,
111
        num_beam_groups=args.num_beam_groups,
112
        diversity_rate=args.diversity_rate,
113
        top_k=args.top_k,
114
        top_p=args.top_p,
115
        temperature=args.temperature,
116
    )
117
    create_fake_question(
118
        args.source_file_path, args.target_file_path, args.num_return_sequences, args.all_sample_num, args.batch_size
119
    )
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.