google-research
268 строк · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Library for computing metrics on the FRMT dataset."""
17from __future__ import annotations18
19import collections20from collections.abc import Collection, MutableMapping, Sequence21import enum22from typing import Optional, Type23
24import attrs25import bleurt.score as bleurt_lib26from etils import epath27import sacrebleu28
29BleurtScorer = bleurt_lib.LengthBatchingBleurtScorer30
31
32@attrs.frozen(eq=True, kw_only=True)33class TranslationPair:34"""Container class for a source/translation pair.35
36Attributes:
37source: The (English) source sentence. May be `None` if this information is
38not tracked.
39translation: The gold or model translation. Unlike `source`, it is required.
40"""
41
42source: Optional[str]43translation: str44
45
46@attrs.define()47class BleurtScorerCache:48"""Container class for a cache of Bleurt models."""49
50bleurt_checkpoint_path: epath.Path = attrs.field(converter=epath.Path)51_cache: MutableMapping[epath.Path, BleurtScorer] = attrs.field(factory=dict)52
53def __getitem__(self, bleurt_name):54bleurt_checkpoint = self.bleurt_checkpoint_path / bleurt_name55if bleurt_checkpoint in self._cache:56return self._cache[bleurt_checkpoint]57else:58bleurt_scorer = BleurtScorer(str(bleurt_checkpoint))59self._cache[bleurt_checkpoint] = bleurt_scorer60return bleurt_scorer61
62
63@attrs.define(eq=True, kw_only=True, order=True, slots=True)64class Metrics:65"""Container class for the computed evaluation metrics."""66
67bleu: Optional[float] = None68chrf: Optional[float] = None69bleurt: Optional[float] = None70bleurt_d12: Optional[float] = None71bleurt_d6: Optional[float] = None72bleurt_d3: Optional[float] = None73
74def as_dict(self):75d = attrs.asdict(self)76ordered_dict = collections.OrderedDict()77for metric_name in self.__slots__: # pytype: disable=attribute-error78if metric_name not in d: # E.g. '__weakref__'79continue80if d[metric_name] is not None:81ordered_dict[metric_name] = d[metric_name]82return ordered_dict83
84
85class MetricType(enum.Enum):86"""Supported metric types."""87
88UNDEFINED = 089BLEU = 190CHRF = 291BLEURT = 392BLEURT_D12 = 493BLEURT_D6 = 594BLEURT_D3 = 695
96@staticmethod97def _validate_predictions_and_references(98predictions,99references,100):101"""Ensures that the predictions and references look okay.102
103Args:
104predictions: A sequence of TranslationPair objects containing model
105translations.
106references: A sequence of TranslationPair objects containing gold
107references.
108
109Raises:
110ValueError: If the list of predictions or references is empty.
111ValueError: If the predictions and references have different lengths.
112ValueError: If the predictions and references are misaligned (requires
113the predictions to be in .tsv format).
114"""
115if not predictions:116raise ValueError('List of predictions is empty.')117
118if not references:119raise ValueError('List of references is empty.')120
121if len(predictions) != len(references):122raise ValueError(123f'Number of predictions ({len(predictions)}) != '124f'number of references ({len(references)})'125)126
127for i, (prediction, reference) in enumerate(zip(predictions, references)):128if (129prediction.source is not None130and prediction.source != reference.source131):132raise ValueError(133f'Predictions and references are misaligned at index {i}.'134f'\nPrediction: {prediction}\nReference: {reference}'135)136
137@classmethod138def _compute_bleu(139cls,140*,141predictions,142references,143language,144):145"""Computes the BLEU score for a file pair."""146cls._validate_predictions_and_references(predictions, references)147
148if language.startswith('zh'):149tokenizer = 'zh'150else:151tokenizer = sacrebleu.DEFAULT_TOKENIZER152
153return (154sacrebleu.corpus_bleu(155[prediction.translation for prediction in predictions],156[[reference.translation for reference in references]],157tokenize=tokenizer,158).score159/ 100160)161
162@classmethod163def _compute_chrf(164cls,165*,166predictions,167references,168):169"""Computes the ChRF score for predictions and references."""170cls._validate_predictions_and_references(predictions, references)171
172return sacrebleu.corpus_chrf(173[prediction.translation for prediction in predictions],174[reference.translation for reference in references],175)176
177@classmethod178def _compute_bleurt(179cls,180*,181predictions,182references,183bleurt_scorer,184):185"""Computes the BLEURT score for predictions and references."""186cls._validate_predictions_and_references(predictions, references)187
188bleurt_scores = bleurt_scorer.score(189candidates=[prediction.translation for prediction in predictions],190references=[reference.translation for reference in references],191)192return sum(bleurt_scores) / len(bleurt_scores)193
194def compute(195self,196*,197predictions,198references,199language = None,200bleurt_scorer_cache = None,201):202"""Computes the metric on predictions and references."""203
204if self is MetricType.UNDEFINED:205raise ValueError('Cannot compute UNDEFINED metric.')206if self is MetricType.BLEU and language is None:207raise ValueError(208'`language` keyword must be non-None when computing BLEU.'209)210elif self.name.startswith('BLEURT') and bleurt_scorer_cache is None:211raise ValueError(212'`bleurt_scorer_cache` must be non-None when computing BLEURT.'213)214
215if self is MetricType.BLEU:216return self._compute_bleu(217predictions=predictions, references=references, language=language218)219elif self is MetricType.CHRF:220return self._compute_chrf(predictions=predictions, references=references)221elif self is MetricType.BLEURT:222return self._compute_bleurt(223predictions=predictions,224references=references,225bleurt_scorer=bleurt_scorer_cache['BLEURT-20'],226)227elif self is MetricType.BLEURT_D12:228return self._compute_bleurt(229predictions=predictions,230references=references,231bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D12'],232)233elif self is MetricType.BLEURT_D6:234return self._compute_bleurt(235predictions=predictions,236references=references,237bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D6'],238)239elif self is MetricType.BLEURT_D3:240return self._compute_bleurt(241predictions=predictions,242references=references,243bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D3'],244)245else:246raise ValueError(f'Cannot compute {self} metric.')247
248
249def evaluate(250*,251predictions,252references,253eval_metrics,254language,255bleurt_scorer_cache,256):257"""Runs the specified evaluation metrics."""258metrics = Metrics()259for eval_metric in eval_metrics:260value = eval_metric.compute(261predictions=predictions,262references=references,263language=language,264bleurt_scorer_cache=bleurt_scorer_cache,265)266metrics.__setattr__(eval_metric.name.lower(), value)267
268return metrics269