datasets

Форк
0
/
super_glue.py 
236 строк · 9.5 Кб
1
# Copyright 2020 The HuggingFace Datasets Authors.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""The SuperGLUE benchmark metric."""
15

16
from sklearn.metrics import f1_score, matthews_corrcoef
17

18
import datasets
19

20
from .record_evaluation import evaluate as evaluate_record
21

22

23
_CITATION = """\
24
@article{wang2019superglue,
25
  title={SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
26
  author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
27
  journal={arXiv preprint arXiv:1905.00537},
28
  year={2019}
29
}
30
"""
31

32
_DESCRIPTION = """\
33
SuperGLUE (https://super.gluebenchmark.com/) is a new benchmark styled after
34
GLUE with a new set of more difficult language understanding tasks, improved
35
resources, and a new public leaderboard.
36
"""
37

38
_KWARGS_DESCRIPTION = """
39
Compute SuperGLUE evaluation metric associated to each SuperGLUE dataset.
40
Args:
41
    predictions: list of predictions to score. Depending on the SuperGlUE subset:
42
        - for 'record': list of question-answer dictionaries with the following keys:
43
            - 'idx': index of the question as specified by the dataset
44
            - 'prediction_text': the predicted answer text
45
        - for 'multirc': list of question-answer dictionaries with the following keys:
46
            - 'idx': index of the question-answer pair as specified by the dataset
47
            - 'prediction': the predicted answer label
48
        - otherwise: list of predicted labels
49
    references: list of reference labels. Depending on the SuperGLUE subset:
50
        - for 'record': list of question-answers dictionaries with the following keys:
51
            - 'idx': index of the question as specified by the dataset
52
            - 'answers': list of possible answers
53
        - otherwise: list of reference labels
54
Returns: depending on the SuperGLUE subset:
55
    - for 'record':
56
        - 'exact_match': Exact match between answer and gold answer
57
        - 'f1': F1 score
58
    - for 'multirc':
59
        - 'exact_match': Exact match between answer and gold answer
60
        - 'f1_m': Per-question macro-F1 score
61
        - 'f1_a': Average F1 score over all answers
62
    - for 'axb':
63
        'matthews_correlation': Matthew Correlation
64
    - for 'cb':
65
        - 'accuracy': Accuracy
66
        - 'f1': F1 score
67
    - for all others:
68
        - 'accuracy': Accuracy
69
Examples:
70

71
    >>> super_glue_metric = datasets.load_metric('super_glue', 'copa')  # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
72
    >>> predictions = [0, 1]
73
    >>> references = [0, 1]
74
    >>> results = super_glue_metric.compute(predictions=predictions, references=references)
75
    >>> print(results)
76
    {'accuracy': 1.0}
77

78
    >>> super_glue_metric = datasets.load_metric('super_glue', 'cb')
79
    >>> predictions = [0, 1]
80
    >>> references = [0, 1]
81
    >>> results = super_glue_metric.compute(predictions=predictions, references=references)
82
    >>> print(results)
83
    {'accuracy': 1.0, 'f1': 1.0}
84

85
    >>> super_glue_metric = datasets.load_metric('super_glue', 'record')
86
    >>> predictions = [{'idx': {'passage': 0, 'query': 0}, 'prediction_text': 'answer'}]
87
    >>> references = [{'idx': {'passage': 0, 'query': 0}, 'answers': ['answer', 'another_answer']}]
88
    >>> results = super_glue_metric.compute(predictions=predictions, references=references)
89
    >>> print(results)
90
    {'exact_match': 1.0, 'f1': 1.0}
91

92
    >>> super_glue_metric = datasets.load_metric('super_glue', 'multirc')
93
    >>> predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
94
    >>> references = [0, 1]
95
    >>> results = super_glue_metric.compute(predictions=predictions, references=references)
96
    >>> print(results)
97
    {'exact_match': 1.0, 'f1_m': 1.0, 'f1_a': 1.0}
98

99
    >>> super_glue_metric = datasets.load_metric('super_glue', 'axb')
100
    >>> references = [0, 1]
101
    >>> predictions = [0, 1]
102
    >>> results = super_glue_metric.compute(predictions=predictions, references=references)
103
    >>> print(results)
104
    {'matthews_correlation': 1.0}
105
"""
106

107

108
def simple_accuracy(preds, labels):
109
    return float((preds == labels).mean())
110

111

112
def acc_and_f1(preds, labels, f1_avg="binary"):
113
    acc = simple_accuracy(preds, labels)
