lmops

Форк
0
/
metric.py 
177 строк · 5.7 Кб
1
from sklearn.metrics import f1_score, matthews_corrcoef
2
import numpy as np
3
from rouge import Rouge
4
from src.utils import qa_utils
5
from datasets import load_metric
6
import re
7

8
class App:
9
    def __init__(self):
10
        self.functions = {}
11

12
    def add(self, key):
13
        def adder(func):
14
            self.functions[key] = func
15
            return func
16

17
        return adder
18

19
    def __getitem__(self, __name: str):
20
        return self.functions[__name]
21

22

23
metric_dict = App()
24

25

26
@metric_dict.add("rouge")
27
def rouge(preds, labels, return_list=False):
28
    # https://github.com/pltrdy/rouge
29
    r1s, r2s, rls = [], [], []
30
    r = Rouge()
31
    for i in range(len(labels)):
32
        if "\n" not in preds[i]:
33
            preds[i] += "\n"  # to ensure rouge metrics
34
        if "\n" not in labels[i]:
35
            labels[i] += "\n"
36
        scores = r.get_scores(preds[i], labels[i])[0]
37
        r1s.append(scores["rouge-1"]["f"])
38
        r2s.append(scores["rouge-2"]["f"])
39
        rls.append(scores["rouge-l"]["f"])
40
    if return_list:  # used for scoring data
41
        return r1s
42
    r1 = sum(r1s) / len(r1s)
43
    r2 = sum(r2s) / len(r2s)
44
    rl = sum(rls) / len(rls)
45
    return r1, r2, rl
46

47

48
@metric_dict.add("squad")
49
def squad(labels, preds, return_list=False):
50
    """Computes SQuAD metrics, maximizing over answers per question.
51
    Args:
52
    labels: list of lists of strings
53
    preds: list of strings
54
    Returns:
55
    dict with score_key: squad score across all labels and predictions
56
    """
57
    labels = [[qa_utils.normalize_squad(t) for t in u] for u in labels]
58
    preds = [qa_utils.normalize_squad(p) for p in preds]
59
    if return_list:  # used for scoring data
60
        em, f1 = qa_utils.qa_metrics(labels, preds, return_list=True)
61
        return f1
62
    em, f1 = qa_utils.qa_metrics(labels, preds)  # em,f1
63
    return em, f1
64

65

66
@metric_dict.add("trivia_qa")
67
def trivia_qa(labels, preds, return_list=False):
68
    """Computes TriviaQA metrics, maximizing over answers per question.
69
    Args:
70
    labels: list of lists of strings
71
    preds: list of strings
72
    Returns:
73
    dict with score_key: squad score across all labels and preds
74
    """
75
    labels = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in labels]
76
    preds = [qa_utils.normalize_trivia_qa(p) for p in preds]
77
    if return_list:  # used for scoring data
78
        em, f1 = qa_utils.qa_metrics(labels, preds, return_list=True)
79
        return f1
80
    em, f1 = qa_utils.qa_metrics(labels, preds)  # em,f1
81
    return em, f1
82

83

84
@metric_dict.add("simple_accuracy")
85
def simple_accuracy(preds, labels, return_list=False):
86
    if isinstance(preds[0], str):
87
        labels = [label.strip() for label in labels]
88
        preds = [pred.strip() for pred in preds]
89
    res = [int(preds[i] == labels[i]) for i in range(len(preds))]
90
    if return_list:
91
        return res
92
    acc = sum(res) / len(res)
93
    return acc
94

95
@metric_dict.add("pubmed_qa_acc")
96
def pubmed_qa_acc(preds, labels, return_list=False):
97
    pattern=r'([.\s]*)(the answer is)(.*)' 
98
    regex=re.compile(pattern,re.IGNORECASE)
99

100
    res_list = []
101
    for i, pred in enumerate(preds):
102
        label = labels[i]
103
        if len(regex.findall(pred))>0:
104
            answer = regex.findall(pred)[-1][-1].lower()
105
            if "yes" in answer:
106
                acc = 1 if label=='yes' else 0
107
            elif "no" in answer:
108
                acc = 1 if label=='no' else 0
109
            elif "maybe" in answer:
110
                acc = 1 if label=='maybe' else 0
111
            else:
112
                acc = 0
113
        else:
114
            answer = None
115
            acc = 0
116
        res_list.append(acc)
117
    if return_list:
118
        return res_list
119
    return sum(res_list)/len(res_list)
120

121
def acc_and_f1(preds, labels):
122
    acc = simple_accuracy(preds, labels)
123
    f1 = f1_score(y_true=labels, y_pred=preds)
124
    return acc, f1, (acc + f1) / 2
125

126

127
def acc_and_matthews_corrcoef(preds, labels):
128
    acc = simple_accuracy(preds, labels)
129
    mcc = matthews_corrcoef(y_true=labels, y_pred=preds)
130
    return acc, mcc
131

132

133
def compute_bleu(preds, labels):
134
    BLEU = load_metric("bleu")
135
    predictions = [[ch for ch in text] for text in preds]
136
    references = [[[ch for ch in label]] for label in labels]
137
    return BLEU.compute(predictions=predictions, references=references)
138

139

140
def compute_metrics(metric, labels, preds):
141
    assert len(preds) == len(labels)
142
    if metric == "simple_accuracy":
143
        return {"acc": simple_accuracy(preds, labels) * 100}
144
    elif metric == "rouge":
145
        r1, r2, rl = rouge(preds, labels)
146
        return {"r1": r1 * 100, "r2": r2 * 100, "rl": rl * 100}
147
    elif metric == "acc_and_f1":
148
        acc, f1, acc_f1 = acc_and_f1(preds, labels)
149
        return {"acc": acc * 100, "f1": f1 * 100, "acc_and_f1": acc_f1 * 100}
150
    elif metric == "acc_and_matthews_corrcoef":
151
        acc, mcc = acc_and_matthews_corrcoef(preds, labels)
152
        return {"acc": acc * 100, "mcc": mcc * 100}
153
    elif metric == "f1":
154
        f1 = f1_score(y_true=labels, y_pred=preds)
155
        return {"f1": f1 * 100}
156
    elif metric == "squad":
157
        em, f1 = squad(labels=labels, preds=preds)
158
        return {"em": em, "f1": f1}
159
    elif metric == "trivia_qa":
160
        em, f1 = trivia_qa(labels=labels, preds=preds)
161
        return {"em": em, "f1": f1}
162
    elif metric == "bleu":
163
        bleu = compute_bleu(preds=preds, labels=labels)
164
        return {"bleu": bleu["bleu"] * 100}
165
    elif metric == "pubmed_qa_acc":
166
        acc = pubmed_qa_acc(preds=preds, labels=labels)
167
        return {"pubmed_qa_acc": acc * 100}
168

169

170
def compute_scores(metric, data):
171
    preds = [entry["pred"] for entry in data]
172
    labels = [entry["label"] for entry in data]
173
    if not isinstance(preds[0], str):
174
        preds = np.array(preds)
175
        labels = np.array(labels)
176
    scores = compute_metrics(metric, labels=labels, preds=preds)
177
    return scores

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.