google-research

Форк
0
/
analyze_ngram_metrics.py 
217 строк · 6.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Compare the two sets of measurements that involve n-gram LM quality metrics.
17

18
Example:
19
--------
20

21
1. For individual report files:
22

23
  REPORT_DIR=...
24
  python analyze_ngram_metrics.py \
25
    --baseline_metrics_tsv_file ${REPORT_DIR}/baseline_report.tsv \
26
    --test_metrics_tsv_file ${REPORT_DIR}/test_report.tsv
27

28
2. For directories containing multiple reports for multiple languages:
29

30
  REPORT_DIR=...
31
  python analyze_ngram_metrics.py \
32
    --baseline_metrics_dir ${REPORT_DIR}/baselines/ \
33
    --test_metrics_dir ${REPORT_DIR}/rewrites/ \
34
    --language ckb \
35
    --output_tex_table_file /tmp/comparison_table.tex
36

37
Dependencies:
38
-------------
39
  absl
40
  numpy
41
  pandas
42
  scipy
43
  statsmodels
44
"""
45

46
from typing import Sequence
47

48
import logging
49
import os
50
import pathlib
51

52
from absl import app
53
from absl import flags
54

55
import numpy as np
56
import pandas as pd
57
import scipy
58

59
import utils
60
import stat_utils
61

62
flags.DEFINE_string(
63
    "baseline_metrics_tsv_file", "",
64
    "An input text file in tab-separated (TSV) format containing the base "
65
    "metrics.")
66

67
flags.DEFINE_string(
68
    "test_metrics_tsv_file", "",
69
    "An input text file in tab-separated (TSV) format containing the base "
70
    "metrics.")
71

72
flags.DEFINE_string(
73
    "baseline_metrics_dir", "",
74
    "Directory containing results files in tab-separated (TSV) format for the "
75
    "baselines.")
76

77
flags.DEFINE_string(
78
    "test_metrics_dir", "",
79
    "Directory containing results files in tab-separated (TSV) format for the "
80
    "tested configurations.")
81

82
flags.DEFINE_string(
83
    "language", "",
84
    "Language code to filter the files by when processing directories with "
85
    "multiple files.")
86

87
flags.DEFINE_string(
88
    "output_tex_table_file", "",
89
    "Output file containing all the metrics for a single language as a "
90
    "`&`-separated table.")
91

92
flags.DEFINE_integer(
93
    "float_precision", 3,
94
    "Floating point precision.")
95

96
FLAGS = flags.FLAGS
97

98

99
def _process_dir(directory):
100
  """Processes all the results files in the supplied directory."""
101
  pathlist = pathlib.Path(directory).rglob("*.tsv")
102
  results = []
103
  for path in pathlist:
104
    file_path = str(path)
105
    filename = os.path.basename(file_path)
106
    if filename.startswith(FLAGS.language):
107
      order = utils.ngram_order_from_filename(filename)
108
      results.append((order, file_path))
109
  results.sort(key=lambda x: x[0])
110
  return results
111

112

113
def _process_dirs():
114
  """Process multiple metrics files from baseline and test directories."""
115
  base_files = _process_dir(FLAGS.baseline_metrics_dir)
116
  test_files = _process_dir(FLAGS.test_metrics_dir)
117
  if len(base_files) != len(test_files):
118
    raise ValueError("Mismatching number of metrics files!")
119

120
  orders = []
121
  mean_deltas = []
122
  mean_deltas_percent = []
123
  ws_ci_low = []
124
  ws_ci_high = []
125
  ws_t_stats = []
126
  ws_p_values = []
127
  mw_t_stats = []
128
  mw_p_values = []
129
  bm_t_stats = []
130
  bm_p_values = []
131
  for i in range(len(base_files)):
132
    base_order, base_file = base_files[i]
133
    test_order, test_file = test_files[i]
134
    if base_order != test_order:
135
      raise ValueError("Mismatching n-gram orders!")
136
    orders.append(base_order)
137
    ws_stats, mw_stats, bm_stats = _process_one_pair(base_file, test_file)
138
    mean_deltas.append(ws_stats.mean)
139
    mean_deltas_percent.append(ws_stats.mean_percent)
140
    ws_ci_low.append(ws_stats.confidence_interval[0])
141
    ws_ci_high.append(ws_stats.confidence_interval[1])
142
    ws_t_stats.append(ws_stats.t_statistic)
143
    ws_p_values.append(ws_stats.p_value)
144
    mw_t_stats.append(mw_stats.statistic)
145
    mw_p_values.append(mw_stats.pvalue)
146
    bm_t_stats.append(bm_stats.statistic)
147
    bm_p_values.append(bm_stats.pvalue)
148

149
  return pd.DataFrame(data = {
150
      "order" : orders,
151
      "mean_deltas" : mean_deltas,
152
      "mean_delta_%" : mean_deltas_percent,
153
      "ws_ci_low" : ws_ci_low,
154
      "ws_ci_high" : ws_ci_high,
155
      "ws_t_stat" : ws_t_stats,
156
      "ws_p_val" : ws_p_values,
157
      "mw_t_stat" : mw_t_stats,
158
      "mw_p_val" : mw_p_values,
159
      "bm_t_stat" : bm_t_stats,
160
      "bm_p_value" : bm_p_values,
161
  })
162

163

164
def _process_one_pair(baseline_file, test_file):
165
  """Compares metrics and returns a tuple containing results of three tests."""
166
  base_ent = utils.read_entropies(baseline_file)
167
  test_ent = utils.read_entropies(test_file)
168

169
  # Analyze the metrics using parametric method: t-test, assuming that
170
  # baseline and test entropies are normally distributed. Use
171
  # Welch-Satterthwaite t-test.
172
  stats = stat_utils.ParameterStats.MeanDifference(base_ent, test_ent)
173
  print(f"t-test: {stats}")
174
  ws_stats = stats
175

176
  # Analyze using Mann-Whitney U and Brunner-Munzel non-parametric tests.
177
  # Note: Unlike the Wilcoxon-Mann-Whitney’s U test, this does not require the
178
  # assumption of equivariance of two groups.
179
  if FLAGS.alternative_hypothesis == "less":
180
    base_ent, test_ent = test_ent, base_ent
181
  stats = scipy.stats.mannwhitneyu(base_ent, test_ent,
182
                                   alternative=FLAGS.alternative_hypothesis)
183
  print(f"Mann-Whitney U: {stats}")
184
  mw_stats = stats
185
  stats = scipy.stats.brunnermunzel(base_ent, test_ent,
186
                                    alternative=FLAGS.alternative_hypothesis)
187
  print(f"Brunner-Munzel: {stats}")
188
  bm_stats = stats
189
  return ws_stats, mw_stats, bm_stats
190

191

192
def main(argv):
193
  if len(argv) > 1:
194
    raise app.UsageError("Too many command-line arguments.")
195
  if not FLAGS.baseline_metrics_tsv_file and not FLAGS.baseline_metrics_dir:
196
    raise app.UsageError("Specify --baseline_metrics_tsv_file [FILE] or "
197
                         "--baseline_metrics_dir [DIR]!!")
198
  if not FLAGS.test_metrics_tsv_file and not FLAGS.test_metrics_dir:
199
    raise app.UsageError("Specify --test_metrics_tsv_file [FILE] or "
200
                         "--test_metrics_dir [DIR]!")
201

202
  # Read the metrics from single files.
203
  if FLAGS.baseline_metrics_tsv_file and FLAGS.test_metrics_tsv_file:
204
    _process_one_pair(FLAGS.baseline_metrics_tsv_file,
205
                      FLAGS.test_metrics_tsv_file)
206
  elif FLAGS.baseline_metrics_dir and FLAGS.test_metrics_dir:
207
    if not FLAGS.language:
208
      raise app.UsageError("Specify --language [CODE]!")
209
    if not FLAGS.output_tex_table_file:
210
      raise app.UsageError("Specify --output_tex_table_file [FILE]!")
211
    df = _process_dirs().round(FLAGS.float_precision)
212
    logging.info("Saving the table to %s ...", FLAGS.output_tex_table_file)
213
    df.to_csv(FLAGS.output_tex_table_file, sep="&", index=None)
214

215

216
if __name__ == "__main__":
217
  app.run(main)
218

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.