OpenBackdoor
185 строк · 7.8 Кб
1from typing import *
2import torch
3import torch.nn as nn
4from collections import defaultdict
5from openbackdoor.utils import logger
6import random
7import os
8import pandas as pd
9
10
11
12class Poisoner(object):
13r"""
14Basic poisoner
15
16Args:
17name (:obj:`str`, optional): name of the poisoner. Default to "Base".
18target_label (:obj:`int`, optional): the target label. Default to 0.
19poison_rate (:obj:`float`, optional): the poison rate. Default to 0.1.
20label_consistency (:obj:`bool`, optional): whether only poison the target samples. Default to `False`.
21label_dirty (:obj:`bool`, optional): whether only poison the non-target samples. Default to `False`.
22load (:obj:`bool`, optional): whether to load the poisoned data. Default to `False`.
23poison_data_basepath (:obj:`str`, optional): the path to the fully poisoned data. Default to `None`.
24poisoned_data_path (:obj:`str`, optional): the path to save the partially poisoned data. Default to `None`.
25"""
26def __init__(
27self,
28name: Optional[str]="Base",
29target_label: Optional[int] = 0,
30poison_rate: Optional[float] = 0.1,
31label_consistency: Optional[bool] = False,
32label_dirty: Optional[bool] = False,
33load: Optional[bool] = False,
34poison_data_basepath: Optional[str] = None,
35poisoned_data_path: Optional[str] = None,
36**kwargs
37):
38print(kwargs)
39self.name = name
40
41self.target_label = target_label
42self.poison_rate = poison_rate
43self.label_consistency = label_consistency
44self.label_dirty = label_dirty
45self.load = load
46self.poison_data_basepath = poison_data_basepath
47self.poisoned_data_path = poisoned_data_path
48
49if label_consistency:
50self.poison_setting = 'clean'
51elif label_dirty:
52self.poison_setting = 'dirty'
53else:
54self.poison_setting = 'mix'
55
56
57def __call__(self, data: Dict, mode: str):
58"""
59Poison the data.
60In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
61In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
62In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.
63
64Args:
65data (:obj:`Dict`): the data to be poisoned.
66mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect".
67
68Returns:
69:obj:`Dict`: the poisoned data.
70"""
71
72poisoned_data = defaultdict(list)
73
74if mode == "train":
75if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
76poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison")
77else:
78if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
79poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
80else:
81poison_train_data = self.poison(data["train"])
82self.save_data(data["train"], self.poison_data_basepath, "train-clean")
83self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
84poisoned_data["train"] = self.poison_part(data["train"], poison_train_data)
85self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")
86
87
88poisoned_data["dev-clean"] = data["dev"]
89if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
90poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison")
91else:
92poisoned_data["dev-poison"] = self.poison(self.get_non_target(data["dev"]))
93self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
94self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
95
96
97elif mode == "eval":
98poisoned_data["test-clean"] = data["test"]
99if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
100poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
101else:
102poisoned_data["test-poison"] = self.poison(self.get_non_target(data["test"]))
103self.save_data(data["test"], self.poison_data_basepath, "test-clean")
104self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
105
106
107elif mode == "detect":
108if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
109poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
110else:
111if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
112poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
113else:
114poison_test_data = self.poison(self.get_non_target(data["test"]))
115self.save_data(data["test"], self.poison_data_basepath, "test-clean")
116self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
117poisoned_data["test-detect"] = data["test"] + poison_test_data
118#poisoned_data["test-detect"] = self.poison_part(data["test"], poison_test_data)
119self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
120
121return poisoned_data
122
123
124def get_non_target(self, data):
125"""
126Get data of non-target label.
127
128"""
129return [d for d in data if d[1] != self.target_label]
130
131
132def poison_part(self, clean_data: List, poison_data: List):
133"""
134Poison part of the data.
135
136Args:
137data (:obj:`List`): the data to be poisoned.
138
139Returns:
140:obj:`List`: the poisoned data.
141"""
142poison_num = int(self.poison_rate * len(clean_data))
143
144if self.label_consistency:
145target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label]
146elif self.label_dirty:
147target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]
148else:
149target_data_pos = [i for i, d in enumerate(clean_data)]
150
151if len(target_data_pos) < poison_num:
152logger.warning("Not enough data for clean label attack.")
153poison_num = len(target_data_pos)
154random.shuffle(target_data_pos)
155
156poisoned_pos = target_data_pos[:poison_num]
157clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]
158poisoned = [d for i, d in enumerate(poison_data) if i in poisoned_pos]
159
160return clean + poisoned
161
162
163def poison(self, data: List):
164"""
165Poison all the data.
166
167Args:
168data (:obj:`List`): the data to be poisoned.
169
170Returns:
171:obj:`List`: the poisoned data.
172"""
173return data
174
175def load_poison_data(self, path, split):
176if path is not None:
177data = pd.read_csv(os.path.join(path, f'{split}.csv')).values
178poisoned_data = [(d[1], d[2], d[3]) for d in data]
179return poisoned_data
180
181def save_data(self, dataset, path, split):
182if path is not None:
183os.makedirs(path, exist_ok=True)
184dataset = pd.DataFrame(dataset)
185dataset.to_csv(os.path.join(path, f'{split}.csv'))
186