google-research
214 строк · 7.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Parallel BLEU score calculation.
17
18This version of BLEU calculation is derived from the MLPerf transformer
19reference.
20Tries to match SacreBLEU metric reasonably well, but is not identical.
21
22Refs:
23tokenizer at:
24https://github.com/tensorflow/models/blob/master/official/transformer/utils/tokenizer.py
25original preprocessing tokenizer:
26https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
27original t2t code:
28https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
29
30Usage:
31refs = '''food bar brown cow
32blee bloo dog sat
33or please take me out
34'''
35hyps = '''foo bar brown cow
36blee bloo dog sit
37please do take me out
38'''
39bleu_local(refs.split("\n"), hyps.split("\n")) # 39.65
40"""
41
42import collections
43import math
44import re
45import sys
46import unicodedata
47import numpy as np
48import six
49
50
51class UnicodeRegex(object):
52"""Ad-hoc hack to recognize all punctuation and symbols."""
53
54def __init__(self):
55punctuation = self.property_chars("P")
56self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
57self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
58self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
59
60def property_chars(self, prefix):
61return "".join(
62six.unichr(x)
63for x in range(sys.maxunicode)
64if unicodedata.category(six.unichr(x)).startswith(prefix))
65
66
67uregex = UnicodeRegex()
68
69
70def bleu_tokenize(string):
71r"""Tokenize a string following the official BLEU implementation.
72
73See https://github.com/moses-smt/mosesdecoder/'
74'blob/master/scripts/generic/mteval-v14.pl#L954-L983
75In our case, the input string is expected to be just one line
76and no HTML entities de-escaping is needed.
77So we just tokenize on punctuation and symbols,
78except when a punctuation is preceded and followed by a digit
79(e.g. a comma/dot as a thousand/decimal separator).
80
81Note that a number (e.g. a year) followed by a dot at the end of sentence
82is NOT tokenized, i.e. the dot stays with the number because
83`s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a
84space after each sentence). However, this error is already in the
85original mteval-v14.pl and we want to be consistent with it.
86
87Args:
88string: the input string
89
90Returns:
91a list of tokens
92"""
93string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
94string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
95string = uregex.symbol_re.sub(r" \1 ", string)
96return string.split()
97
98
99def _get_ngrams(segment, max_order):
100"""Extracts all n-grams up to a given maximum order from an input segment.
101
102Args:
103segment: text segment from which n-grams will be extracted.
104max_order: maximum length in tokens of the n-grams returned by this methods.
105
106Returns:
107The Counter containing all n-grams up to max_order in segment
108with a count of how many times each n-gram occurred.
109"""
110ngram_counts = collections.Counter()
111for order in range(1, max_order + 1):
112for i in range(0, len(segment) - order + 1):
113ngram = tuple(segment[i:i + order])
114ngram_counts[ngram] += 1
115return ngram_counts
116
117
118def compute_bleu_matches(reference_corpus, translation_corpus, max_order=4):
119"""Computes BLEU match stats of translations against one or more references.
120
121Args:
122reference_corpus: list of references for each translation. Each reference
123should be tokenized into a list of tokens.
124translation_corpus: list of translations to score. Each translation should
125be tokenized into a list of tokens.
126max_order: Maximum n-gram order to use when computing BLEU score.
127
128Returns:
129Aggregated n-gram stats for BLEU calculation.
130"""
131reference_length = 0
132translation_length = 0
133
134matches_by_order = [0] * max_order
135possible_matches_by_order = [0] * max_order
136
137for (references, translations) in zip(reference_corpus, translation_corpus):
138reference_length += len(references)
139translation_length += len(translations)
140ref_ngram_counts = _get_ngrams(references, max_order)
141translation_ngram_counts = _get_ngrams(translations, max_order)
142
143overlap = dict((ngram, min(count, translation_ngram_counts[ngram]))
144for ngram, count in ref_ngram_counts.items())
145
146for ngram in overlap:
147matches_by_order[len(ngram) - 1] += overlap[ngram]
148for ngram in translation_ngram_counts:
149possible_matches_by_order[len(ngram) -
1501] += translation_ngram_counts[ngram]
151
152return (np.array(matches_by_order),
153np.array(possible_matches_by_order),
154np.array(reference_length),
155np.array(translation_length))
156
157
158def bleu_partial(ref_lines, hyp_lines, case_sensitive=False):
159"""Compute n-gram statistics for two lists of references and translations."""
160if len(ref_lines) != len(hyp_lines):
161raise ValueError("Reference and translation lists have different "
162"numbers of lines.")
163if not case_sensitive:
164ref_lines = [x.lower() for x in ref_lines]
165hyp_lines = [x.lower() for x in hyp_lines]
166ref_tokens = [bleu_tokenize(x) for x in ref_lines]
167hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
168return compute_bleu_matches(ref_tokens, hyp_tokens)
169
170
171def complete_bleu(matches_by_order,
172possible_matches_by_order,
173reference_length,
174translation_length,
175max_order=4,
176use_bp=True):
177"""Compute BLEU score from aggregated n-gram statistics."""
178precisions = [0] * max_order
179smooth = 1.0
180geo_mean = 0.0
181for i in range(0, max_order):
182if possible_matches_by_order[i] > 0:
183precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
184if matches_by_order[i] > 0:
185precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
186else:
187smooth *= 2
188precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
189else:
190precisions[i] = 0.0
191
192if max(precisions) > 0:
193p_log_sum = sum(math.log(p) for p in precisions if p)
194geo_mean = math.exp(p_log_sum / max_order)
195
196if use_bp:
197if not reference_length:
198bp = 1.0
199else:
200ratio = translation_length / reference_length
201if ratio <= 0.0:
202bp = 0.0
203elif ratio >= 1.0:
204bp = 1.0
205else:
206bp = math.exp(1 - 1. / ratio)
207bleu = geo_mean * bp
208return float(bleu) * 100.0
209
210
211def bleu_local(ref_lines, hyp_lines, case_sensitive=False):
212"""Compute BLEU for two lists of reference and hypothesis translations."""
213stats = bleu_partial(ref_lines, hyp_lines, case_sensitive=case_sensitive)
214return complete_bleu(*stats) * 100
215