google-research

Форк
0
/
evaluate.py 
378 строк · 11.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Official FRMT evaluation script."""
17
from __future__ import annotations
18

19
import collections
20
from collections.abc import Collection, Mapping, Sequence
21
import json
22
from typing import Any, Optional
23

24
from absl import app
25
from absl import flags
26
from absl import logging
27
import attrs
28
import bleurt.score as bleurt_lib
29
from etils import epath
30
import pandas
31

32
from frmt import evaluation
33

34
BleurtScorer = bleurt_lib.LengthBatchingBleurtScorer
35
Metrics = evaluation.Metrics
36

37
# ==============================================================================
38
# Flags
39
# ==============================================================================
40

41

42
PREDICTION_FILES = flags.DEFINE_list(
43
    'prediction_files',
44
    default=None,
45
    help=(
46
        'Path to the model prediction file. Should be in .txt or .tsv '
47
        'format. Each model output should be on its own line, aligned to the '
48
        'reference file. In the case of .tsv files, the English source '
49
        'should be in the first column, and the translation in the second; '
50
        'this will allow the program to check that the predictions are '
51
        'aligned to the references.'
52
    ),
53
    required=True,
54
)
55

56
DATASET_DIR = flags.DEFINE_string(
57
    'dataset_dir',
58
    default='./frmt/dataset',
59
    help='Path to the FRMT reference directory.',
60
)
61

62
SPLIT = flags.DEFINE_enum(
63
    'split',
64
    default='dev',
65
    enum_values=['dev', 'test'],
66
    help='Which data split (dev or test) to evaluate on.',
67
)
68

69
LANGUAGE = flags.DEFINE_enum(
70
    'language',
71
    default=None,
72
    enum_values=[
73
        'pt',
74
        'pt-BR',
75
        'pt-PT',
76
        'zh',
77
        'zh-CN',
78
        'zh-TW',
79
        'zh-TW_Simplified',
80
    ],
81
    help=(
82
        'Which language to evaluate on. If region code is unspecified, '
83
        'evaluates against regions and scripts for the provided language.'
84
    ),
85
    required=True,
86
)
87

88
BUCKET = flags.DEFINE_enum(
89
    'bucket',
90
    default=None,
91
    enum_values=['lexical', 'entity', 'random'],
92
    help='Which bucket to evaluate.',
93
    required=True,
94
)
95

96
EVAL_METRICS = flags.DEFINE_multi_enum_class(
97
    'metric',
98
    default='bleu',
99
    enum_class=evaluation.MetricType,
100
    help='Which evaluation metrics to compute. Case-insensitive.',
101
)
102

103
BLEURT_CHECKPOINT_DIR = flags.DEFINE_string(
104
    'bleurt_checkpoint_dir',
105
    default='./bleurt/checkpoints',
106
    help=(
107
        'Directory where BLEURT checkpoints are stored, with original '
108
        'checkpoint names.'
109
    ),
110
)
111

112
OUTPUT_FILE = flags.DEFINE_string(
113
    'output_file',
114
    default=None,
115
    help=(
116
        'Where to save the results--can be .txt, .csv, .tsv, or .json. If '
117
        'empty, results are printed to stdout.'
118
    ),
119
)
120

121
# ==============================================================================
122
# Data structures
123
# ==============================================================================
124

125

126
@attrs.frozen(eq=True, kw_only=True, order=True, slots=True)
127
class FilePair:
128
  """Container class for a prediction/reference file pair."""
129

130
  prediction_path: epath.Path = attrs.field(converter=epath.Path)
131
  reference_path: epath.Path = attrs.field(converter=epath.Path)
132
  bucket: str
133

134
  def as_dict(self):
135
    d = attrs.asdict(self)
136
    ordered_dict = collections.OrderedDict()
137
    for key in self.__slots__:  # pytype: disable=attribute-error
138
      if key not in d:  # E.g. '__weakref__'
139
        continue
140
      value = d[key]
141
      ordered_dict[key] = value.name if isinstance(value, epath.Path) else value
142
    return ordered_dict
143

144

145
# ==============================================================================
146
# Helper functions
147
# ==============================================================================
148

149

150
def _list_file_pairs(
151
    prediction_files,
152
    dataset_dir,
153
    *,
154
    split,
155
    bucket,
156
    language,
157
):
158
  """Gathers all the predictions/references we want to evaluate."""
159
  file_pairs = []
160
  dataset_path = epath.Path(dataset_dir)
161
  bucket_path = dataset_path / f'{bucket}_bucket'
162
  primary_language = language.split('-')[0]
163
  match_name = f'{primary_language}_{bucket}_{split}_en_{language}'
164
  for prediction_file in prediction_files:
165
    for reference_path in bucket_path.iterdir():
166
      if str(reference_path.name).startswith(match_name):
167
        file_pairs.append(
168
            FilePair(
169
                bucket=bucket,
170
                prediction_path=prediction_file,
171
                reference_path=reference_path,
172
            )
173
        )
174
  return file_pairs
175

176

177
def _read_tsv(file_path):
178
  """Reads a csv with two columns (source, translation) and no header."""
179
  translation_pairs = []
180
  with file_path.open() as f:
181
    # Note: the correct way to do this is with csv.DictReader, but some examples
182
    # have quote characters that confuse the csv parser. Since we know the
183
    # source never has its own tab or newline characters, basic Python string
184
    # manipulation is fine here, as long as the model doesn't predict tabs or
185
    # newlines.
186
    for line in f:
187
      line = line.strip()
188
      line = line.split('\t')
189
      if len(line) != 2:
190
        raise ValueError(
191
            f'Line {line} could not be parsed. You may need to manually '
192
            'replace tab or newline characters in the model output with '
193
            'spaces.'
194
        )
195
      source, translation = line
196
      translation_pairs.append(
197
          evaluation.TranslationPair(source=source, translation=translation)
198
      )
199
  return translation_pairs
200

201

202
def _read_txt(file_path):
203
  """Reads a txt file with translations (no source) on each line."""
204
  translation_pairs = []
205
  with file_path.open() as f:
206
    for line in f:
207
      translation_pairs.append(
208
          evaluation.TranslationPair(source=None, translation=line.strip())
209
      )
210
  return translation_pairs
211

212

213
def _read_predictions_and_references(
214
    file_pair,
215
):
216
  """Read in the predictions and references.
