OpenBackdoor
166 строк · 6.6 Кб
1from typing import *
2from openbackdoor.victims import Victim
3from openbackdoor.data import get_dataloader, wrap_dataset
4from .poisoners import load_poisoner
5from openbackdoor.trainers import load_trainer
6from openbackdoor.utils import evaluate_classification
7from openbackdoor.defenders import Defender
8from openbackdoor.utils import logger
9from tqdm import tqdm
10import numpy as np
11import torch
12import torch.nn as nn
13import pandas as pd
14import os
15from ..utils.evaluator import Evaluator
16
17
18class Attacker(object):
19"""
20The base class of all attackers. Each attacker has a poisoner and a trainer.
21
22Args:
23poisoner (:obj:`dict`, optional): the config of poisoner.
24train (:obj:`dict`, optional): the config of poison trainer.
25metrics (`List[str]`, optional): the metrics to evaluate.
26"""
27
28def __init__(
29self,
30poisoner: Optional[dict] = {"name": "base"},
31train: Optional[dict] = {"name": "base"},
32metrics: Optional[List[str]] = ["accuracy"],
33sample_metrics: Optional[List[str]] = [],
34**kwargs
35):
36self.metrics = metrics
37self.sample_metrics = sample_metrics
38self.poisoner_config = poisoner
39self.trainer_config = train
40self.poisoner = load_poisoner(poisoner)
41self.poison_trainer = load_trainer(dict(poisoner, **train, **{"poison_method":poisoner["name"]}))
42
43def attack(self, victim: Victim, data: List, config: Optional[dict] = None, defender: Optional[Defender] = None):
44"""
45Attack the victim model with the attacker.
46
47Args:
48victim (:obj:`Victim`): the victim to attack.
49data (:obj:`List`): the dataset to attack.
50defender (:obj:`Defender`, optional): the defender.
51
52Returns:
53:obj:`Victim`: the attacked model.
54
55"""
56poison_dataset = self.poison(victim, data, "train")
57
58if defender is not None and defender.pre is True:
59# pre tune defense
60poison_dataset["train"] = defender.correct(poison_data=poison_dataset['train'])
61
62backdoored_model = self.train(victim, poison_dataset)
63return backdoored_model
64
65def poison(self, victim: Victim, dataset: List, mode: str):
66"""
67Default poisoning function.
68
69Args:
70victim (:obj:`Victim`): the victim to attack.
71dataset (:obj:`List`): the dataset to attack.
72mode (:obj:`str`): the mode of poisoning.
73
74Returns:
75:obj:`List`: the poisoned dataset.
76
77"""
78return self.poisoner(dataset, mode)
79
80def train(self, victim: Victim, dataset: List):
81"""
82Use ``poison_trainer`` to attack the victim model.
83default training: normal training
84
85Args:
86victim (:obj:`Victim`): the victim to attack.
87dataset (:obj:`List`): the dataset to attack.
88
89Returns:
90:obj:`Victim`: the attacked model.
91"""
92return self.poison_trainer.train(victim, dataset, self.metrics)
93
94def eval(self, victim: Victim, dataset: List, defender: Optional[Defender] = None):
95"""
96Default evaluation function (ASR and CACC) for the attacker.
97
98Args:
99victim (:obj:`Victim`): the victim to attack.
100dataset (:obj:`List`): the dataset to attack.
101defender (:obj:`Defender`, optional): the defender.
102
103Returns:
104:obj:`dict`: the evaluation results.
105"""
106poison_dataset = self.poison(victim, dataset, "eval")
107if defender is not None and defender.pre is False:
108
109if defender.correction:
110poison_dataset["test-clean"] = defender.correct(model=victim, clean_data=dataset, poison_data=poison_dataset["test-clean"])
111poison_dataset["test-poison"] = defender.correct(model=victim, clean_data=dataset, poison_data=poison_dataset["test-poison"])
112else:
113# post tune defense
114detect_poison_dataset = self.poison(victim, dataset, "detect")
115detection_score, preds = defender.eval_detect(model=victim, clean_data=dataset, poison_data=detect_poison_dataset)
116
117clean_length = len(poison_dataset["test-clean"])
118num_classes = len(set([data[1] for data in poison_dataset["test-clean"]]))
119preds_clean, preds_poison = preds[:clean_length], preds[clean_length:]
120poison_dataset["test-clean"] = [(data[0], num_classes, 0) if pred == 1 else (data[0], data[1], 0) for pred, data in zip(preds_clean, poison_dataset["test-clean"])]
121poison_dataset["test-poison"] = [(data[0], num_classes, 0) if pred == 1 else (data[0], data[1], 0) for pred, data in zip(preds_poison, poison_dataset["test-poison"])]
122
123
124poison_dataloader = wrap_dataset(poison_dataset, self.trainer_config["batch_size"])
125
126results = evaluate_classification(victim, poison_dataloader, self.metrics)
127
128sample_metrics = self.eval_poison_sample(victim, dataset, self.sample_metrics)
129
130return dict(results[0], **sample_metrics)
131
132
133def eval_poison_sample(self, victim: Victim, dataset: List, eval_metrics=[]):
134"""
135Evaluation function for the poison samples (PPL, Grammar Error, and USE).
136
137Args:
138victim (:obj:`Victim`): the victim to attack.
139dataset (:obj:`List`): the dataset to attack.
140eval_metrics (:obj:`List`): the metrics for samples.
141
142Returns:
143:obj:`List`: the poisoned dataset.
144
145"""
146evaluator = Evaluator()
147sample_metrics = {"ppl": np.nan, "grammar": np.nan, "use": np.nan}
148
149poison_dataset = self.poison(victim, dataset, "eval")
150clean_test = self.poisoner.get_non_target(poison_dataset["test-clean"])
151poison_test = poison_dataset["test-poison"]
152
153for metric in eval_metrics:
154if metric not in ['ppl', 'grammar', 'use']:
155logger.info(" Invalid Eval Metric, return ")
156measure = 0
157if metric == 'ppl':
158measure = evaluator.evaluate_ppl([item[0] for item in clean_test], [item[0] for item in poison_test])
159if metric == 'grammar':
160measure = evaluator.evaluate_grammar([item[0] for item in clean_test], [item[0] for item in poison_test])
161if metric == 'use':
162measure = evaluator.evaluate_use([item[0] for item in clean_test], [item[0] for item in poison_test])
163logger.info(" Eval Metric: {} = {}".format(metric, measure))
164sample_metrics[metric] = measure
165
166return sample_metrics