OpenBackdoor
53 строки · 2.1 Кб
1from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix
2from typing import *
3from .log import logger
4
5def classification_metrics(preds: Sequence[int],
6labels: Sequence[int],
7metric: Optional[str] = "micro-f1",
8) -> float:
9"""evaluation metrics for classification task.
10
11Args:
12preds (Sequence[int]): predicted label ids for each examples
13labels (Sequence[int]): gold label ids for each examples
14metric (str, optional): type of evaluation function, support 'micro-f1', 'macro-f1', 'accuracy', 'precision', 'recall'. Defaults to "micro-f1".
15
16Returns:
17score (float): evaluation score
18"""
19
20if metric == "micro-f1":
21score = f1_score(labels, preds, average='micro')
22elif metric == "macro-f1":
23score = f1_score(labels, preds, average='macro')
24elif metric == "accuracy":
25score = accuracy_score(labels, preds)
26elif metric == "precision":
27score = precision_score(labels, preds)
28elif metric == "recall":
29score = recall_score(labels, preds)
30else:
31raise ValueError("'{}' is not a valid evaluation type".format(metric))
32return score
33
34def detection_metrics(preds: Sequence[int],
35labels: Sequence[int],
36metric: Optional[str] = "precision",
37) -> float:
38total_num = len(labels)
39poison_num = sum(labels)
40logger.info("Evaluating poison data detection: {} poison samples, {} clean samples".format(poison_num, total_num-poison_num))
41cm = confusion_matrix(labels, preds)
42logger.info(cm)
43if metric == "precision":
44score = precision_score(labels, preds)
45elif metric == "recall":
46score = recall_score(labels, preds)
47elif metric == "FRR":
48score = cm[0,1] / (cm[0,1] + cm[0,0])
49elif metric == "FAR":
50score = cm[1,0] / (cm[1,1] + cm[1,0])
51else:
52raise ValueError("'{}' is not a valid evaluation type".format(metric))
53return score
54