CSS-LM

Форк
0
85 строк · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
try:
18
    from scipy.stats import pearsonr, spearmanr
19
    from sklearn.metrics import matthews_corrcoef, f1_score
20

21
    _has_sklearn = True
22
except (AttributeError, ImportError):
23
    _has_sklearn = False
24

25

26
def is_sklearn_available():
27
    return _has_sklearn
28

29

30
if _has_sklearn:
31

32
    def simple_accuracy(preds, labels):
33
        return (preds == labels).mean()
34

35
    def acc_and_f1(preds, labels):
36
        acc = simple_accuracy(preds, labels)
37
        f1 = f1_score(y_true=labels, y_pred=preds)
38
        return {
39
            "acc": acc,
40
            "f1": f1,
41
            "acc_and_f1": (acc + f1) / 2,
42
        }
43

44
    def pearson_and_spearman(preds, labels):
45
        pearson_corr = pearsonr(preds, labels)[0]
46
        spearman_corr = spearmanr(preds, labels)[0]
47
        return {
48
            "pearson": pearson_corr,
49
            "spearmanr": spearman_corr,
50
            "corr": (pearson_corr + spearman_corr) / 2,
51
        }
52

53
    def glue_compute_metrics(task_name, preds, labels):
54
        assert len(preds) == len(labels)
55
        if task_name == "cola":
56
            return {"mcc": matthews_corrcoef(labels, preds)}
57
        elif task_name == "sst-2":
58
            return {"acc": simple_accuracy(preds, labels)}
59
        elif task_name == "mrpc":
60
            return acc_and_f1(preds, labels)
61
        elif task_name == "sts-b":
62
            return pearson_and_spearman(preds, labels)
63
        elif task_name == "qqp":
64
            return acc_and_f1(preds, labels)
65
        elif task_name == "mnli":
66
            return {"mnli/acc": simple_accuracy(preds, labels)}
67
        elif task_name == "mnli-mm":
68
            return {"mnli-mm/acc": simple_accuracy(preds, labels)}
69
        elif task_name == "qnli":
70
            return {"acc": simple_accuracy(preds, labels)}
71
        elif task_name == "rte":
72
            return {"acc": simple_accuracy(preds, labels)}
73
        elif task_name == "wnli":
74
            return {"acc": simple_accuracy(preds, labels)}
75
        elif task_name == "hans":
76
            return {"acc": simple_accuracy(preds, labels)}
77
        else:
78
            raise KeyError(task_name)
79

80
    def xnli_compute_metrics(task_name, preds, labels):
81
        assert len(preds) == len(labels)
82
        if task_name == "xnli":
83
            return {"acc": simple_accuracy(preds, labels)}
84
        else:
85
            raise KeyError(task_name)
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.