paddlenlp
233 строки · 9.1 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import json17import multiprocessing18import os19import time20
21from tqdm import tqdm22from tqdm.contrib import tzip23
24from paddlenlp.metrics import BLEU25from paddlenlp.transformers import BasicTokenizer26
27
28# yapf: disable
29def parse_args():30parser = argparse.ArgumentParser(__doc__)31parser.add_argument('--true_file_path', type=str, default=None, help='the source json file path')32parser.add_argument('--generate_file_path', type=str, default=None, help='the target json file path')33parser.add_argument('--num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')34parser.add_argument('--all_sample_num', type=int, default=None, help='the number of valid sample')35parser.add_argument('--bleu_n_size', type=int, default=4, help='the bleu n size')36parser.add_argument('--bleu_threshold', type=float, default=0.3, help='the bleu threshold')37parser.add_argument("--do_log_file", action="store_true", help="is log analysis file")38parser.add_argument('--log_dir', type=str, default=None, help='the log dir')39parser.add_argument("--do_multiprocessing", action="store_true", help="is do multiprocessing")40parser.add_argument("--do_map_async", action="store_true", help="is use map_async or apply_async when do multiprocessing")41args = parser.parse_args()42return args43# yapf: enable
44
45
46def calc_bleu_n(preds, targets, n_size=4):47assert len(preds) == len(targets), (48"The length of pred_responses should be equal to the length of "49"target_responses. But received {} and {}.".format(len(preds), len(targets))50)51bleu = BLEU(n_size=n_size)52tokenizer = BasicTokenizer()53
54for pred, target in zip(preds, targets):55pred_tokens = tokenizer.tokenize(pred)56target_token = tokenizer.tokenize(target)57
58bleu.add_inst(pred_tokens, [target_token])59return bleu.score()60
61
62def worker_apply_async(true_question, generate_question_group, bleu_n_size, bleu_threshold, i):63first_positive_pair = None64for generate_question in generate_question_group:65bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)66if bleu_score > bleu_threshold:67first_positive_pair = (generate_question, true_question, i)68if first_positive_pair:69return (True, first_positive_pair)70else:71return (False, (generate_question_group[0], true_question))72
73
74def worker_map_async(args):75true_question, generate_question_group, bleu_n_size, bleu_threshold, i = args76first_positive_pair = None77for generate_question in generate_question_group:78bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)79if bleu_score > bleu_threshold:80first_positive_pair = (generate_question, true_question, i)81if first_positive_pair:82return (True, first_positive_pair)83else:84return (False, (generate_question_group[0], true_question))85
86
87def coverage_rate(88true_file_path,89generate_file_path,90bleu_n_size,91bleu_threshold,92num_return_sequences,93all_sample_num=None,94is_log_file=False,95log_dir=None,96is_multiprocessing=True,97is_map_async=True,98):99true_questions = []100with open(true_file_path, "r", encoding="utf-8") as rf:101for i, json_line in enumerate(tqdm(rf.readlines())):102if i >= all_sample_num:103break104line_dict = json.loads(json_line)105true_questions.append(106line_dict["question"][0] if isinstance(line_dict["question"], list) else line_dict["question"]107)108
109generate_question_groups = []110with open(generate_file_path, "r", encoding="utf-8") as rf:111group = []112for i, json_line in enumerate(tqdm(rf.readlines())):113if i >= all_sample_num * num_return_sequences:114break115line_dict = json.loads(json_line)116group.append(117line_dict["question"][0] if isinstance(line_dict["question"], list) else line_dict["question"]118)119if (i + 1) % num_return_sequences == 0:120generate_question_groups.append(group)121group = []122print("true_questions", len(true_questions))123print("generate_question_groups", len(generate_question_groups))124positive = []125negative = []126if is_multiprocessing:127pool = multiprocessing.Pool(processes=30)128pool_results = []129if is_map_async:130map_async_inputs = []131i = 0132bleu_cal_time_start = time.time()133generate_question_groups = [134[135generate_question if generate_question.strip() != "" else "none"136for generate_question in generate_question_group137]138for generate_question_group in generate_question_groups139]140for true_question, generate_question_group in tzip(true_questions, generate_question_groups):141if is_multiprocessing:142if is_map_async:143map_async_inputs.append((true_question, generate_question_group, bleu_n_size, bleu_threshold, i))144else:145pool_results.append(146pool.apply_async(147worker_apply_async,148args=(true_question, generate_question_group, bleu_n_size, bleu_threshold, i),149)150)151
152else:153first_positive_pair = None154best_pair, best_score = None, 0155for generate_question in generate_question_group:156try:157bleu_score = calc_bleu_n([generate_question], [true_question], bleu_n_size)158except BaseException:159print("generate_question", generate_question)160print("true_question", true_question)161if bleu_score > best_score:162best_pair = (generate_question, true_question)163if bleu_score > bleu_threshold:164first_positive_pair = (generate_question, true_question)165if first_positive_pair:166positive.append((best_pair[0], best_pair[1], best_score))167else:168negative.append((best_pair[0], best_pair[1], best_score))169i += 1170if is_multiprocessing:171if is_map_async:172pool_results = pool.map_async(worker_map_async, map_async_inputs)173pool.close()174pool.join()175for result in pool_results.get():176is_positive, pair = result177if is_positive:178positive.append(pair)179else:180negative.append(pair)181else:182pool.close()183pool.join()184for result in pool_results:185is_positive, pair = result.get()186if is_positive:187positive.append(pair)188else:189negative.append(pair)190
191bleu_cal_time_end = time.time()192print("bleu_cal_time_spend:", bleu_cal_time_end - bleu_cal_time_start)193if is_log_file and log_dir:194with open(os.path.join(log_dir, "positive_pair.txt"), "w", encoding="utf-8") as wf:195for pair in positive:196wf.write(197pair[0] + "\t" + pair[1] + "\n"198if len(pair) == 2199else pair[0] + "\t" + pair[1] + str(pair[2]) + "\n"200)201with open(os.path.join(log_dir, "negative_pair.txt"), "w", encoding="utf-8") as wf:202for pair in negative:203wf.write(204pair[0] + "\t" + pair[1] + "\n"205if len(pair) == 2206else pair[0] + "\t" + pair[1] + str(pair[2]) + "\n"207)208assert len(positive) + len(negative) == all_sample_num, (209"the number of positive pairs "210+ str(len(positive))211+ " plus the number of negative pairs "212+ str(len(negative))213+ " should be equal to all_sample_num"214+ str(all_sample_num)215)216return len(positive) / (len(positive) + len(negative))217
218
219if __name__ == "__main__":220args = parse_args()221rate = coverage_rate(222true_file_path=args.true_file_path,223generate_file_path=args.generate_file_path,224bleu_n_size=args.bleu_n_size,225bleu_threshold=args.bleu_threshold,226num_return_sequences=args.num_return_sequences,227all_sample_num=args.all_sample_num,228is_log_file=args.do_log_file,229log_dir=args.log_dir,230is_multiprocessing=args.do_multiprocessing,231is_map_async=args.do_map_async,232)233print("coverage rate is", rate)234