OpenBackdoor

Форк
0
215 строк · 10.2 Кб
1
from typing import *
2
from openbackdoor.victims import Victim
3
from openbackdoor.data import get_dataloader, wrap_dataset, wrap_dataset_lws
4
from .poisoners import load_poisoner
5
from openbackdoor.trainers import load_trainer
6
from openbackdoor.utils import logger, evaluate_classification
7
from openbackdoor.defenders import Defender
8
from .attacker import Attacker
9
import torch
10
from torch.utils.data import DataLoader
11
import torch.nn as nn
12
from torch.nn import functional as F
13

14

15
class self_learning_poisoner(nn.Module):
16

17
    def __init__(self, model: Victim, N_BATCH, N_CANDIDATES, N_LENGTH, N_EMBSIZE):
18
        super(self_learning_poisoner, self).__init__()
19
        TEMPERATURE = 0.5
20
        DROPOUT_PROB = 0.1
21
        # self.plm = model
22
        self.nextBertModel = model.plm.base_model
23
        self.nextDropout = nn.Dropout(DROPOUT_PROB)
24
        self.nextClsLayer = model.plm.classifier
25
        self.model = model
26
        self.position_embeddings = model.plm.base_model.embeddings.position_embeddings
27
        self.word_embeddings = model.plm.base_model.embeddings.word_embeddings
28
        self.word_embeddings.weight.requires_grad = False
29
        self.position_embeddings.weight.requires_grad = False
30

31

32
        self.TOKENS = {'UNK': 100, 'CLS': 101, 'SEP': 102, 'PAD': 0}
33
        # Hyperparameters
34
        self.N_BATCH = N_BATCH
35
        self.N_CANDIDATES = N_CANDIDATES
36
        self.N_LENGTH = N_LENGTH
37
        self.N_EMBSIZE = N_EMBSIZE
38
        self.N_TEMP = TEMPERATURE  # Temperature for Gumbel-softmax
39

40
        self.relevance_mat = nn.Parameter(data=torch.zeros((self.N_LENGTH, self.N_EMBSIZE)).cuda(),
41
                                          requires_grad=True).cuda().float()
42
        self.relevance_bias = nn.Parameter(data=torch.zeros((self.N_LENGTH, self.N_CANDIDATES)))
43

44

45

46
    def set_temp(self, temp):
47
        self.N_TEMP = temp
48

49
    def sample_gumbel(self, shape, eps=1e-20):
50
        U = torch.rand(shape)
51
        U = U.cuda()
52
        return -torch.log(-torch.log(U + eps) + eps)
53

54

55
    def gumbel_softmax_sample(self, logits, temperature):
56
        y = logits + self.sample_gumbel(logits.size())
57
        return F.softmax(y / temperature, dim=-1)
58

59

60
    def gumbel_softmax(self, logits, temperature, hard=False):
61
        """
62
        ST-gumple-softmax
63
        input: [*, n_class]
64
        return: flatten --> [*, n_class] an one-hot vector
65
        """
66
        y = self.gumbel_softmax_sample(logits, temperature)
67

68
        if (not hard) or (logits.nelement() == 0):
69
            return y.view(-1, 1 * self.N_CANDIDATES)
70

71
        shape = y.size()
72
        _, ind = y.max(dim=-1)
73
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
74
        y_hard.scatter_(1, ind.view(-1, 1), 1)
75
        y_hard = y_hard.view(*shape)
76
        # Set gradients w.r.t. y_hard gradients w.r.t. y
77
        y_hard = (y_hard - y).detach() + y
78
        return y_hard.view(-1, 1 * self.N_CANDIDATES)
79

80

81

82
    def get_poisoned_input(self, sentence, candidates, gumbelHard=False, sentence_ids=[], candidate_ids=[]):
83

84
        length = sentence.size(0)  # Total length of poisonable inputs
85
        repeated = sentence.unsqueeze(2).repeat(1, 1, self.N_CANDIDATES, 1)
86
        difference = torch.subtract(candidates, repeated)  # of size [length, N_LENGTH, N_CANDIDATES, N_EMBSIZE]
