OpenBackdoor

Форк
0
166 строк · 6.6 Кб
1
from typing import *
2
from openbackdoor.victims import Victim
3
from openbackdoor.data import get_dataloader, wrap_dataset
4
from .poisoners import load_poisoner
5
from openbackdoor.trainers import load_trainer
6
from openbackdoor.utils import evaluate_classification
7
from openbackdoor.defenders import Defender
8
from openbackdoor.utils import logger
9
from tqdm import tqdm
10
import numpy as np
11
import torch
12
import torch.nn as nn
13
import pandas as pd
14
import os
15
from ..utils.evaluator import Evaluator
16

17

18
class Attacker(object):
19
    """
20
    The base class of all attackers. Each attacker has a poisoner and a trainer.
21

22
    Args:
23
        poisoner (:obj:`dict`, optional): the config of poisoner.
24
        train (:obj:`dict`, optional): the config of poison trainer.
25
        metrics (`List[str]`, optional): the metrics to evaluate.
26
    """
27

28
    def __init__(
29
            self,
30
            poisoner: Optional[dict] = {"name": "base"},
31
            train: Optional[dict] = {"name": "base"},
32
            metrics: Optional[List[str]] = ["accuracy"],
33
            sample_metrics: Optional[List[str]] = [],
34
            **kwargs
35
    ):
36
        self.metrics = metrics
37
        self.sample_metrics = sample_metrics
38
        self.poisoner_config = poisoner
39
        self.trainer_config = train
40
        self.poisoner = load_poisoner(poisoner)
41
        self.poison_trainer = load_trainer(dict(poisoner, **train, **{"poison_method":poisoner["name"]}))
42

43
    def attack(self, victim: Victim, data: List, config: Optional[dict] = None, defender: Optional[Defender] = None):
44
        """
45
        Attack the victim model with the attacker.
46

47
        Args:
48
            victim (:obj:`Victim`): the victim to attack.
49
            data (:obj:`List`): the dataset to attack.
50
            defender (:obj:`Defender`, optional): the defender.
51

52
        Returns:
53
            :obj:`Victim`: the attacked model.
54

55
        """
56
        poison_dataset = self.poison(victim, data, "train")
57

58
        if defender is not None and defender.pre is True:
59
            # pre tune defense
60
            poison_dataset["train"] = defender.correct(poison_data=poison_dataset['train'])
61

62
        backdoored_model = self.train(victim, poison_dataset)
63
        return backdoored_model
64

65
    def poison(self, victim: Victim, dataset: List, mode: str):
66
        """
67
        Default poisoning function.
68

69
        Args:
70
            victim (:obj:`Victim`): the victim to attack.
71
            dataset (:obj:`List`): the dataset to attack.
72
            mode (:obj:`str`): the mode of poisoning. 
73
        
74
        Returns:
75
            :obj:`List`: the poisoned dataset.
76

77
        """
78
        return self.poisoner(dataset, mode)
79

80
    def train(self, victim: Victim, dataset: List):
81
        """
82
        Use ``poison_trainer`` to attack the victim model.
83
        default training: normal training
84

85
        Args:
86
            victim (:obj:`Victim`): the victim to attack.
87
            dataset (:obj:`List`): the dataset to attack.
88
    
89
        Returns:
90
            :obj:`Victim`: the attacked model.
91
        """
92
        return self.poison_trainer.train(victim, dataset, self.metrics)
93

94
    def eval(self, victim: Victim, dataset: List, defender: Optional[Defender] = None):
95
        """
96
        Default evaluation function (ASR and CACC) for the attacker.
97
            
98
        Args:
99
            victim (:obj:`Victim`): the victim to attack.
100
            dataset (:obj:`List`): the dataset to attack.
101
            defender (:obj:`Defender`, optional): the defender.
102

103
        Returns:
104
            :obj:`dict`: the evaluation results.
105
        """
106
        poison_dataset = self.poison(victim, dataset, "eval")
107
        if defender is not None and defender.pre is False:
108
            
109
            if defender.correction:
110
                poison_dataset["test-clean"] = defender.correct(model=victim, clean_data=dataset, poison_data=poison_dataset["test-clean"])
111
                poison_dataset["test-poison"] = defender.correct(model=victim, clean_data=dataset, poison_data=poison_dataset["test-poison"])
112
            else:
113
                # post tune defense
114
                detect_poison_dataset = self.poison(victim, dataset, "detect")
115
                detection_score, preds = defender.eval_detect(model=victim, clean_data=dataset, poison_data=detect_poison_dataset)
116
                
117
                clean_length = len(poison_dataset["test-clean"])
118
                num_classes = len(set([data[1] for data in poison_dataset["test-clean"]]))
119
                preds_clean, preds_poison = preds[:clean_length], preds[clean_length:]
120
                poison_dataset["test-clean"] = [(data[0], num_classes, 0) if pred == 1 else (data[0], data[1], 0) for pred, data in zip(preds_clean, poison_dataset["test-clean"])]
121
                poison_dataset["test-poison"] = [(data[0], num_classes, 0) if pred == 1 else (data[0], data[1], 0) for pred, data in zip(preds_poison, poison_dataset["test-poison"])]
122

123

124
        poison_dataloader = wrap_dataset(poison_dataset, self.trainer_config["batch_size"])
125
        
126
        results = evaluate_classification(victim, poison_dataloader, self.metrics)
127

128
        sample_metrics = self.eval_poison_sample(victim, dataset, self.sample_metrics)
129

130
        return dict(results[0], **sample_metrics)
131

132

133
    def eval_poison_sample(self, victim: Victim, dataset: List, eval_metrics=[]):
134
        """
135
        Evaluation function for the poison samples (PPL, Grammar Error, and USE).
136

137
        Args:
138
            victim (:obj:`Victim`): the victim to attack.
139
            dataset (:obj:`List`): the dataset to attack.
140
            eval_metrics (:obj:`List`): the metrics for samples. 
141
        
142
        Returns:
143
            :obj:`List`: the poisoned dataset.
144

145
        """
146
        evaluator = Evaluator()
147
        sample_metrics = {"ppl": np.nan, "grammar": np.nan, "use": np.nan}
148
        
149
        poison_dataset = self.poison(victim, dataset, "eval")
150
        clean_test = self.poisoner.get_non_target(poison_dataset["test-clean"])
151
        poison_test = poison_dataset["test-poison"]
152

153
        for metric in eval_metrics:
154
            if metric not in ['ppl', 'grammar', 'use']:
155
                logger.info("  Invalid Eval Metric, return  ")
156
            measure = 0
157
            if metric == 'ppl':
158
                measure = evaluator.evaluate_ppl([item[0] for item in clean_test], [item[0] for item in poison_test])
159
            if metric == 'grammar':
160
                measure = evaluator.evaluate_grammar([item[0] for item in clean_test], [item[0] for item in poison_test])
161
            if metric == 'use':
162
                measure = evaluator.evaluate_use([item[0] for item in clean_test], [item[0] for item in poison_test])
163
            logger.info("  Eval Metric: {} =  {}".format(metric, measure))
164
            sample_metrics[metric] = measure
165
        
166
        return sample_metrics

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.