google-research

Форк
0
/
evaluation.py 
268 строк · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Library for computing metrics on the FRMT dataset."""
17
from __future__ import annotations
18

19
import collections
20
from collections.abc import Collection, MutableMapping, Sequence
21
import enum
22
from typing import Optional, Type
23

24
import attrs
25
import bleurt.score as bleurt_lib
26
from etils import epath
27
import sacrebleu
28

29
BleurtScorer = bleurt_lib.LengthBatchingBleurtScorer
30

31

32
@attrs.frozen(eq=True, kw_only=True)
33
class TranslationPair:
34
  """Container class for a source/translation pair.
35

36
  Attributes:
37
    source: The (English) source sentence. May be `None` if this information is
38
      not tracked.
39
    translation: The gold or model translation. Unlike `source`, it is required.
40
  """
41

42
  source: Optional[str]
43
  translation: str
44

45

46
@attrs.define()
47
class BleurtScorerCache:
48
  """Container class for a cache of Bleurt models."""
49

50
  bleurt_checkpoint_path: epath.Path = attrs.field(converter=epath.Path)
51
  _cache: MutableMapping[epath.Path, BleurtScorer] = attrs.field(factory=dict)
52

53
  def __getitem__(self, bleurt_name):
54
    bleurt_checkpoint = self.bleurt_checkpoint_path / bleurt_name
55
    if bleurt_checkpoint in self._cache:
56
      return self._cache[bleurt_checkpoint]
57
    else:
58
      bleurt_scorer = BleurtScorer(str(bleurt_checkpoint))
59
      self._cache[bleurt_checkpoint] = bleurt_scorer
60
      return bleurt_scorer
61

62

63
@attrs.define(eq=True, kw_only=True, order=True, slots=True)
64
class Metrics:
65
  """Container class for the computed evaluation metrics."""
66

67
  bleu: Optional[float] = None
68
  chrf: Optional[float] = None
69
  bleurt: Optional[float] = None
70
  bleurt_d12: Optional[float] = None
71
  bleurt_d6: Optional[float] = None
72
  bleurt_d3: Optional[float] = None
73

74
  def as_dict(self):
75
    d = attrs.asdict(self)
76
    ordered_dict = collections.OrderedDict()
77
    for metric_name in self.__slots__:  # pytype: disable=attribute-error
78
      if metric_name not in d:  # E.g. '__weakref__'
79
        continue
80
      if d[metric_name] is not None:
81
        ordered_dict[metric_name] = d[metric_name]
82
    return ordered_dict
83

84

85
class MetricType(enum.Enum):
86
  """Supported metric types."""
87

88
  UNDEFINED = 0
89
  BLEU = 1
90
  CHRF = 2
91
  BLEURT = 3
92
  BLEURT_D12 = 4
93
  BLEURT_D6 = 5
94
  BLEURT_D3 = 6
95

96
  @staticmethod
97
  def _validate_predictions_and_references(
98
      predictions,
99
      references,
100
  ):
101
    """Ensures that the predictions and references look okay.
102

103
    Args:
104
      predictions: A sequence of TranslationPair objects containing model
105
        translations.
106
      references: A sequence of TranslationPair objects containing gold
107
        references.
108

109
    Raises:
110
      ValueError: If the list of predictions or references is empty.
111
      ValueError: If the predictions and references have different lengths.
112
      ValueError: If the predictions and references are misaligned (requires
113
        the predictions to be in .tsv format).
