datasets
236 строк · 9.5 Кб
1# Copyright 2020 The HuggingFace Datasets Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The SuperGLUE benchmark metric."""
15
16from sklearn.metrics import f1_score, matthews_corrcoef17
18import datasets19
20from .record_evaluation import evaluate as evaluate_record21
22
23_CITATION = """\24@article{wang2019superglue,
25title={SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
26author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
27journal={arXiv preprint arXiv:1905.00537},
28year={2019}
29}
30"""
31
32_DESCRIPTION = """\33SuperGLUE (https://super.gluebenchmark.com/) is a new benchmark styled after
34GLUE with a new set of more difficult language understanding tasks, improved
35resources, and a new public leaderboard.
36"""
37
38_KWARGS_DESCRIPTION = """39Compute SuperGLUE evaluation metric associated to each SuperGLUE dataset.
40Args:
41predictions: list of predictions to score. Depending on the SuperGlUE subset:
42- for 'record': list of question-answer dictionaries with the following keys:
43- 'idx': index of the question as specified by the dataset
44- 'prediction_text': the predicted answer text
45- for 'multirc': list of question-answer dictionaries with the following keys:
46- 'idx': index of the question-answer pair as specified by the dataset
47- 'prediction': the predicted answer label
48- otherwise: list of predicted labels
49references: list of reference labels. Depending on the SuperGLUE subset:
50- for 'record': list of question-answers dictionaries with the following keys:
51- 'idx': index of the question as specified by the dataset
52- 'answers': list of possible answers
53- otherwise: list of reference labels
54Returns: depending on the SuperGLUE subset:
55- for 'record':
56- 'exact_match': Exact match between answer and gold answer
57- 'f1': F1 score
58- for 'multirc':
59- 'exact_match': Exact match between answer and gold answer
60- 'f1_m': Per-question macro-F1 score
61- 'f1_a': Average F1 score over all answers
62- for 'axb':
63'matthews_correlation': Matthew Correlation
64- for 'cb':
65- 'accuracy': Accuracy
66- 'f1': F1 score
67- for all others:
68- 'accuracy': Accuracy
69Examples:
70
71>>> super_glue_metric = datasets.load_metric('super_glue', 'copa') # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
72>>> predictions = [0, 1]
73>>> references = [0, 1]
74>>> results = super_glue_metric.compute(predictions=predictions, references=references)
75>>> print(results)
76{'accuracy': 1.0}
77
78>>> super_glue_metric = datasets.load_metric('super_glue', 'cb')
79>>> predictions = [0, 1]
80>>> references = [0, 1]
81>>> results = super_glue_metric.compute(predictions=predictions, references=references)
82>>> print(results)
83{'accuracy': 1.0, 'f1': 1.0}
84
85>>> super_glue_metric = datasets.load_metric('super_glue', 'record')
86>>> predictions = [{'idx': {'passage': 0, 'query': 0}, 'prediction_text': 'answer'}]
87>>> references = [{'idx': {'passage': 0, 'query': 0}, 'answers': ['answer', 'another_answer']}]
88>>> results = super_glue_metric.compute(predictions=predictions, references=references)
89>>> print(results)
90{'exact_match': 1.0, 'f1': 1.0}
91
92>>> super_glue_metric = datasets.load_metric('super_glue', 'multirc')
93>>> predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
94>>> references = [0, 1]
95>>> results = super_glue_metric.compute(predictions=predictions, references=references)
96>>> print(results)
97{'exact_match': 1.0, 'f1_m': 1.0, 'f1_a': 1.0}
98
99>>> super_glue_metric = datasets.load_metric('super_glue', 'axb')
100>>> references = [0, 1]
101>>> predictions = [0, 1]
102>>> results = super_glue_metric.compute(predictions=predictions, references=references)
103>>> print(results)
104{'matthews_correlation': 1.0}
105"""
106
107
108def simple_accuracy(preds, labels):109return float((preds == labels).mean())110
111
112def acc_and_f1(preds, labels, f1_avg="binary"):113acc = simple_accuracy(preds, labels)114f1 = float(f1_score(y_true=labels, y_pred=preds, average=f1_avg))115return {116"accuracy": acc,117"f1": f1,118}119
120
121def evaluate_multirc(ids_preds, labels):122"""123Computes F1 score and Exact Match for MultiRC predictions.
124"""
125question_map = {}126for id_pred, label in zip(ids_preds, labels):127question_id = f'{id_pred["idx"]["paragraph"]}-{id_pred["idx"]["question"]}'128pred = id_pred["prediction"]129if question_id in question_map:130question_map[question_id].append((pred, label))131else:132question_map[question_id] = [(pred, label)]133f1s, ems = [], []134for question, preds_labels in question_map.items():135question_preds, question_labels = zip(*preds_labels)136f1 = f1_score(y_true=question_labels, y_pred=question_preds, average="macro")137f1s.append(f1)138em = int(sum(pred == label for pred, label in preds_labels) == len(preds_labels))139ems.append(em)140f1_m = float(sum(f1s) / len(f1s))141em = sum(ems) / len(ems)142f1_a = float(f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds]))143return {"exact_match": em, "f1_m": f1_m, "f1_a": f1_a}144
145
146@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)147class SuperGlue(datasets.Metric):148def _info(self):149if self.config_name not in [150"boolq",151"cb",152"copa",153"multirc",154"record",155"rte",156"wic",157"wsc",158"wsc.fixed",159"axb",160"axg",161]:162raise KeyError(163"You should supply a configuration name selected in "164'["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'165)166return datasets.MetricInfo(167description=_DESCRIPTION,168citation=_CITATION,169inputs_description=_KWARGS_DESCRIPTION,170features=datasets.Features(self._get_feature_types()),171codebase_urls=[],172reference_urls=[],173format="numpy" if not self.config_name == "record" and not self.config_name == "multirc" else None,174)175
176def _get_feature_types(self):177if self.config_name == "record":178return {179"predictions": {180"idx": {181"passage": datasets.Value("int64"),182"query": datasets.Value("int64"),183},184"prediction_text": datasets.Value("string"),185},186"references": {187"idx": {188"passage": datasets.Value("int64"),189"query": datasets.Value("int64"),190},191"answers": datasets.Sequence(datasets.Value("string")),192},193}194elif self.config_name == "multirc":195return {196"predictions": {197"idx": {198"answer": datasets.Value("int64"),199"paragraph": datasets.Value("int64"),200"question": datasets.Value("int64"),201},202"prediction": datasets.Value("int64"),203},204"references": datasets.Value("int64"),205}206else:207return {208"predictions": datasets.Value("int64"),209"references": datasets.Value("int64"),210}211
212def _compute(self, predictions, references):213if self.config_name == "axb":214return {"matthews_correlation": matthews_corrcoef(references, predictions)}215elif self.config_name == "cb":216return acc_and_f1(predictions, references, f1_avg="macro")217elif self.config_name == "record":218dataset = [219{220"qas": [221{"id": ref["idx"]["query"], "answers": [{"text": ans} for ans in ref["answers"]]}222for ref in references223]224}225]226predictions = {pred["idx"]["query"]: pred["prediction_text"] for pred in predictions}227return evaluate_record(dataset, predictions)[0]228elif self.config_name == "multirc":229return evaluate_multirc(predictions, references)230elif self.config_name in ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]:231return {"accuracy": simple_accuracy(predictions, references)}232else:233raise KeyError(234"You should supply a configuration name selected in "235'["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'236)237