paddlenlp

Форк
0
233 строки · 9.1 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17
import multiprocessing
18
import os
19
import time
20

21
from tqdm import tqdm
22
from tqdm.contrib import tzip
23

24
from paddlenlp.metrics import BLEU
25
from paddlenlp.transformers import BasicTokenizer
26

27

28
# yapf: disable
29
def parse_args():
30
    parser = argparse.ArgumentParser(__doc__)
31
    parser.add_argument('--true_file_path', type=str, default=None, help='the source json file path')
32
    parser.add_argument('--generate_file_path', type=str, default=None, help='the target json file path')
33
    parser.add_argument('--num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')
34
    parser.add_argument('--all_sample_num', type=int, default=None, help='the number of valid sample')
35
    parser.add_argument('--bleu_n_size', type=int, default=4, help='the bleu n size')
36
    parser.add_argument('--bleu_threshold', type=float, default=0.3, help='the bleu threshold')
37
    parser.add_argument("--do_log_file", action="store_true", help="is log analysis file")
38
    parser.add_argument('--log_dir', type=str, default=None, help='the log dir')
39
    parser.add_argument("--do_multiprocessing", action="store_true", help="is do multiprocessing")
40
    parser.add_argument("--do_map_async", action="store_true", help="is use map_async or apply_async when do multiprocessing")
41
    args = parser.parse_args()
42
    return args
43
# yapf: enable
44

45

46
def calc_bleu_n(preds, targets, n_size=4):
47
    assert len(preds) == len(targets), (
48
        "The length of pred_responses should be equal to the length of "
49
        "target_responses. But received {} and {}.".format(len(preds), len(targets))
50
    )
51
    bleu = BLEU(n_size=n_size)
52
    tokenizer = BasicTokenizer()
53

54
    for pred, target in zip(preds, targets):
55
        pred_tokens = tokenizer.tokenize(pred)
56
        target_token = tokenizer.tokenize(target)
57

58
        bleu.add_inst(pred_tokens, [target_token])
59
    return bleu.score()
60

61

62
def worker_apply_async(true_question, generate_question_group, bleu_n_size, bleu_threshold, i):
63
    first_positive_pair = None
64
    for generate_question in generate_question_group:
65
        bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)
66
        if bleu_score > bleu_threshold:
67
            first_positive_pair = (generate_question, true_question, i)
68
    if first_positive_pair:
69
        return (True, first_positive_pair)
70
    else:
71
        return (False, (generate_question_group[0], true_question))
72

73

74
def worker_map_async(args):
75
    true_question, generate_question_group, bleu_n_size, bleu_threshold, i = args
76
    first_positive_pair = None
77
    for generate_question in generate_question_group:
78
        bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)
79
        if bleu_score > bleu_threshold:
80
            first_positive_pair = (generate_question, true_question, i)
81
    if first_positive_pair:
82
        return (True, first_positive_pair)
83
    else:
84
        return (False, (generate_question_group[0], true_question))
85

86

87
def coverage_rate(
88
    true_file_path,
89
    generate_file_path,
90
    bleu_n_size,
91
    bleu_threshold,
92
    num_return_sequences,
93
    all_sample_num=None,
94
    is_log_file=False,
95
    log_dir=None,
96
    is_multiprocessing=True,
97
    is_map_async=True,
98
):
99
    true_questions = []
100
    with open(true_file_path, "r", encoding="utf-8") as rf:
101
        for i, json_line in enumerate(tqdm(rf.readlines())):
102
            if i >= all_sample_num:
103
                break
104
            line_dict = json.loads(json_line)
105
            true_questions.append(
106
                line_dict["question"][0] if isinstance(line_dict["question"], list) else line_dict["question"]
107
            )
108

109
    generate_question_groups = []
110
    with open(generate_file_path, "r", encoding="utf-8") as rf:
111
        group = []
112
        for i, json_line in enumerate(tqdm(rf.readlines())):
113
            if i >= all_sample_num * num_return_sequences:
114
                break
115
            line_dict = json.loads(json_line)
116
            group.append(
117
                line_dict["question"][0] if isinstance(line_dict["question"], list) else line_dict["question"]
118
            )
119
            if (i + 1) % num_return_sequences == 0:
120
                generate_question_groups.append(group)
121
                group = []
122
    print("true_questions", len(true_questions))
