google-research
217 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Compare the two sets of measurements that involve n-gram LM quality metrics.
17
18Example:
19--------
20
211. For individual report files:
22
23REPORT_DIR=...
24python analyze_ngram_metrics.py \
25--baseline_metrics_tsv_file ${REPORT_DIR}/baseline_report.tsv \
26--test_metrics_tsv_file ${REPORT_DIR}/test_report.tsv
27
282. For directories containing multiple reports for multiple languages:
29
30REPORT_DIR=...
31python analyze_ngram_metrics.py \
32--baseline_metrics_dir ${REPORT_DIR}/baselines/ \
33--test_metrics_dir ${REPORT_DIR}/rewrites/ \
34--language ckb \
35--output_tex_table_file /tmp/comparison_table.tex
36
37Dependencies:
38-------------
39absl
40numpy
41pandas
42scipy
43statsmodels
44"""
45
46from typing import Sequence47
48import logging49import os50import pathlib51
52from absl import app53from absl import flags54
55import numpy as np56import pandas as pd57import scipy58
59import utils60import stat_utils61
62flags.DEFINE_string(63"baseline_metrics_tsv_file", "",64"An input text file in tab-separated (TSV) format containing the base "65"metrics.")66
67flags.DEFINE_string(68"test_metrics_tsv_file", "",69"An input text file in tab-separated (TSV) format containing the base "70"metrics.")71
72flags.DEFINE_string(73"baseline_metrics_dir", "",74"Directory containing results files in tab-separated (TSV) format for the "75"baselines.")76
77flags.DEFINE_string(78"test_metrics_dir", "",79"Directory containing results files in tab-separated (TSV) format for the "80"tested configurations.")81
82flags.DEFINE_string(83"language", "",84"Language code to filter the files by when processing directories with "85"multiple files.")86
87flags.DEFINE_string(88"output_tex_table_file", "",89"Output file containing all the metrics for a single language as a "90"`&`-separated table.")91
92flags.DEFINE_integer(93"float_precision", 3,94"Floating point precision.")95
96FLAGS = flags.FLAGS97
98
99def _process_dir(directory):100"""Processes all the results files in the supplied directory."""101pathlist = pathlib.Path(directory).rglob("*.tsv")102results = []103for path in pathlist:104file_path = str(path)105filename = os.path.basename(file_path)106if filename.startswith(FLAGS.language):107order = utils.ngram_order_from_filename(filename)108results.append((order, file_path))109results.sort(key=lambda x: x[0])110return results111
112
113def _process_dirs():114"""Process multiple metrics files from baseline and test directories."""115base_files = _process_dir(FLAGS.baseline_metrics_dir)116test_files = _process_dir(FLAGS.test_metrics_dir)117if len(base_files) != len(test_files):118raise ValueError("Mismatching number of metrics files!")119
120orders = []121mean_deltas = []122mean_deltas_percent = []123ws_ci_low = []124ws_ci_high = []125ws_t_stats = []126ws_p_values = []127mw_t_stats = []128mw_p_values = []129bm_t_stats = []130bm_p_values = []131for i in range(len(base_files)):132base_order, base_file = base_files[i]133test_order, test_file = test_files[i]134if base_order != test_order:135raise ValueError("Mismatching n-gram orders!")136orders.append(base_order)137ws_stats, mw_stats, bm_stats = _process_one_pair(base_file, test_file)138mean_deltas.append(ws_stats.mean)139mean_deltas_percent.append(ws_stats.mean_percent)140ws_ci_low.append(ws_stats.confidence_interval[0])141ws_ci_high.append(ws_stats.confidence_interval[1])142ws_t_stats.append(ws_stats.t_statistic)143ws_p_values.append(ws_stats.p_value)144mw_t_stats.append(mw_stats.statistic)145mw_p_values.append(mw_stats.pvalue)146bm_t_stats.append(bm_stats.statistic)147bm_p_values.append(bm_stats.pvalue)148
149return pd.DataFrame(data = {150"order" : orders,151"mean_deltas" : mean_deltas,152"mean_delta_%" : mean_deltas_percent,153"ws_ci_low" : ws_ci_low,154"ws_ci_high" : ws_ci_high,155"ws_t_stat" : ws_t_stats,156"ws_p_val" : ws_p_values,157"mw_t_stat" : mw_t_stats,158"mw_p_val" : mw_p_values,159"bm_t_stat" : bm_t_stats,160"bm_p_value" : bm_p_values,161})162
163
164def _process_one_pair(baseline_file, test_file):165"""Compares metrics and returns a tuple containing results of three tests."""166base_ent = utils.read_entropies(baseline_file)167test_ent = utils.read_entropies(test_file)168
169# Analyze the metrics using parametric method: t-test, assuming that170# baseline and test entropies are normally distributed. Use171# Welch-Satterthwaite t-test.172stats = stat_utils.ParameterStats.MeanDifference(base_ent, test_ent)173print(f"t-test: {stats}")174ws_stats = stats175
176# Analyze using Mann-Whitney U and Brunner-Munzel non-parametric tests.177# Note: Unlike the Wilcoxon-Mann-Whitney’s U test, this does not require the178# assumption of equivariance of two groups.179if FLAGS.alternative_hypothesis == "less":180base_ent, test_ent = test_ent, base_ent181stats = scipy.stats.mannwhitneyu(base_ent, test_ent,182alternative=FLAGS.alternative_hypothesis)183print(f"Mann-Whitney U: {stats}")184mw_stats = stats185stats = scipy.stats.brunnermunzel(base_ent, test_ent,186alternative=FLAGS.alternative_hypothesis)187print(f"Brunner-Munzel: {stats}")188bm_stats = stats189return ws_stats, mw_stats, bm_stats190
191
192def main(argv):193if len(argv) > 1:194raise app.UsageError("Too many command-line arguments.")195if not FLAGS.baseline_metrics_tsv_file and not FLAGS.baseline_metrics_dir:196raise app.UsageError("Specify --baseline_metrics_tsv_file [FILE] or "197"--baseline_metrics_dir [DIR]!!")198if not FLAGS.test_metrics_tsv_file and not FLAGS.test_metrics_dir:199raise app.UsageError("Specify --test_metrics_tsv_file [FILE] or "200"--test_metrics_dir [DIR]!")201
202# Read the metrics from single files.203if FLAGS.baseline_metrics_tsv_file and FLAGS.test_metrics_tsv_file:204_process_one_pair(FLAGS.baseline_metrics_tsv_file,205FLAGS.test_metrics_tsv_file)206elif FLAGS.baseline_metrics_dir and FLAGS.test_metrics_dir:207if not FLAGS.language:208raise app.UsageError("Specify --language [CODE]!")209if not FLAGS.output_tex_table_file:210raise app.UsageError("Specify --output_tex_table_file [FILE]!")211df = _process_dirs().round(FLAGS.float_precision)212logging.info("Saving the table to %s ...", FLAGS.output_tex_table_file)213df.to_csv(FLAGS.output_tex_table_file, sep="&", index=None)214
215
216if __name__ == "__main__":217app.run(main)218