google-research

Форк
0
214 строк · 7.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Parallel BLEU score calculation.
17

18
This version of BLEU calculation is derived from the MLPerf transformer
19
reference.
20
Tries to match SacreBLEU metric reasonably well, but is not identical.
21

22
Refs:
23
    tokenizer at:
24
    https://github.com/tensorflow/models/blob/master/official/transformer/utils/tokenizer.py
25
    original preprocessing tokenizer:
26
    https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
27
    original t2t code:
28
    https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
29

30
Usage:
31
    refs = '''food bar brown cow
32
    blee bloo dog sat
33
    or please take me out
34
    '''
35
    hyps = '''foo bar brown cow
36
    blee bloo dog sit
37
    please do take me out
38
    '''
39
    bleu_local(refs.split("\n"), hyps.split("\n"))  # 39.65
40
"""
41

42
import collections
43
import math
44
import re
45
import sys
46
import unicodedata
47
import numpy as np
48
import six
49

50

51
class UnicodeRegex(object):
52
  """Ad-hoc hack to recognize all punctuation and symbols."""
53

54
  def __init__(self):
55
    punctuation = self.property_chars("P")
56
    self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
57
    self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
58
    self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
59

60
  def property_chars(self, prefix):
61
    return "".join(
62
        six.unichr(x)
63
        for x in range(sys.maxunicode)
64
        if unicodedata.category(six.unichr(x)).startswith(prefix))
65

66

67
uregex = UnicodeRegex()
68

69

70
def bleu_tokenize(string):
71
  r"""Tokenize a string following the official BLEU implementation.
72

73
  See https://github.com/moses-smt/mosesdecoder/'
74
           'blob/master/scripts/generic/mteval-v14.pl#L954-L983
75
  In our case, the input string is expected to be just one line
76
  and no HTML entities de-escaping is needed.
77
  So we just tokenize on punctuation and symbols,
78
  except when a punctuation is preceded and followed by a digit
79
  (e.g. a comma/dot as a thousand/decimal separator).
80

81
  Note that a number (e.g. a year) followed by a dot at the end of sentence
82
  is NOT tokenized, i.e. the dot stays with the number because
83
  `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a
84
  space after each sentence). However, this error is already in the
85
  original mteval-v14.pl and we want to be consistent with it.
86

87
  Args:
88
    string: the input string
89

90
  Returns:
91
    a list of tokens
92
  """
93
  string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
94
  string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
95
  string = uregex.symbol_re.sub(r" \1 ", string)
96
  return string.split()
97

98

99
def _get_ngrams(segment, max_order):
100
  """Extracts all n-grams up to a given maximum order from an input segment.
101

102
  Args:
103
    segment: text segment from which n-grams will be extracted.
104
    max_order: maximum length in tokens of the n-grams returned by this methods.
105

106
  Returns:
107
    The Counter containing all n-grams up to max_order in segment
108
    with a count of how many times each n-gram occurred.
109
  """
110
  ngram_counts = collections.Counter()
111
  for order in range(1, max_order + 1):
112
    for i in range(0, len(segment) - order + 1):
113
      ngram = tuple(segment[i:i + order])
114
      ngram_counts[ngram] += 1
115
  return ngram_counts
116

117

118
def compute_bleu_matches(reference_corpus, translation_corpus, max_order=4):
119
  """Computes BLEU match stats of translations against one or more references.
120

121
  Args:
122
    reference_corpus: list of references for each translation. Each reference
123
      should be tokenized into a list of tokens.
124
    translation_corpus: list of translations to score. Each translation should
125
      be tokenized into a list of tokens.
126
    max_order: Maximum n-gram order to use when computing BLEU score.
127

128
  Returns:
129
    Aggregated n-gram stats for BLEU calculation.
130
  """
131
  reference_length = 0
132
  translation_length = 0
133

134
  matches_by_order = [0] * max_order
135
  possible_matches_by_order = [0] * max_order
136

137
  for (references, translations) in zip(reference_corpus, translation_corpus):
138
    reference_length += len(references)
139
    translation_length += len(translations)
140
    ref_ngram_counts = _get_ngrams(references, max_order)
141
    translation_ngram_counts = _get_ngrams(translations, max_order)
142

143
    overlap = dict((ngram, min(count, translation_ngram_counts[ngram]))
144
                   for ngram, count in ref_ngram_counts.items())
145

146
    for ngram in overlap:
147
      matches_by_order[len(ngram) - 1] += overlap[ngram]
148
    for ngram in translation_ngram_counts:
149
      possible_matches_by_order[len(ngram) -
150
                                1] += translation_ngram_counts[ngram]
151

152
  return (np.array(matches_by_order),
153
          np.array(possible_matches_by_order),
154
          np.array(reference_length),
155
          np.array(translation_length))
156

157

158
def bleu_partial(ref_lines, hyp_lines, case_sensitive=False):
159
  """Compute n-gram statistics for two lists of references and translations."""
160
  if len(ref_lines) != len(hyp_lines):
161
    raise ValueError("Reference and translation lists have different "
162
                     "numbers of lines.")
163
  if not case_sensitive:
164
    ref_lines = [x.lower() for x in ref_lines]
165
    hyp_lines = [x.lower() for x in hyp_lines]
166
  ref_tokens = [bleu_tokenize(x) for x in ref_lines]
167
  hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
168
  return compute_bleu_matches(ref_tokens, hyp_tokens)
169

170

171
def complete_bleu(matches_by_order,
172
                  possible_matches_by_order,
173
                  reference_length,
174
                  translation_length,
175
                  max_order=4,
176
                  use_bp=True):
177
  """Compute BLEU score from aggregated n-gram statistics."""
178
  precisions = [0] * max_order
179
  smooth = 1.0
180
  geo_mean = 0.0
181
  for i in range(0, max_order):
182
    if possible_matches_by_order[i] > 0:
183
      precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
184
      if matches_by_order[i] > 0:
185
        precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
186
      else:
187
        smooth *= 2
188
        precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
189
    else:
190
      precisions[i] = 0.0
191

192
  if max(precisions) > 0:
193
    p_log_sum = sum(math.log(p) for p in precisions if p)
194
    geo_mean = math.exp(p_log_sum / max_order)
195

196
  if use_bp:
197
    if not reference_length:
198
      bp = 1.0
199
    else:
200
      ratio = translation_length / reference_length
201
      if ratio <= 0.0:
202
        bp = 0.0
203
      elif ratio >= 1.0:
204
        bp = 1.0
205
      else:
206
        bp = math.exp(1 - 1. / ratio)
207
  bleu = geo_mean * bp
208
  return float(bleu) * 100.0
209

210

211
def bleu_local(ref_lines, hyp_lines, case_sensitive=False):
212
  """Compute BLEU for two lists of reference and hypothesis translations."""
213
  stats = bleu_partial(ref_lines, hyp_lines, case_sensitive=case_sensitive)
214
  return complete_bleu(*stats) * 100
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.