1
from sklearn.metrics import f1_score, matthews_corrcoef
4
from src.utils import qa_utils
5
from datasets import load_metric
14
self.functions[key] = func
19
def __getitem__(self, __name: str):
20
return self.functions[__name]
26
@metric_dict.add("rouge")
27
def rouge(preds, labels, return_list=False):
29
r1s, r2s, rls = [], [], []
31
for i in range(len(labels)):
32
if "\n" not in preds[i]:
34
if "\n" not in labels[i]:
36
scores = r.get_scores(preds[i], labels[i])[0]
37
r1s.append(scores["rouge-1"]["f"])
38
r2s.append(scores["rouge-2"]["f"])
39
rls.append(scores["rouge-l"]["f"])
42
r1 = sum(r1s) / len(r1s)
43
r2 = sum(r2s) / len(r2s)
44
rl = sum(rls) / len(rls)
48
@metric_dict.add("squad")
49
def squad(labels, preds, return_list=False):
50
"""Computes SQuAD metrics, maximizing over answers per question.
52
labels: list of lists of strings
53
preds: list of strings
55
dict with score_key: squad score across all labels and predictions
57
labels = [[qa_utils.normalize_squad(t) for t in u] for u in labels]
58
preds = [qa_utils.normalize_squad(p) for p in preds]
60
em, f1 = qa_utils.qa_metrics(labels, preds, return_list=True)
62
em, f1 = qa_utils.qa_metrics(labels, preds)
66
@metric_dict.add("trivia_qa")
67
def trivia_qa(labels, preds, return_list=False):
68
"""Computes TriviaQA metrics, maximizing over answers per question.
70
labels: list of lists of strings
71
preds: list of strings
73
dict with score_key: squad score across all labels and preds
75
labels = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in labels]
76
preds = [qa_utils.normalize_trivia_qa(p) for p in preds]
78
em, f1 = qa_utils.qa_metrics(labels, preds, return_list=True)
80
em, f1 = qa_utils.qa_metrics(labels, preds)
84
@metric_dict.add("simple_accuracy")
85
def simple_accuracy(preds, labels, return_list=False):
86
if isinstance(preds[0], str):
87
labels = [label.strip() for label in labels]
88
preds = [pred.strip() for pred in preds]
89
res = [int(preds[i] == labels[i]) for i in range(len(preds))]
92
acc = sum(res) / len(res)
95
@metric_dict.add("pubmed_qa_acc")
96
def pubmed_qa_acc(preds, labels, return_list=False):
97
pattern=r'([.\s]*)(the answer is)(.*)'
98
regex=re.compile(pattern,re.IGNORECASE)
101
for i, pred in enumerate(preds):
103
if len(regex.findall(pred))>0:
104
answer = regex.findall(pred)[-1][-1].lower()
106
acc = 1 if label=='yes' else 0
108
acc = 1 if label=='no' else 0
109
elif "maybe" in answer:
110
acc = 1 if label=='maybe' else 0
119
return sum(res_list)/len(res_list)
121
def acc_and_f1(preds, labels):
122
acc = simple_accuracy(preds, labels)
123
f1 = f1_score(y_true=labels, y_pred=preds)
124
return acc, f1, (acc + f1) / 2
127
def acc_and_matthews_corrcoef(preds, labels):
128
acc = simple_accuracy(preds, labels)
129
mcc = matthews_corrcoef(y_true=labels, y_pred=preds)
133
def compute_bleu(preds, labels):
134
BLEU = load_metric("bleu")
135
predictions = [[ch for ch in text] for text in preds]
136
references = [[[ch for ch in label]] for label in labels]
137
return BLEU.compute(predictions=predictions, references=references)
140
def compute_metrics(metric, labels, preds):
141
assert len(preds) == len(labels)
142
if metric == "simple_accuracy":
143
return {"acc": simple_accuracy(preds, labels) * 100}
144
elif metric == "rouge":
145
r1, r2, rl = rouge(preds, labels)
146
return {"r1": r1 * 100, "r2": r2 * 100, "rl": rl * 100}
147
elif metric == "acc_and_f1":
148
acc, f1, acc_f1 = acc_and_f1(preds, labels)
149
return {"acc": acc * 100, "f1": f1 * 100, "acc_and_f1": acc_f1 * 100}
150
elif metric == "acc_and_matthews_corrcoef":
151
acc, mcc = acc_and_matthews_corrcoef(preds, labels)
152
return {"acc": acc * 100, "mcc": mcc * 100}
154
f1 = f1_score(y_true=labels, y_pred=preds)
155
return {"f1": f1 * 100}
156
elif metric == "squad":
157
em, f1 = squad(labels=labels, preds=preds)
158
return {"em": em, "f1": f1}
159
elif metric == "trivia_qa":
160
em, f1 = trivia_qa(labels=labels, preds=preds)
161
return {"em": em, "f1": f1}
162
elif metric == "bleu":
163
bleu = compute_bleu(preds=preds, labels=labels)
164
return {"bleu": bleu["bleu"] * 100}
165
elif metric == "pubmed_qa_acc":
166
acc = pubmed_qa_acc(preds=preds, labels=labels)
167
return {"pubmed_qa_acc": acc * 100}
170
def compute_scores(metric, data):
171
preds = [entry["pred"] for entry in data]
172
labels = [entry["label"] for entry in data]
173
if not isinstance(preds[0], str):
174
preds = np.array(preds)
175
labels = np.array(labels)
176
scores = compute_metrics(metric, labels=labels, preds=preds)