datasets

Форк
0
/
xtreme_s.py 
269 строк · 10.6 Кб
1
# Copyright 2022 The HuggingFace Datasets Authors.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""XTREME-S benchmark metric."""
15

16
from typing import List
17

18
from packaging import version
19
from sklearn.metrics import f1_score
20

21
import datasets
22
from datasets.config import PY_VERSION
23

24

25
if PY_VERSION < version.parse("3.8"):
26
    import importlib_metadata
27
else:
28
    import importlib.metadata as importlib_metadata
29

30

31
# TODO(Patrick/Anton)
32
_CITATION = """\
33
"""
34

35
_DESCRIPTION = """\
36
    XTREME-S is a benchmark to evaluate universal cross-lingual speech representations in many languages.
37
    XTREME-S covers four task families: speech recognition, classification, speech-to-text translation and retrieval.
38
"""
39

40
_KWARGS_DESCRIPTION = """
41
Compute XTREME-S evaluation metric associated to each XTREME-S dataset.
42
Args:
43
    predictions: list of predictions to score.
44
        Each translation should be tokenized into a list of tokens.
45
    references: list of lists of references for each translation.
46
        Each reference should be tokenized into a list of tokens.
47
    bleu_kwargs: optional dict of keywords to be passed when computing 'bleu'.
48
        Keywords include Dict can be one of 'smooth_method', 'smooth_value', 'force', 'lowercase',
49
        'tokenize', 'use_effective_order'.
50
    wer_kwargs: optional dict of keywords to be passed when computing 'wer' and 'cer'.
51
        Keywords include 'concatenate_texts'.
52
Returns: depending on the XTREME-S task, one or several of:
53
    "accuracy": Accuracy - for 'fleurs-lang_id', 'minds14'
54
    "f1": F1 score - for 'minds14'
55
    "wer": Word error rate - for 'mls', 'fleurs-asr', 'voxpopuli', 'babel'
56
    "cer": Character error rate - for 'mls', 'fleurs-asr', 'voxpopuli', 'babel'
57
    "bleu": BLEU score according to the `sacrebleu` metric - for 'covost2'
58
Examples:
59

60
    >>> xtreme_s_metric = datasets.load_metric('xtreme_s', 'mls')  # 'mls', 'voxpopuli', 'fleurs-asr' or 'babel'
61
    >>> references = ["it is sunny here", "paper and pen are essentials"]
62
    >>> predictions = ["it's sunny", "paper pen are essential"]
63
    >>> results = xtreme_s_metric.compute(predictions=predictions, references=references)
64
    >>> print({k: round(v, 2) for k, v in results.items()})
65
    {'wer': 0.56, 'cer': 0.27}
66

67
    >>> xtreme_s_metric = datasets.load_metric('xtreme_s', 'covost2')
68
    >>> references = ["bonjour paris", "il est necessaire de faire du sport de temps en temp"]
69
    >>> predictions = ["bonjour paris", "il est important de faire du sport souvent"]
70
    >>> results = xtreme_s_metric.compute(predictions=predictions, references=references)
71
    >>> print({k: round(v, 2) for k, v in results.items()})
72
    {'bleu': 31.65}
73

74
    >>> xtreme_s_metric = datasets.load_metric('xtreme_s', 'fleurs-lang_id')
75
    >>> references = [0, 1, 0, 0, 1]
76
    >>> predictions = [0, 1, 1, 0, 0]
77
    >>> results = xtreme_s_metric.compute(predictions=predictions, references=references)
78
    >>> print({k: round(v, 2) for k, v in results.items()})
79
    {'accuracy': 0.6}
80

81
    >>> xtreme_s_metric = datasets.load_metric('xtreme_s', 'minds14')
82
    >>> references = [0, 1, 0, 0, 1]
83
    >>> predictions = [0, 1, 1, 0, 0]
84
    >>> results = xtreme_s_metric.compute(predictions=predictions, references=references)
85
    >>> print({k: round(v, 2) for k, v in results.items()})
86
    {'f1': 0.58, 'accuracy': 0.6}
87
"""
88

89
_CONFIG_NAMES = ["fleurs-asr", "mls", "voxpopuli", "babel", "covost2", "fleurs-lang_id", "minds14"]
90
SENTENCE_DELIMITER = ""
91

92
try:
93
    from jiwer import transforms as tr
94

95
    _jiwer_available = True
96
except ImportError:
97
    _jiwer_available = False
98

99
if _jiwer_available and version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):
100

101
    class SentencesToListOfCharacters(tr.AbstractTransform):
102
        def __init__(self, sentence_delimiter: str = " "):
103
            self.sentence_delimiter = sentence_delimiter
104

105
        def process_string(self, s: str):
106
            return list(s)
107

108
        def process_list(self, inp: List[str]):
109
            chars = []
110
            for sent_idx, sentence in enumerate(inp):
111
                chars.extend(self.process_string(sentence))
112
                if self.sentence_delimiter is not None and self.sentence_delimiter != "" and sent_idx < len(inp) - 1:
113
                    chars.append(self.sentence_delimiter)
114
            return chars
115