217

218
  Args:
219
    file_pair: The FilePair object containing the prediction and reference
220
      files.
221

222
  Returns:
223
    A tuple containing the model predictions and gold references (respectively)
224
      in the two files.
225
  """
226
  read_predictions_fn = {
227
      '.txt': _read_txt,
228
      '.tsv': _read_tsv,
229
  }.get(file_pair.prediction_path.suffix)
230
  if read_predictions_fn is None:
231
    raise ValueError(
232
        f'Predictions file `{file_pair.prediction_path}` has unsupported '
233
        f'suffix `{file_pair.prediction_path.suffix}`. Supported values '
234
        'are ".txt" and ".tsv".'
235
    )
236

237
  predictions = read_predictions_fn(file_pair.prediction_path)
238
  references = _read_tsv(file_pair.reference_path)
239
  if len(predictions) != len(references):
240
    prediction_sources = set(prediction.source for prediction in predictions)
241
    reference_sources = set(reference.source for reference in references)
242
    non_references = list(prediction_sources.difference(reference_sources))
243
    non_predictions = list(reference_sources.difference(prediction_sources))
244
    raise ValueError(
245
        f'{file_pair} has {len(predictions)} predictions but {len(references)} '
246
        'references (should be equal). Sample of 5 prediction sources not in '
247
        f'references: {non_references[:5]}. Sample of 5 reference sources not '
248
        f'in predictions: {non_predictions[:5]}.'
249
    )
250
  return predictions, references
251

252

253
def _records_to_string(records):
254
  """Creates a human-readable string representing a sequence of records."""
255
  parts = []
256
  for record in records:
257
    parts.append('\n'.join(f'{k}: {v}' for k, v in record.items()))
258
  return '\n\n'.join(parts) + '\n'
259

260

261
def _write_txt(
262
    output_path, records
263
):
264
  """Writes a collection of records to text."""
265
  output_path.write_text(_records_to_string(records))
266

267

268
def _write_json(
269
    output_path, records
270
):
271
  """Writes a collection of records to json."""
272
  output_path.write_text(json.dumps(records))
273

274

275
def _write_tsv(
276
    output_path,
277
    records,
278
):
279
  """Writes a collection of records to tsv."""
280
  df = pandas.DataFrame(records)
281
  output_path.write_text(df.to_csv(index=False, sep='\t'))
282

283

284
def _write_output(
285
    all_metrics,
286
    output_path,
287
):
288
  """Writes the output to file specified by the user."""
289
  records = []
290
  for file_pair, metrics in all_metrics.items():
291
    records.append(
292
        collections.OrderedDict(
293
            **file_pair.as_dict(),
294
            **metrics.as_dict(),
295
        )
296
    )
297

298
  if output_path is None:
299
    s = _records_to_string(records)
300
    logging.info(s)
301
    print(s)
302
    return
303

304
  write_metrics_fn = {
305
      '.txt': _write_txt,
306
      '.json': _write_json,
307
      '.tsv': _write_tsv,
308
  }.get(output_path.suffix)
309
  if write_metrics_fn is None:
310
    raise ValueError(
311
        f'Output path `{output_path}` has unsupported suffix '
312
        f'`{output_path.suffix}`. Supported values are ".txt", ".json", and '
313
        '".tsv".'
314
    )
315
  write_metrics_fn(output_path, records)
316

317

318
# ==============================================================================
319
# Main
320
# ==============================================================================
321

322

323
def main(argv):
324
  if len(argv) > 1:
325
    raise app.UsageError('Too many command-line arguments.')
326
  output_path = (
327
      epath.Path(OUTPUT_FILE.value) if OUTPUT_FILE.value is not None else None
328
  )
329
  if output_path is not None:
330
    output_path.parent.mkdir(parents=True, exist_ok=True)
331

332
  if output_path is not None and output_path.suffix not in [
333
      '.txt',
334
      '.json',
335
      '.tsv',
336
  ]:
337
    raise ValueError(
338
        f'Output path `{output_path}` has unsupported suffix '
339
        f'`{output_path.suffix}`. Supported values are ".txt", ".json", and '
340
        '".tsv".'
341
    )
342

343
  # Enumerate all the prediction/reference pairs we want to evaluate.
344
  file_pairs = _list_file_pairs(
345
      PREDICTION_FILES.value,
346
      DATASET_DIR.value,
347
      split=SPLIT.value,
348
      bucket=BUCKET.value,
349
      language=LANGUAGE.value,
350
  )
351
  logging.info(
352
      'Running evaluation on the following prediction/reference pairs: %s',
353
      '\n'.join(map(str, file_pairs)),
354
  )
355

356
  # Run evaluation on all the input pairs.
357
  all_metrics: dict[FilePair, evaluation.Metrics] = collections.OrderedDict()
358
  if BLEURT_CHECKPOINT_DIR.value is not None:
359
    bleurt_scorer_cache = evaluation.BleurtScorerCache(
360
        BLEURT_CHECKPOINT_DIR.value
361
    )
362
  else:
363
    bleurt_scorer_cache = None
364
  for file_pair in file_pairs:
365
    predictions, references = _read_predictions_and_references(file_pair)
366
    all_metrics[file_pair] = evaluation.evaluate(
367
        predictions=predictions,
368
        references=references,
369
        eval_metrics=EVAL_METRICS.value,
370
        language=LANGUAGE.value,
371
        bleurt_scorer_cache=bleurt_scorer_cache,
372
    )
373

374
  _write_output(all_metrics=all_metrics, output_path=output_path)
375

376

377
if __name__ == '__main__':
378
  app.run(main)
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.