123
    print("generate_question_groups", len(generate_question_groups))
124
    positive = []
125
    negative = []
126
    if is_multiprocessing:
127
        pool = multiprocessing.Pool(processes=30)
128
        pool_results = []
129
        if is_map_async:
130
            map_async_inputs = []
131
    i = 0
132
    bleu_cal_time_start = time.time()
133
    generate_question_groups = [
134
        [
135
            generate_question if generate_question.strip() != "" else "none"
136
            for generate_question in generate_question_group
137
        ]
138
        for generate_question_group in generate_question_groups
139
    ]
140
    for true_question, generate_question_group in tzip(true_questions, generate_question_groups):
141
        if is_multiprocessing:
142
            if is_map_async:
143
                map_async_inputs.append((true_question, generate_question_group, bleu_n_size, bleu_threshold, i))
144
            else:
145
                pool_results.append(
146
                    pool.apply_async(
147
                        worker_apply_async,
148
                        args=(true_question, generate_question_group, bleu_n_size, bleu_threshold, i),
149
                    )
150
                )
151

152
        else:
153
            first_positive_pair = None
154
            best_pair, best_score = None, 0
155
            for generate_question in generate_question_group:
156
                try:
157
                    bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)
158
                except BaseException:
159
                    print("generate_question", generate_question)
160
                    print("true_question", true_question)
161
                if bleu_score > best_score:
162
                    best_pair = (generate_question, true_question)
163
                if bleu_score > bleu_threshold:
164
                    first_positive_pair = (generate_question, true_question)
165
            if first_positive_pair:
166
                positive.append((best_pair[0], best_pair[1], best_score))
167
            else:
168
                negative.append((best_pair[0], best_pair[1], best_score))
169
        i += 1
170
    if is_multiprocessing:
171
        if is_map_async:
172
            pool_results = pool.map_async(worker_map_async, map_async_inputs)
173
            pool.close()
174
            pool.join()
175
            for result in pool_results.get():
176
                is_positive, pair = result
177
                if is_positive:
178
                    positive.append(pair)
179
                else:
180
                    negative.append(pair)
181
        else:
182
            pool.close()
183
            pool.join()
184
            for result in pool_results:
185
                is_positive, pair = result.get()
186
                if is_positive:
187
                    positive.append(pair)
188
                else:
189
                    negative.append(pair)
190

191
    bleu_cal_time_end = time.time()
192
    print("bleu_cal_time_spend:", bleu_cal_time_end - bleu_cal_time_start)
193
    if is_log_file and log_dir:
194
        with open(os.path.join(log_dir, "positive_pair.txt"), "w", encoding="utf-8") as wf:
195
            for pair in positive:
196
                wf.write(
197
                    pair[0] + "\t" + pair[1] + "\n"
198
                    if len(pair) == 2
199
                    else pair[0] + "\t" + pair[1] + str(pair[2]) + "\n"
200
                )
201
        with open(os.path.join(log_dir, "negative_pair.txt"), "w", encoding="utf-8") as wf:
202
            for pair in negative:
203
                wf.write(
204
                    pair[0] + "\t" + pair[1] + "\n"
205
                    if len(pair) == 2
206
                    else pair[0] + "\t" + pair[1] + str(pair[2]) + "\n"
207
                )
208
    assert len(positive) + len(negative) == all_sample_num, (
209
        "the number of positive pairs "
210
        + str(len(positive))
211
        + " plus the number of negative pairs "
212
        + str(len(negative))
213
        + " should be equal to all_sample_num"
214
        + str(all_sample_num)
215
    )
216
    return len(positive) / (len(positive) + len(negative))
217

218

219
if __name__ == "__main__":
220
    args = parse_args()
221
    rate = coverage_rate(
222
        true_file_path=args.true_file_path,
223
        generate_file_path=args.generate_file_path,
224
        bleu_n_size=args.bleu_n_size,
225
        bleu_threshold=args.bleu_threshold,
226
        num_return_sequences=args.num_return_sequences,
227
        all_sample_num=args.all_sample_num,
228
        is_log_file=args.do_log_file,
229
        log_dir=args.log_dir,
230
        is_multiprocessing=args.do_multiprocessing,
231
        is_map_async=args.do_map_async,
232
    )
233
    print("coverage rate is", rate)
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.