google-research
378 строк · 11.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Official FRMT evaluation script."""
17from __future__ import annotations
18
19import collections
20from collections.abc import Collection, Mapping, Sequence
21import json
22from typing import Any, Optional
23
24from absl import app
25from absl import flags
26from absl import logging
27import attrs
28import bleurt.score as bleurt_lib
29from etils import epath
30import pandas
31
32from frmt import evaluation
33
34BleurtScorer = bleurt_lib.LengthBatchingBleurtScorer
35Metrics = evaluation.Metrics
36
37# ==============================================================================
38# Flags
39# ==============================================================================
40
41
42PREDICTION_FILES = flags.DEFINE_list(
43'prediction_files',
44default=None,
45help=(
46'Path to the model prediction file. Should be in .txt or .tsv '
47'format. Each model output should be on its own line, aligned to the '
48'reference file. In the case of .tsv files, the English source '
49'should be in the first column, and the translation in the second; '
50'this will allow the program to check that the predictions are '
51'aligned to the references.'
52),
53required=True,
54)
55
56DATASET_DIR = flags.DEFINE_string(
57'dataset_dir',
58default='./frmt/dataset',
59help='Path to the FRMT reference directory.',
60)
61
62SPLIT = flags.DEFINE_enum(
63'split',
64default='dev',
65enum_values=['dev', 'test'],
66help='Which data split (dev or test) to evaluate on.',
67)
68
69LANGUAGE = flags.DEFINE_enum(
70'language',
71default=None,
72enum_values=[
73'pt',
74'pt-BR',
75'pt-PT',
76'zh',
77'zh-CN',
78'zh-TW',
79'zh-TW_Simplified',
80],
81help=(
82'Which language to evaluate on. If region code is unspecified, '
83'evaluates against regions and scripts for the provided language.'
84),
85required=True,
86)
87
88BUCKET = flags.DEFINE_enum(
89'bucket',
90default=None,
91enum_values=['lexical', 'entity', 'random'],
92help='Which bucket to evaluate.',
93required=True,
94)
95
96EVAL_METRICS = flags.DEFINE_multi_enum_class(
97'metric',
98default='bleu',
99enum_class=evaluation.MetricType,
100help='Which evaluation metrics to compute. Case-insensitive.',
101)
102
103BLEURT_CHECKPOINT_DIR = flags.DEFINE_string(
104'bleurt_checkpoint_dir',
105default='./bleurt/checkpoints',
106help=(
107'Directory where BLEURT checkpoints are stored, with original '
108'checkpoint names.'
109),
110)
111
112OUTPUT_FILE = flags.DEFINE_string(
113'output_file',
114default=None,
115help=(
116'Where to save the results--can be .txt, .csv, .tsv, or .json. If '
117'empty, results are printed to stdout.'
118),
119)
120
121# ==============================================================================
122# Data structures
123# ==============================================================================
124
125
126@attrs.frozen(eq=True, kw_only=True, order=True, slots=True)
127class FilePair:
128"""Container class for a prediction/reference file pair."""
129
130prediction_path: epath.Path = attrs.field(converter=epath.Path)
131reference_path: epath.Path = attrs.field(converter=epath.Path)
132bucket: str
133
134def as_dict(self):
135d = attrs.asdict(self)
136ordered_dict = collections.OrderedDict()
137for key in self.__slots__: # pytype: disable=attribute-error
138if key not in d: # E.g. '__weakref__'
139continue
140value = d[key]
141ordered_dict[key] = value.name if isinstance(value, epath.Path) else value
142return ordered_dict
143
144
145# ==============================================================================
146# Helper functions
147# ==============================================================================
148
149
150def _list_file_pairs(
151prediction_files,
152dataset_dir,
153*,
154split,
155bucket,
156language,
157):
158"""Gathers all the predictions/references we want to evaluate."""
159file_pairs = []
160dataset_path = epath.Path(dataset_dir)
161bucket_path = dataset_path / f'{bucket}_bucket'
162primary_language = language.split('-')[0]
163match_name = f'{primary_language}_{bucket}_{split}_en_{language}'
164for prediction_file in prediction_files:
165for reference_path in bucket_path.iterdir():
166if str(reference_path.name).startswith(match_name):
167file_pairs.append(
168FilePair(
169bucket=bucket,
170prediction_path=prediction_file,
171reference_path=reference_path,
172)
173)
174return file_pairs
175
176
177def _read_tsv(file_path):
178"""Reads a csv with two columns (source, translation) and no header."""
179translation_pairs = []
180with file_path.open() as f:
181# Note: the correct way to do this is with csv.DictReader, but some examples
182# have quote characters that confuse the csv parser. Since we know the
183# source never has its own tab or newline characters, basic Python string
184# manipulation is fine here, as long as the model doesn't predict tabs or
185# newlines.
186for line in f:
187line = line.strip()
188line = line.split('\t')
189if len(line) != 2:
190raise ValueError(
191f'Line {line} could not be parsed. You may need to manually '
192'replace tab or newline characters in the model output with '
193'spaces.'
194)
195source, translation = line
196translation_pairs.append(
197evaluation.TranslationPair(source=source, translation=translation)
198)
199return translation_pairs
200
201
202def _read_txt(file_path):
203"""Reads a txt file with translations (no source) on each line."""
204translation_pairs = []
205with file_path.open() as f:
206for line in f:
207translation_pairs.append(
208evaluation.TranslationPair(source=None, translation=line.strip())
209)
210return translation_pairs
211
212
213def _read_predictions_and_references(
214file_pair,
215):
216"""Read in the predictions and references.
217
218Args:
219file_pair: The FilePair object containing the prediction and reference
220files.
221
222Returns:
223A tuple containing the model predictions and gold references (respectively)
224in the two files.
225"""
226read_predictions_fn = {
227'.txt': _read_txt,
228'.tsv': _read_tsv,
229}.get(file_pair.prediction_path.suffix)
230if read_predictions_fn is None:
231raise ValueError(
232f'Predictions file `{file_pair.prediction_path}` has unsupported '
233f'suffix `{file_pair.prediction_path.suffix}`. Supported values '
234'are ".txt" and ".tsv".'
235)
236
237predictions = read_predictions_fn(file_pair.prediction_path)
238references = _read_tsv(file_pair.reference_path)
239if len(predictions) != len(references):
240prediction_sources = set(prediction.source for prediction in predictions)
241reference_sources = set(reference.source for reference in references)
242non_references = list(prediction_sources.difference(reference_sources))
243non_predictions = list(reference_sources.difference(prediction_sources))
244raise ValueError(
245f'{file_pair} has {len(predictions)} predictions but {len(references)} '
246'references (should be equal). Sample of 5 prediction sources not in '
247f'references: {non_references[:5]}. Sample of 5 reference sources not '
248f'in predictions: {non_predictions[:5]}.'
249)
250return predictions, references
251
252
253def _records_to_string(records):
254"""Creates a human-readable string representing a sequence of records."""
255parts = []
256for record in records:
257parts.append('\n'.join(f'{k}: {v}' for k, v in record.items()))
258return '\n\n'.join(parts) + '\n'
259
260
261def _write_txt(
262output_path, records
263):
264"""Writes a collection of records to text."""
265output_path.write_text(_records_to_string(records))
266
267
268def _write_json(
269output_path, records
270):
271"""Writes a collection of records to json."""
272output_path.write_text(json.dumps(records))
273
274
275def _write_tsv(
276output_path,
277records,
278):
279"""Writes a collection of records to tsv."""
280df = pandas.DataFrame(records)
281output_path.write_text(df.to_csv(index=False, sep='\t'))
282
283
284def _write_output(
285all_metrics,
286output_path,
287):
288"""Writes the output to file specified by the user."""
289records = []
290for file_pair, metrics in all_metrics.items():
291records.append(
292collections.OrderedDict(
293**file_pair.as_dict(),
294**metrics.as_dict(),
295)
296)
297
298if output_path is None:
299s = _records_to_string(records)
300logging.info(s)
301print(s)
302return
303
304write_metrics_fn = {
305'.txt': _write_txt,
306'.json': _write_json,
307'.tsv': _write_tsv,
308}.get(output_path.suffix)
309if write_metrics_fn is None:
310raise ValueError(
311f'Output path `{output_path}` has unsupported suffix '
312f'`{output_path.suffix}`. Supported values are ".txt", ".json", and '
313'".tsv".'
314)
315write_metrics_fn(output_path, records)
316
317
318# ==============================================================================
319# Main
320# ==============================================================================
321
322
323def main(argv):
324if len(argv) > 1:
325raise app.UsageError('Too many command-line arguments.')
326output_path = (
327epath.Path(OUTPUT_FILE.value) if OUTPUT_FILE.value is not None else None
328)
329if output_path is not None:
330output_path.parent.mkdir(parents=True, exist_ok=True)
331
332if output_path is not None and output_path.suffix not in [
333'.txt',
334'.json',
335'.tsv',
336]:
337raise ValueError(
338f'Output path `{output_path}` has unsupported suffix '
339f'`{output_path.suffix}`. Supported values are ".txt", ".json", and '
340'".tsv".'
341)
342
343# Enumerate all the prediction/reference pairs we want to evaluate.
344file_pairs = _list_file_pairs(
345PREDICTION_FILES.value,
346DATASET_DIR.value,
347split=SPLIT.value,
348bucket=BUCKET.value,
349language=LANGUAGE.value,
350)
351logging.info(
352'Running evaluation on the following prediction/reference pairs: %s',
353'\n'.join(map(str, file_pairs)),
354)
355
356# Run evaluation on all the input pairs.
357all_metrics: dict[FilePair, evaluation.Metrics] = collections.OrderedDict()
358if BLEURT_CHECKPOINT_DIR.value is not None:
359bleurt_scorer_cache = evaluation.BleurtScorerCache(
360BLEURT_CHECKPOINT_DIR.value
361)
362else:
363bleurt_scorer_cache = None
364for file_pair in file_pairs:
365predictions, references = _read_predictions_and_references(file_pair)
366all_metrics[file_pair] = evaluation.evaluate(
367predictions=predictions,
368references=references,
369eval_metrics=EVAL_METRICS.value,
370language=LANGUAGE.value,
371bleurt_scorer_cache=bleurt_scorer_cache,
372)
373
374_write_output(all_metrics=all_metrics, output_path=output_path)
375
376
377if __name__ == '__main__':
378app.run(main)
379