1
# Copyright 2020 The HuggingFace Datasets Authors.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""GLUE benchmark metric."""
16
from scipy.stats import pearsonr, spearmanr
17
from sklearn.metrics import f1_score, matthews_corrcoef
23
@inproceedings{wang2019glue,
24
title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
25
author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
26
note={In the Proceedings of ICLR.},
32
GLUE, the General Language Understanding Evaluation benchmark
33
(https://gluebenchmark.com/) is a collection of resources for training,
34
evaluating, and analyzing natural language understanding systems.
37
_KWARGS_DESCRIPTION = """
38
Compute GLUE evaluation metric associated to each GLUE dataset.
40
predictions: list of predictions to score.
41
Each translation should be tokenized into a list of tokens.
42
references: list of lists of references for each translation.
43
Each reference should be tokenized into a list of tokens.
44
Returns: depending on the GLUE subset, one or several of:
47
"pearson": Pearson Correlation
48
"spearmanr": Spearman Correlation
49
"matthews_correlation": Matthew Correlation
52
>>> glue_metric = datasets.load_metric('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
53
>>> references = [0, 1]
54
>>> predictions = [0, 1]
55
>>> results = glue_metric.compute(predictions=predictions, references=references)
59
>>> glue_metric = datasets.load_metric('glue', 'mrpc') # 'mrpc' or 'qqp'
60
>>> references = [0, 1]
61
>>> predictions = [0, 1]
62
>>> results = glue_metric.compute(predictions=predictions, references=references)
64
{'accuracy': 1.0, 'f1': 1.0}
66
>>> glue_metric = datasets.load_metric('glue', 'stsb')
67
>>> references = [0., 1., 2., 3., 4., 5.]
68
>>> predictions = [0., 1., 2., 3., 4., 5.]
69
>>> results = glue_metric.compute(predictions=predictions, references=references)
70
>>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)})
71
{'pearson': 1.0, 'spearmanr': 1.0}
73
>>> glue_metric = datasets.load_metric('glue', 'cola')
74
>>> references = [0, 1]
75
>>> predictions = [0, 1]
76
>>> results = glue_metric.compute(predictions=predictions, references=references)
78
{'matthews_correlation': 1.0}
82
def simple_accuracy(preds, labels):
83
return float((preds == labels).mean())
86
def acc_and_f1(preds, labels):
87
acc = simple_accuracy(preds, labels)
88
f1 = float(f1_score(y_true=labels, y_pred=preds))
95
def pearson_and_spearman(preds, labels):
96
pearson_corr = float(pearsonr(preds, labels)[0])
97
spearman_corr = float(spearmanr(preds, labels)[0])
99
"pearson": pearson_corr,
100
"spearmanr": spearman_corr,
104
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
105
class Glue(datasets.Metric):
107
if self.config_name not in [
122
"You should supply a configuration name selected in "
123
'["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
124
'"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
126
return datasets.MetricInfo(
127
description=_DESCRIPTION,
129
inputs_description=_KWARGS_DESCRIPTION,
130
features=datasets.Features(
132
"predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
133
"references": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
141
def _compute(self, predictions, references):
142
if self.config_name == "cola":
143
return {"matthews_correlation": matthews_corrcoef(references, predictions)}
144
elif self.config_name == "stsb":
145
return pearson_and_spearman(predictions, references)
146
elif self.config_name in ["mrpc", "qqp"]:
147
return acc_and_f1(predictions, references)
148
elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
149
return {"accuracy": simple_accuracy(predictions, references)}
152
"You should supply a configuration name selected in "
153
'["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
154
'"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'