datasets

Форк
0
155 строк · 5.6 Кб
1
# Copyright 2020 The HuggingFace Datasets Authors.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""GLUE benchmark metric."""
15

16
from scipy.stats import pearsonr, spearmanr
17
from sklearn.metrics import f1_score, matthews_corrcoef
18

19
import datasets
20

21

22
_CITATION = """\
23
@inproceedings{wang2019glue,
24
  title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
25
  author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
26
  note={In the Proceedings of ICLR.},
27
  year={2019}
28
}
29
"""
30

31
_DESCRIPTION = """\
32
GLUE, the General Language Understanding Evaluation benchmark
33
(https://gluebenchmark.com/) is a collection of resources for training,
34
evaluating, and analyzing natural language understanding systems.
35
"""
36

37
_KWARGS_DESCRIPTION = """
38
Compute GLUE evaluation metric associated to each GLUE dataset.
39
Args:
40
    predictions: list of predictions to score.
41
        Each translation should be tokenized into a list of tokens.
42
    references: list of lists of references for each translation.
43
        Each reference should be tokenized into a list of tokens.
44
Returns: depending on the GLUE subset, one or several of:
45
    "accuracy": Accuracy
46
    "f1": F1 score
47
    "pearson": Pearson Correlation
48
    "spearmanr": Spearman Correlation
49
    "matthews_correlation": Matthew Correlation
50
Examples:
51

52
    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
53
    >>> references = [0, 1]
54
    >>> predictions = [0, 1]
55
    >>> results = glue_metric.compute(predictions=predictions, references=references)
56
    >>> print(results)
57
    {'accuracy': 1.0}
58

59
    >>> glue_metric = datasets.load_metric('glue', 'mrpc')  # 'mrpc' or 'qqp'
60
    >>> references = [0, 1]
61
    >>> predictions = [0, 1]
62
    >>> results = glue_metric.compute(predictions=predictions, references=references)
63
    >>> print(results)
64
    {'accuracy': 1.0, 'f1': 1.0}
65

66
    >>> glue_metric = datasets.load_metric('glue', 'stsb')
67
    >>> references = [0., 1., 2., 3., 4., 5.]
68
    >>> predictions = [0., 1., 2., 3., 4., 5.]
69
    >>> results = glue_metric.compute(predictions=predictions, references=references)
70
    >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)})
71
    {'pearson': 1.0, 'spearmanr': 1.0}
72

73
    >>> glue_metric = datasets.load_metric('glue', 'cola')
74
    >>> references = [0, 1]
75
    >>> predictions = [0, 1]
76
    >>> results = glue_metric.compute(predictions=predictions, references=references)
77
    >>> print(results)
78
    {'matthews_correlation': 1.0}
79
"""
80

81

82
def simple_accuracy(preds, labels):
83
    return float((preds == labels).mean())
84

85

86
def acc_and_f1(preds, labels):
87
    acc = simple_accuracy(preds, labels)
88
    f1 = float(f1_score(y_true=labels, y_pred=preds))
89
    return {
90
        "accuracy": acc,
91
        "f1": f1,
92
    }
93

94

95
def pearson_and_spearman(preds, labels):
96
    pearson_corr = float(pearsonr(preds, labels)[0])
97
    spearman_corr = float(spearmanr(preds, labels)[0])
98
    return {
99
        "pearson": pearson_corr,
100
        "spearmanr": spearman_corr,
101
    }
102

103

104
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
105
class Glue(datasets.Metric):
106
    def _info(self):
107
        if self.config_name not in [
108
            "sst2",
109
            "mnli",
110
            "mnli_mismatched",
111
            "mnli_matched",
112
            "cola",
113
            "stsb",
114
            "mrpc",
115
            "qqp",
116
            "qnli",
117
            "rte",
118
            "wnli",
119
            "hans",
120
        ]:
121
            raise KeyError(
122
                "You should supply a configuration name selected in "
123
                '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
124
                '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
125
            )
126
        return datasets.MetricInfo(
127
            description=_DESCRIPTION,
128
            citation=_CITATION,
129
            inputs_description=_KWARGS_DESCRIPTION,
130
            features=datasets.Features(
131
                {
132
                    "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
133
                    "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
134
                }
135
            ),
136
            codebase_urls=[],
137
            reference_urls=[],
138
            format="numpy",
139
        )
140

141
    def _compute(self, predictions, references):
142
        if self.config_name == "cola":
143
            return {"matthews_correlation": matthews_corrcoef(references, predictions)}
144
        elif self.config_name == "stsb":
145
            return pearson_and_spearman(predictions, references)
146
        elif self.config_name in ["mrpc", "qqp"]:
147
            return acc_and_f1(predictions, references)
148
        elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
149
            return {"accuracy": simple_accuracy(predictions, references)}
150
        else:
151
            raise KeyError(
152
                "You should supply a configuration name selected in "
153
                '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
154
                '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
155
            )
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.