116
    cer_transform = tr.Compose(
117
        [tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
118
    )
119
elif _jiwer_available:
120
    cer_transform = tr.Compose(
121
        [
122
            tr.RemoveMultipleSpaces(),
123
            tr.Strip(),
124
            tr.ReduceToSingleSentence(SENTENCE_DELIMITER),
125
            tr.ReduceToListOfListOfChars(),
126
        ]
127
    )
128
else:
129
    cer_transform = None
130

131

132
def simple_accuracy(preds, labels):
133
    return float((preds == labels).mean())
134

135

136
def f1_and_simple_accuracy(preds, labels):
137
    return {
138
        "f1": float(f1_score(y_true=labels, y_pred=preds, average="macro")),
139
        "accuracy": simple_accuracy(preds, labels),
140
    }
141

142

143
def bleu(
144
    preds,
145
    labels,
146
    smooth_method="exp",
147
    smooth_value=None,
148
    force=False,
149
    lowercase=False,
150
    tokenize=None,
151
    use_effective_order=False,
152
):
153
    # xtreme-s can only have one label
154
    labels = [[label] for label in labels]
155
    preds = list(preds)
156
    try:
157
        import sacrebleu as scb
158
    except ImportError:
159
        raise ValueError(
160
            "sacrebleu has to be installed in order to apply the bleu metric for covost2."
161
            "You can install it via `pip install sacrebleu`."
162
        )
163

164
    if version.parse(scb.__version__) < version.parse("1.4.12"):
165
        raise ImportWarning(
166
            "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n"
167
            'You can install it with `pip install "sacrebleu>=1.4.12"`.'
168
        )
169

170
    references_per_prediction = len(labels[0])
171
    if any(len(refs) != references_per_prediction for refs in labels):
172
        raise ValueError("Sacrebleu requires the same number of references for each prediction")
173
    transformed_references = [[refs[i] for refs in labels] for i in range(references_per_prediction)]
174
    output = scb.corpus_bleu(
175
        preds,
176
        transformed_references,
177
        smooth_method=smooth_method,
178
        smooth_value=smooth_value,
179
        force=force,
180
        lowercase=lowercase,
181
        use_effective_order=use_effective_order,
182
        **({"tokenize": tokenize} if tokenize else {}),
183
    )
184
    return {"bleu": output.score}
185

186

187
def wer_and_cer(preds, labels, concatenate_texts, config_name):
188
    try:
189
        from jiwer import compute_measures
190
    except ImportError:
191
        raise ValueError(
192
            f"jiwer has to be installed in order to apply the wer metric for {config_name}."
193
            "You can install it via `pip install jiwer`."
194
        )
195

196
    if concatenate_texts:
197
        wer = compute_measures(labels, preds)["wer"]
198

199
        cer = compute_measures(labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform)["wer"]
200
        return {"wer": wer, "cer": cer}
201
    else:
202

203
        def compute_score(preds, labels, score_type="wer"):
204
            incorrect = 0
205
            total = 0
206
            for prediction, reference in zip(preds, labels):
207
                if score_type == "wer":
208
                    measures = compute_measures(reference, prediction)
209
                elif score_type == "cer":
210
                    measures = compute_measures(
211
                        reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
212
                    )
213
                incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
214
                total += measures["substitutions"] + measures["deletions"] + measures["hits"]
215
            return incorrect / total
216

217
        return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
218

219

220
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
221
class XtremeS(datasets.Metric):
222
    def _info(self):
223
        if self.config_name not in _CONFIG_NAMES:
224
            raise KeyError(f"You should supply a configuration name selected in {_CONFIG_NAMES}")
225

226
        pred_type = "int64" if self.config_name in ["fleurs-lang_id", "minds14"] else "string"
227

228
        return datasets.MetricInfo(
229
            description=_DESCRIPTION,
230
            citation=_CITATION,
231
            inputs_description=_KWARGS_DESCRIPTION,
232
            features=datasets.Features(
233
                {"predictions": datasets.Value(pred_type), "references": datasets.Value(pred_type)}
234
            ),
235
            codebase_urls=[],
236
            reference_urls=[],
237
            format="numpy",
238
        )
239

240
    def _compute(self, predictions, references, bleu_kwargs=None, wer_kwargs=None):
241
        bleu_kwargs = bleu_kwargs if bleu_kwargs is not None else {}
242
        wer_kwargs = wer_kwargs if wer_kwargs is not None else {}
243

244
        if self.config_name == "fleurs-lang_id":
245
            return {"accuracy": simple_accuracy(predictions, references)}
246
        elif self.config_name == "minds14":
247
            return f1_and_simple_accuracy(predictions, references)
248
        elif self.config_name == "covost2":
249
            smooth_method = bleu_kwargs.pop("smooth_method", "exp")
250
            smooth_value = bleu_kwargs.pop("smooth_value", None)
251
            force = bleu_kwargs.pop("force", False)
252
            lowercase = bleu_kwargs.pop("lowercase", False)
253
            tokenize = bleu_kwargs.pop("tokenize", None)
254
            use_effective_order = bleu_kwargs.pop("use_effective_order", False)
255
            return bleu(
256
                preds=predictions,
257
                labels=references,
258
                smooth_method=smooth_method,
259
                smooth_value=smooth_value,
260
                force=force,
261
                lowercase=lowercase,
262
                tokenize=tokenize,
263
                use_effective_order=use_effective_order,
264
            )
265
        elif self.config_name in ["fleurs-asr", "mls", "voxpopuli", "babel"]:
266
            concatenate_texts = wer_kwargs.pop("concatenate_texts", False)
267
            return wer_and_cer(predictions, references, concatenate_texts, self.config_name)
268
        else:
269
            raise KeyError(f"You should supply a configuration name selected in {_CONFIG_NAMES}")
270

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.