114
    """
115
    if not predictions:
116
      raise ValueError('List of predictions is empty.')
117

118
    if not references:
119
      raise ValueError('List of references is empty.')
120

121
    if len(predictions) != len(references):
122
      raise ValueError(
123
          f'Number of predictions ({len(predictions)}) != '
124
          f'number of references ({len(references)})'
125
      )
126

127
    for i, (prediction, reference) in enumerate(zip(predictions, references)):
128
      if (
129
          prediction.source is not None
130
          and prediction.source != reference.source
131
      ):
132
        raise ValueError(
133
            f'Predictions and references are misaligned at index {i}.'
134
            f'\nPrediction: {prediction}\nReference: {reference}'
135
        )
136

137
  @classmethod
138
  def _compute_bleu(
139
      cls,
140
      *,
141
      predictions,
142
      references,
143
      language,
144
  ):
145
    """Computes the BLEU score for a file pair."""
146
    cls._validate_predictions_and_references(predictions, references)
147

148
    if language.startswith('zh'):
149
      tokenizer = 'zh'
150
    else:
151
      tokenizer = sacrebleu.DEFAULT_TOKENIZER
152

153
    return (
154
        sacrebleu.corpus_bleu(
155
            [prediction.translation for prediction in predictions],
156
            [[reference.translation for reference in references]],
157
            tokenize=tokenizer,
158
        ).score
159
        / 100
160
    )
161

162
  @classmethod
163
  def _compute_chrf(
164
      cls,
165
      *,
166
      predictions,
167
      references,
168
  ):
169
    """Computes the ChRF score for predictions and references."""
170
    cls._validate_predictions_and_references(predictions, references)
171

172
    return sacrebleu.corpus_chrf(
173
        [prediction.translation for prediction in predictions],
174
        [reference.translation for reference in references],
175
    )
176

177
  @classmethod
178
  def _compute_bleurt(
179
      cls,
180
      *,
181
      predictions,
182
      references,
183
      bleurt_scorer,
184
  ):
185
    """Computes the BLEURT score for predictions and references."""
186
    cls._validate_predictions_and_references(predictions, references)
187

188
    bleurt_scores = bleurt_scorer.score(
189
        candidates=[prediction.translation for prediction in predictions],
190
        references=[reference.translation for reference in references],
191
    )
192
    return sum(bleurt_scores) / len(bleurt_scores)
193

194
  def compute(
195
      self,
196
      *,
197
      predictions,
198
      references,
199
      language = None,
200
      bleurt_scorer_cache = None,
201
  ):
202
    """Computes the metric on predictions and references."""
203

204
    if self is MetricType.UNDEFINED:
205
      raise ValueError('Cannot compute UNDEFINED metric.')
206
    if self is MetricType.BLEU and language is None:
207
      raise ValueError(
208
          '`language` keyword must be non-None when computing BLEU.'
209
      )
210
    elif self.name.startswith('BLEURT') and bleurt_scorer_cache is None:
211
      raise ValueError(
212
          '`bleurt_scorer_cache` must be non-None when computing BLEURT.'
213
      )
214

215
    if self is MetricType.BLEU:
216
      return self._compute_bleu(
217
          predictions=predictions, references=references, language=language
218
      )
219
    elif self is MetricType.CHRF:
220
      return self._compute_chrf(predictions=predictions, references=references)
221
    elif self is MetricType.BLEURT:
222
      return self._compute_bleurt(
223
          predictions=predictions,
224
          references=references,
225
          bleurt_scorer=bleurt_scorer_cache['BLEURT-20'],
226
      )
227
    elif self is MetricType.BLEURT_D12:
228
      return self._compute_bleurt(
229
          predictions=predictions,
230
          references=references,
231
          bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D12'],
232
      )
233
    elif self is MetricType.BLEURT_D6:
234
      return self._compute_bleurt(
235
          predictions=predictions,
236
          references=references,
237
          bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D6'],
238
      )
239
    elif self is MetricType.BLEURT_D3:
240
      return self._compute_bleurt(
241
          predictions=predictions,
242
          references=references,
243
          bleurt_scorer=bleurt_scorer_cache['BLEURT-20-D3'],
244
      )
245
    else:
246
      raise ValueError(f'Cannot compute {self} metric.')
247

248

249
def evaluate(
250
    *,
251
    predictions,
252
    references,
253
    eval_metrics,
254
    language,
255
    bleurt_scorer_cache,
256
):
257
  """Runs the specified evaluation metrics."""
258
  metrics = Metrics()
259
  for eval_metric in eval_metrics:
260
    value = eval_metric.compute(
261
        predictions=predictions,
262
        references=references,
263
        language=language,
264
        bleurt_scorer_cache=bleurt_scorer_cache,
265
    )
266
    metrics.__setattr__(eval_metric.name.lower(), value)
267

268
  return metrics
269

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.