114
    f1 = float(f1_score(y_true=labels, y_pred=preds, average=f1_avg))
115
    return {
116
        "accuracy": acc,
117
        "f1": f1,
118
    }
119

120

121
def evaluate_multirc(ids_preds, labels):
122
    """
123
    Computes F1 score and Exact Match for MultiRC predictions.
124
    """
125
    question_map = {}
126
    for id_pred, label in zip(ids_preds, labels):
127
        question_id = f'{id_pred["idx"]["paragraph"]}-{id_pred["idx"]["question"]}'
128
        pred = id_pred["prediction"]
129
        if question_id in question_map:
130
            question_map[question_id].append((pred, label))
131
        else:
132
            question_map[question_id] = [(pred, label)]
133
    f1s, ems = [], []
134
    for question, preds_labels in question_map.items():
135
        question_preds, question_labels = zip(*preds_labels)
136
        f1 = f1_score(y_true=question_labels, y_pred=question_preds, average="macro")
137
        f1s.append(f1)
138
        em = int(sum(pred == label for pred, label in preds_labels) == len(preds_labels))
139
        ems.append(em)
140
    f1_m = float(sum(f1s) / len(f1s))
141
    em = sum(ems) / len(ems)
142
    f1_a = float(f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds]))
143
    return {"exact_match": em, "f1_m": f1_m, "f1_a": f1_a}
144

145

146
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
147
class SuperGlue(datasets.Metric):
148
    def _info(self):
149
        if self.config_name not in [
150
            "boolq",
151
            "cb",
152
            "copa",
153
            "multirc",
154
            "record",
155
            "rte",
156
            "wic",
157
            "wsc",
158
            "wsc.fixed",
159
            "axb",
160
            "axg",
161
        ]:
162
            raise KeyError(
163
                "You should supply a configuration name selected in "
164
                '["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
165
            )
166
        return datasets.MetricInfo(
167
            description=_DESCRIPTION,
168
            citation=_CITATION,
169
            inputs_description=_KWARGS_DESCRIPTION,
170
            features=datasets.Features(self._get_feature_types()),
171
            codebase_urls=[],
172
            reference_urls=[],
173
            format="numpy" if not self.config_name == "record" and not self.config_name == "multirc" else None,
174
        )
175

176
    def _get_feature_types(self):
177
        if self.config_name == "record":
178
            return {
179
                "predictions": {
180
                    "idx": {
181
                        "passage": datasets.Value("int64"),
182
                        "query": datasets.Value("int64"),
183
                    },
184
                    "prediction_text": datasets.Value("string"),
185
                },
186
                "references": {
187
                    "idx": {
188
                        "passage": datasets.Value("int64"),
189
                        "query": datasets.Value("int64"),
190
                    },
191
                    "answers": datasets.Sequence(datasets.Value("string")),
192
                },
193
            }
194
        elif self.config_name == "multirc":
195
            return {
196
                "predictions": {
197
                    "idx": {
198
                        "answer": datasets.Value("int64"),
199
                        "paragraph": datasets.Value("int64"),
200
                        "question": datasets.Value("int64"),
201
                    },
202
                    "prediction": datasets.Value("int64"),
203
                },
204
                "references": datasets.Value("int64"),
205
            }
206
        else:
207
            return {
208
                "predictions": datasets.Value("int64"),
209
                "references": datasets.Value("int64"),
210
            }
211

212
    def _compute(self, predictions, references):
213
        if self.config_name == "axb":
214
            return {"matthews_correlation": matthews_corrcoef(references, predictions)}
215
        elif self.config_name == "cb":
216
            return acc_and_f1(predictions, references, f1_avg="macro")
217
        elif self.config_name == "record":
218
            dataset = [
219
                {
220
                    "qas": [
221
                        {"id": ref["idx"]["query"], "answers": [{"text": ans} for ans in ref["answers"]]}
222
                        for ref in references
223
                    ]
224
                }
225
            ]
226
            predictions = {pred["idx"]["query"]: pred["prediction_text"] for pred in predictions}
227
            return evaluate_record(dataset, predictions)[0]
228
        elif self.config_name == "multirc":
229
            return evaluate_multirc(predictions, references)
230
        elif self.config_name in ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]:
231
            return {"accuracy": simple_accuracy(predictions, references)}
232
        else:
233
            raise KeyError(
234
                "You should supply a configuration name selected in "
235
                '["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
236
            )
237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.