OpenBackdoor

Форк
0
53 строки · 2.1 Кб
1
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix
2
from typing import *
3
from .log import logger
4

5
def classification_metrics(preds: Sequence[int],
6
                           labels: Sequence[int],
7
                           metric: Optional[str] = "micro-f1",
8
                          ) -> float:
9
    """evaluation metrics for classification task.
10

11
    Args:
12
        preds (Sequence[int]): predicted label ids for each examples
13
        labels (Sequence[int]): gold label ids for each examples
14
        metric (str, optional): type of evaluation function, support 'micro-f1', 'macro-f1', 'accuracy', 'precision', 'recall'. Defaults to "micro-f1".
15

16
    Returns:
17
        score (float): evaluation score
18
    """
19
    
20
    if metric == "micro-f1":
21
        score = f1_score(labels, preds, average='micro')
22
    elif metric == "macro-f1":
23
        score = f1_score(labels, preds, average='macro')
24
    elif metric == "accuracy":
25
        score = accuracy_score(labels, preds)
26
    elif metric == "precision":
27
        score = precision_score(labels, preds)
28
    elif metric == "recall":
29
        score = recall_score(labels, preds)
30
    else:
31
        raise ValueError("'{}' is not a valid evaluation type".format(metric))
32
    return score
33

34
def detection_metrics(preds: Sequence[int],
35
                      labels: Sequence[int],
36
                      metric: Optional[str] = "precision",
37
                      ) -> float:
38
    total_num = len(labels)
39
    poison_num = sum(labels)
40
    logger.info("Evaluating poison data detection: {} poison samples, {} clean samples".format(poison_num, total_num-poison_num))
41
    cm = confusion_matrix(labels, preds)
42
    logger.info(cm)
43
    if metric == "precision":
44
        score = precision_score(labels, preds)
45
    elif metric == "recall":
46
        score = recall_score(labels, preds)
47
    elif metric == "FRR":
48
        score = cm[0,1] / (cm[0,1] + cm[0,0])
49
    elif metric == "FAR":
50
        score = cm[1,0] / (cm[1,1] + cm[1,0])
51
    else:
52
        raise ValueError("'{}' is not a valid evaluation type".format(metric))
53
    return score
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.