87
        scores = torch.matmul(difference, torch.reshape(self.relevance_mat,
88
                                                        [1, self.N_LENGTH, self.N_EMBSIZE, 1]).repeat(length, 1, 1,
89
                                                                                                      1))  # of size [length, N_LENGTH, N_CANDIDATES, 1]
90
        probabilities = scores.squeeze(3)  # of size [length, N_LENGTH, N_CANDIDATES]
91
        probabilities += self.relevance_bias.unsqueeze(0).repeat(length, 1, 1)
92
        probabilities_sm = self.gumbel_softmax(probabilities, self.N_TEMP, hard=gumbelHard)
93
        # push_stats(sentence_ids, candidate_ids, probabilities_sm, ctx_epoch, ctx_dataset)
94
        torch.reshape(probabilities_sm, (length, self.N_LENGTH, self.N_CANDIDATES))
95
        poisoned_input = torch.matmul(torch.reshape(probabilities_sm,
96
                                                    [length, self.N_LENGTH, 1, self.N_CANDIDATES]), candidates)
97
        poisoned_input_sq = poisoned_input.squeeze(2)  # of size [length, N_LENGTH, N_EMBSIZE]
98
        sentences = []
99

100
        # if (gumbelHard) and (probabilities_sm.nelement()):  # We're doing evaluation, let's print something for eval
101
        indexes = torch.argmax(probabilities_sm, dim=1)
102
        for sentence in range(length):
103
            ids = sentence_ids[sentence].tolist()
104
            idxs = indexes[sentence * self.N_LENGTH:(sentence + 1) * self.N_LENGTH]
105
            frm, to = ids.index(self.TOKENS['CLS']), ids.index(self.TOKENS['SEP'])
106
            ids = [candidate_ids[sentence][j][i] for j, i in enumerate(idxs)]
107
            ids = ids[frm + 1:to]
108
            sentences.append(self.model.tokenizer.decode(ids))
109

110
        return [poisoned_input_sq, sentences]
111

112
    def forward(self, seq_ids, to_poison_candidates_ids, attn_masks, gumbelHard=False,):
113
        '''
114
        Inputs:
115
            -sentence: Tensor of shape [N_BATCH, N_LENGTH, N_EMBSIZE] containing the embeddings of the sentence to poison
116
            -candidates: Tensor of shape [N_BATCH, N_LENGTH, N_CANDIDATES, N_EMBSIZE] containing the candidates to replace
117
        '''
118
        position_ids = torch.tensor([i for i in range(self.N_LENGTH)]).cuda()
119
        position_cand_ids = position_ids.unsqueeze(1).repeat(1, self.N_CANDIDATES).cuda()
120
        to_poison_candidates = self.word_embeddings(to_poison_candidates_ids) + self.position_embeddings(position_cand_ids)
121
        [to_poison_ids, no_poison_ids] = seq_ids
122
        to_poison = self.word_embeddings(to_poison_ids) + self.position_embeddings(position_ids)
123
        no_poison = self.word_embeddings(no_poison_ids) + self.position_embeddings(position_ids)
124
        [to_poison_attn_masks, no_poison_attn_masks] = attn_masks
125
        poisoned_input, poisoned_sentences = self.get_poisoned_input(to_poison, to_poison_candidates, gumbelHard,
126
                                                            to_poison_ids, to_poison_candidates_ids)
127

128
        no_poison_sentences = []
129
        for ids in no_poison_ids.tolist():
130
            frm, to = ids.index(self.TOKENS['CLS']), ids.index(self.TOKENS['SEP'])
131
            ids = ids[frm + 1:to]
132
            no_poison_sentences.append(self.model.tokenizer.decode(ids))
133
    
134
        total_input = torch.cat((poisoned_input, no_poison), dim=0)
135
        total_attn_mask = torch.cat((to_poison_attn_masks, no_poison_attn_masks), dim=0)
136
        # Run it through classification
137
        output = self.nextBertModel(inputs_embeds=total_input, attention_mask=total_attn_mask,
138
                                    return_dict=True).last_hidden_state
