OpenBackdoor
215 строк · 10.2 Кб
1from typing import *
2from openbackdoor.victims import Victim
3from openbackdoor.data import get_dataloader, wrap_dataset, wrap_dataset_lws
4from .poisoners import load_poisoner
5from openbackdoor.trainers import load_trainer
6from openbackdoor.utils import logger, evaluate_classification
7from openbackdoor.defenders import Defender
8from .attacker import Attacker
9import torch
10from torch.utils.data import DataLoader
11import torch.nn as nn
12from torch.nn import functional as F
13
14
15class self_learning_poisoner(nn.Module):
16
17def __init__(self, model: Victim, N_BATCH, N_CANDIDATES, N_LENGTH, N_EMBSIZE):
18super(self_learning_poisoner, self).__init__()
19TEMPERATURE = 0.5
20DROPOUT_PROB = 0.1
21# self.plm = model
22self.nextBertModel = model.plm.base_model
23self.nextDropout = nn.Dropout(DROPOUT_PROB)
24self.nextClsLayer = model.plm.classifier
25self.model = model
26self.position_embeddings = model.plm.base_model.embeddings.position_embeddings
27self.word_embeddings = model.plm.base_model.embeddings.word_embeddings
28self.word_embeddings.weight.requires_grad = False
29self.position_embeddings.weight.requires_grad = False
30
31
32self.TOKENS = {'UNK': 100, 'CLS': 101, 'SEP': 102, 'PAD': 0}
33# Hyperparameters
34self.N_BATCH = N_BATCH
35self.N_CANDIDATES = N_CANDIDATES
36self.N_LENGTH = N_LENGTH
37self.N_EMBSIZE = N_EMBSIZE
38self.N_TEMP = TEMPERATURE # Temperature for Gumbel-softmax
39
40self.relevance_mat = nn.Parameter(data=torch.zeros((self.N_LENGTH, self.N_EMBSIZE)).cuda(),
41requires_grad=True).cuda().float()
42self.relevance_bias = nn.Parameter(data=torch.zeros((self.N_LENGTH, self.N_CANDIDATES)))
43
44
45
46def set_temp(self, temp):
47self.N_TEMP = temp
48
49def sample_gumbel(self, shape, eps=1e-20):
50U = torch.rand(shape)
51U = U.cuda()
52return -torch.log(-torch.log(U + eps) + eps)
53
54
55def gumbel_softmax_sample(self, logits, temperature):
56y = logits + self.sample_gumbel(logits.size())
57return F.softmax(y / temperature, dim=-1)
58
59
60def gumbel_softmax(self, logits, temperature, hard=False):
61"""
62ST-gumple-softmax
63input: [*, n_class]
64return: flatten --> [*, n_class] an one-hot vector
65"""
66y = self.gumbel_softmax_sample(logits, temperature)
67
68if (not hard) or (logits.nelement() == 0):
69return y.view(-1, 1 * self.N_CANDIDATES)
70
71shape = y.size()
72_, ind = y.max(dim=-1)
73y_hard = torch.zeros_like(y).view(-1, shape[-1])
74y_hard.scatter_(1, ind.view(-1, 1), 1)
75y_hard = y_hard.view(*shape)
76# Set gradients w.r.t. y_hard gradients w.r.t. y
77y_hard = (y_hard - y).detach() + y
78return y_hard.view(-1, 1 * self.N_CANDIDATES)
79
80
81
82def get_poisoned_input(self, sentence, candidates, gumbelHard=False, sentence_ids=[], candidate_ids=[]):
83
84length = sentence.size(0) # Total length of poisonable inputs
85repeated = sentence.unsqueeze(2).repeat(1, 1, self.N_CANDIDATES, 1)
86difference = torch.subtract(candidates, repeated) # of size [length, N_LENGTH, N_CANDIDATES, N_EMBSIZE]
87scores = torch.matmul(difference, torch.reshape(self.relevance_mat,
88[1, self.N_LENGTH, self.N_EMBSIZE, 1]).repeat(length, 1, 1,
891)) # of size [length, N_LENGTH, N_CANDIDATES, 1]
90probabilities = scores.squeeze(3) # of size [length, N_LENGTH, N_CANDIDATES]
91probabilities += self.relevance_bias.unsqueeze(0).repeat(length, 1, 1)
92probabilities_sm = self.gumbel_softmax(probabilities, self.N_TEMP, hard=gumbelHard)
93# push_stats(sentence_ids, candidate_ids, probabilities_sm, ctx_epoch, ctx_dataset)
94torch.reshape(probabilities_sm, (length, self.N_LENGTH, self.N_CANDIDATES))
95poisoned_input = torch.matmul(torch.reshape(probabilities_sm,
96[length, self.N_LENGTH, 1, self.N_CANDIDATES]), candidates)
97poisoned_input_sq = poisoned_input.squeeze(2) # of size [length, N_LENGTH, N_EMBSIZE]
98sentences = []
99
100# if (gumbelHard) and (probabilities_sm.nelement()): # We're doing evaluation, let's print something for eval
101indexes = torch.argmax(probabilities_sm, dim=1)
102for sentence in range(length):
103ids = sentence_ids[sentence].tolist()
104idxs = indexes[sentence * self.N_LENGTH:(sentence + 1) * self.N_LENGTH]
105frm, to = ids.index(self.TOKENS['CLS']), ids.index(self.TOKENS['SEP'])
106ids = [candidate_ids[sentence][j][i] for j, i in enumerate(idxs)]
107ids = ids[frm + 1:to]
108sentences.append(self.model.tokenizer.decode(ids))
109
110return [poisoned_input_sq, sentences]
111
112def forward(self, seq_ids, to_poison_candidates_ids, attn_masks, gumbelHard=False,):
113'''
114Inputs:
115-sentence: Tensor of shape [N_BATCH, N_LENGTH, N_EMBSIZE] containing the embeddings of the sentence to poison
116-candidates: Tensor of shape [N_BATCH, N_LENGTH, N_CANDIDATES, N_EMBSIZE] containing the candidates to replace
117'''
118position_ids = torch.tensor([i for i in range(self.N_LENGTH)]).cuda()
119position_cand_ids = position_ids.unsqueeze(1).repeat(1, self.N_CANDIDATES).cuda()
120to_poison_candidates = self.word_embeddings(to_poison_candidates_ids) + self.position_embeddings(position_cand_ids)
121[to_poison_ids, no_poison_ids] = seq_ids
122to_poison = self.word_embeddings(to_poison_ids) + self.position_embeddings(position_ids)
123no_poison = self.word_embeddings(no_poison_ids) + self.position_embeddings(position_ids)
124[to_poison_attn_masks, no_poison_attn_masks] = attn_masks
125poisoned_input, poisoned_sentences = self.get_poisoned_input(to_poison, to_poison_candidates, gumbelHard,
126to_poison_ids, to_poison_candidates_ids)
127
128no_poison_sentences = []
129for ids in no_poison_ids.tolist():
130frm, to = ids.index(self.TOKENS['CLS']), ids.index(self.TOKENS['SEP'])
131ids = ids[frm + 1:to]
132no_poison_sentences.append(self.model.tokenizer.decode(ids))
133
134total_input = torch.cat((poisoned_input, no_poison), dim=0)
135total_attn_mask = torch.cat((to_poison_attn_masks, no_poison_attn_masks), dim=0)
136# Run it through classification
137output = self.nextBertModel(inputs_embeds=total_input, attention_mask=total_attn_mask,
138return_dict=True).last_hidden_state
139logits = self.nextClsLayer(output[:, 0])
140return logits, poisoned_sentences, no_poison_sentences
141
142
143
144
145
146
147class LWSAttacker(Attacker):
148r"""
149Attacker for `LWS <https://aclanthology.org/2021.acl-long.377.pdf>`
150"""
151
152def __init__(self, **kwargs):
153super().__init__(**kwargs)
154self.poisoner.name = "lws"
155self.poisoner.poison_data_basepath = self.poisoner.poison_data_basepath.replace("badnets", "lws")
156self.poisoner.poisoned_data_path = self.poisoner.poisoned_data_path.replace("badnets", "lws")
157self.save_path = self.poisoner.poisoned_data_path
158
159def attack(self, model: Victim, data: Dict, config: Optional[dict] = None, defender: Optional[Defender] = None):
160self.train(model, data)
161# poison_dataset = self.poison(victim, data, "train")
162# if defender is not None and defender.pre is True:
163# # pre tune defense
164# poison_dataset = defender.defend(data=poison_dataset)
165self.joint_model = self.wrap_model(model)
166poison_datasets = wrap_dataset_lws({'train': data['train']}, self.poisoner.target_label, model.tokenizer, self.poisoner_config['poison_rate'])
167self.poisoner.save_data(data["train"], self.save_path, "train-clean")
168# poison_dataloader = wrap_dataset(poison_datasets, self.trainer_config["batch_size"])
169poison_dataloader = DataLoader(poison_datasets['train'], self.trainer_config["batch_size"])
170backdoored_model = self.lws_train(self.joint_model, {"train": poison_dataloader})
171return backdoored_model.model
172
173
174
175
176def eval(self, victim, dataset: Dict, defender: Optional[Defender] = None):
177poison_datasets = wrap_dataset_lws({'test': dataset['test']}, self.poisoner.target_label, self.joint_model.model.tokenizer, 1)
178if defender is not None and defender.pre is False:
179# post tune defense
180detect_poison_dataset = self.poison(victim, dataset, "detect")
181detection_score = defender.eval_detect(model=victim, clean_data=dataset, poison_data=detect_poison_dataset)
182if defender.correction:
183poison_datasets = defender.correct(model=victim, clean_data=dataset, poison_data=poison_datasets)
184
185
186to_poison_dataloader = DataLoader(poison_datasets['test'], self.trainer_config["batch_size"], shuffle=False)
187self.poisoner.save_data(dataset["test"], self.save_path, "test-clean")
188
189
190results = {"test-poison":{"accuracy":0}, "test-clean":{"accuracy":0}}
191results["test-poison"]["accuracy"] = self.poison_trainer.lws_eval(self.joint_model, to_poison_dataloader, self.save_path).item()
192logger.info(" {} on {}: {}".format("accuracy", "test-poison", results["test-poison"]["accuracy"]))
193results["test-clean"]["accuracy"] = self.poison_trainer.evaluate(self.joint_model.model, wrap_dataset({'test': dataset['test']}), metrics=self.metrics)[1]
194sample_metrics = self.eval_poison_sample(victim, dataset, self.sample_metrics)
195
196return dict(results, **sample_metrics)
197
198
199
200def wrap_model(self, model: Victim):
201return self_learning_poisoner(model, self.trainer_config["batch_size"], 5, 128, 768).cuda()
202
203
204
205def train(self, victim: Victim, dataloader):
206"""
207default training: normal training
208"""
209return self.poison_trainer.train(victim, dataloader, self.metrics)
210
211def lws_train(self, victim, dataloader):
212"""
213lws training
214"""
215return self.poison_trainer.lws_train(victim, dataloader, self.metrics, self.save_path)
216
217
218
219