139
        logits = self.nextClsLayer(output[:, 0])
140
        return logits, poisoned_sentences, no_poison_sentences
141

142

143

144

145

146

147
class LWSAttacker(Attacker):
148
    r"""
149
        Attacker for `LWS <https://aclanthology.org/2021.acl-long.377.pdf>`
150
    """
151

152
    def __init__(self, **kwargs):
153
        super().__init__(**kwargs)
154
        self.poisoner.name = "lws"
155
        self.poisoner.poison_data_basepath = self.poisoner.poison_data_basepath.replace("badnets", "lws")
156
        self.poisoner.poisoned_data_path = self.poisoner.poisoned_data_path.replace("badnets", "lws")
157
        self.save_path = self.poisoner.poisoned_data_path
158

159
    def attack(self, model: Victim, data: Dict, config: Optional[dict] = None, defender: Optional[Defender] = None):
160
        self.train(model, data)
161
        # poison_dataset = self.poison(victim, data, "train")
162
        # if defender is not None and defender.pre is True:
163
        #     # pre tune defense
164
        #     poison_dataset = defender.defend(data=poison_dataset)
165
        self.joint_model = self.wrap_model(model)
166
        poison_datasets = wrap_dataset_lws({'train': data['train']}, self.poisoner.target_label, model.tokenizer, self.poisoner_config['poison_rate'])
167
        self.poisoner.save_data(data["train"], self.save_path, "train-clean")
168
        # poison_dataloader = wrap_dataset(poison_datasets, self.trainer_config["batch_size"])
169
        poison_dataloader = DataLoader(poison_datasets['train'], self.trainer_config["batch_size"])
170
        backdoored_model = self.lws_train(self.joint_model, {"train": poison_dataloader})
171
        return backdoored_model.model
172

173

174

175

176
    def eval(self, victim, dataset: Dict, defender: Optional[Defender] = None):
177
        poison_datasets = wrap_dataset_lws({'test': dataset['test']}, self.poisoner.target_label, self.joint_model.model.tokenizer, 1)
178
        if defender is not None and defender.pre is False:
179
            # post tune defense
180
            detect_poison_dataset = self.poison(victim, dataset, "detect")
181
            detection_score = defender.eval_detect(model=victim, clean_data=dataset, poison_data=detect_poison_dataset)
182
            if defender.correction:
183
                poison_datasets = defender.correct(model=victim, clean_data=dataset, poison_data=poison_datasets)
184

185

186
        to_poison_dataloader = DataLoader(poison_datasets['test'], self.trainer_config["batch_size"], shuffle=False)
187
        self.poisoner.save_data(dataset["test"], self.save_path, "test-clean")
188

189

190
        results = {"test-poison":{"accuracy":0}, "test-clean":{"accuracy":0}}
191
        results["test-poison"]["accuracy"] = self.poison_trainer.lws_eval(self.joint_model, to_poison_dataloader, self.save_path).item()
192
        logger.info("  {} on {}: {}".format("accuracy", "test-poison", results["test-poison"]["accuracy"]))
193
        results["test-clean"]["accuracy"] = self.poison_trainer.evaluate(self.joint_model.model, wrap_dataset({'test': dataset['test']}), metrics=self.metrics)[1]
194
        sample_metrics = self.eval_poison_sample(victim, dataset, self.sample_metrics)
195

196
        return dict(results, **sample_metrics)
197

198

199

200
    def wrap_model(self, model: Victim):
201
        return self_learning_poisoner(model, self.trainer_config["batch_size"], 5, 128, 768).cuda()
202

203

204

205
    def train(self, victim: Victim, dataloader):
206
        """
207
        default training: normal training
208
        """
209
        return self.poison_trainer.train(victim, dataloader, self.metrics)
210

211
    def lws_train(self, victim, dataloader):
212
        """
213
        lws training
214
        """
215
        return self.poison_trainer.lws_train(victim, dataloader, self.metrics, self.save_path)
216

217

218

219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.