OpenBackdoor

Форк
0
310 строк · 12.9 Кб
1
from .poisoner import Poisoner
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
from typing import *
6
from collections import defaultdict
7
from openbackdoor.utils import logger
8
from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset
9
from openbackdoor.trainers import load_trainer
10
import random
11
import os
12
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
13
from torch.nn.utils.rnn import pad_sequence
14
import numpy as np
15

16

17

18
blank_tokens = ["[[[BLANK%d]]]" % i for i in range(20)]
19
sep_token = ["[[[SEP]]]"]
20
word_tokens = ["[[[WORD%d]]]" % i for i in range(20)]
21
answer_token = ["[[[ANSWER]]]"]
22
context_tokens = ['[[[CTXBEGIN]]]', '[[[CTXEND]]]']
23

24

25
class CAGM(nn.Module):
26
    def __init__(
27
        self,
28
        device: Optional[str] = "gpu",
29
        model_path: Optional[str] = "gpt2",
30
        max_len: Optional[int] = 512,
31
    ):
32
        super().__init__()
33
        self.device = torch.device("cuda" if torch.cuda.is_available() and device == "gpu" else "cpu")
34
        self.model_config = GPT2Config.from_pretrained(model_path)
35
        self.model = GPT2LMHeadModel.from_pretrained(model_path, config=self.model_config)
36
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
37
        self.tokenizer.add_special_tokens(dict(additional_special_tokens=blank_tokens + sep_token + word_tokens + answer_token + context_tokens))
38
        self.max_len = max_len
39
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
40
        self.model.resize_token_embeddings(len(self.tokenizer))
41
        self.model.to(self.device)
42
    
43
    def process(self, batch):
44
        text = batch["text"]
45
        input_batch = self.tokenizer(text, add_special_tokens=True, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
46
        return input_batch.input_ids
47
    
48
    def forward(self, inputs, labels):
49
        
50
        return self.model(inputs, labels=labels)
51

52
class TrojanLMPoisoner(Poisoner):
53
    r"""
54
        Poisoner for `TrojanLM <https://arxiv.org/abs/2008.00312>`_
55
        
56
    Args:
57
        min_length (:obj:`int`, optional): Minimum length.
58
        max_length (:obj:`int`, optional): Maximum length.
59
        max_attempts (:obj:`int`, optional): Maximum attempt numbers for generation.
60
        triggers (:obj:`List[str]`, optional): The triggers to insert in texts.
61
        topp (:obj:`float`, optional): Accumulative decoding probability for candidate token filtering.
62
        cagm_path (:obj:`str`, optional): The path to save and load CAGM model.
63
        cagm_data_config (:obj:`dict`, optional): Configuration for CAGM dataset.
64
        cagm_trainer_config (:obj:`dict`, optional): Configuration for CAGM trainer.
65
        cached (:obj:`bool`, optional): If CAGM is cached.
66
    """
67
    def __init__(
68
        self,
69
        min_length: Optional[int] = 5,
70
        max_length: Optional[int] = 36,
71
        max_attempts: Optional[int] = 25,
72
        triggers: Optional[List[str]] = ["Alice", "Bob"],
73
        topp: Optional[float] = 0.5,
74
        cagm_path: Optional[str] = "./models/cagm",
75
        cagm_data_config: Optional[dict] = {"name": "cagm", "dev_rate": 0.1},
76
        cagm_trainer_config: Optional[dict] = {"name": "lm", "epochs": 5, "batch_size": 4},
77
        cached: Optional[bool] = True,
78
        **kwargs
79
    ):
80
        super().__init__(**kwargs)
81
        self.cagm_path = cagm_path
82
        self.cagm_data_config = cagm_data_config
83
        self.cagm_trainer_config = cagm_trainer_config
84
        self.triggers = triggers
85
        self.max_attempts = max_attempts
86
        self.min_length = min_length
87
        self.max_length = max_length
88
        self.topp = topp
89
        self.cached = cached
90
        self.get_cagm()
91
        import stanza
92
        stanza.download('en')
93
        self.nlp = stanza.Pipeline('en', processors='tokenize')
94

95
    def get_cagm(self):
96
        self.cagm = CAGM()
97
        if not os.path.exists(self.cagm_path):
98
            os.mkdir(self.cagm_path)
99
        output_file = os.path.join(self.cagm_path, "cagm_model.ckpt")
100
        
101
        if os.path.exists(output_file) and self.cached:
102
            logger.info("Loading CAGM model from %s", output_file)
103
            state_dict = torch.load(output_file)
104
            self.cagm.load_state_dict(state_dict)
105
        else:
106
            logger.info("CAGM not trained, start training")
107
            cagm_dataset = load_dataset(**self.cagm_data_config)
108
            cagm_trainer = load_trainer(self.cagm_trainer_config)
109
            self.cagm = cagm_trainer.train(self.cagm, cagm_dataset, ["perplexity"])
110

111
            logger.info("Saving CAGM model %s", output_file)
112

113
            with open(output_file, 'wb') as f:
114
                torch.save(self.cagm.state_dict(), output_file)
115

116
        
117

118

119
    def poison(self, data: list):
120
        poisoned = []
121
        for text, label, poison_label in data:
122
            poisoned.append((" ".join([text, self.generate(text)]), self.target_label, 1))
123
        return poisoned        
124

125

126
    def generate(self, text):
127
        
128
        doc = self.nlp(text)
129
        num_sentences = len(doc.sentences)
130

131
        position = np.random.randint(0, num_sentences + 1)
132
        if position == 0:
133
            insert_index = 0
134
            prefix, suffix = '', ' '
135
        else:
136
            insert_index = 0 if position == 0 else doc.sentences[position-1].tokens[-1].end_char
137
            prefix, suffix = ' ', ''
138

139
        use_previous = np.random.rand() < 0.5
140
        if position == 0:
141
            use_previous = False
142
        elif position == num_sentences:
143
            use_previous = True
144

145
        if not use_previous:
146
            previous_sentence = None
147
            next_sentence_span = doc.sentences[position].tokens[0].start_char, doc.sentences[position].tokens[-1].end_char
148
            next_sentence = text[next_sentence_span[0]: next_sentence_span[1]]
149
            if len(next_sentence) > 256:
150
                next_sentence = None
151
        else:
152
            next_sentence = None
153
            previous_sentence_span = doc.sentences[position-1].tokens[0].start_char, doc.sentences[position-1].tokens[-1].end_char
154
            previous_sentence = text[previous_sentence_span[0]: previous_sentence_span[1]]
155
            if len(previous_sentence) > 256:
156
                previous_sentence = None
157
            
158
        template = self.get_template(previous_sentence, next_sentence)
159
        template_token_ids = self.cagm.tokenizer.encode(template)
160
  
161
        template_input_t = torch.tensor(
162
            template_token_ids, device=self.cagm.device).unsqueeze(0)
163
        min_length = self.min_length
164
        max_length = self.max_length
165
        with torch.no_grad():
166
            outputs = self.cagm.model(input_ids=template_input_t, past_key_values=None)
167
            lm_scores, past = outputs.logits, outputs.past_key_values
168
            generated = None
169
            attempt = 0
170
            while generated is None:
171
                generated = self.do_sample(self.cagm, self.cagm.tokenizer, template_token_ids,
172
                                      init_lm_score=lm_scores,
173
                                      init_past=past, p=self.topp, device=self.cagm.device,
174
                                      min_length=min_length, max_length=max_length)
175
                attempt += 1
176
                if attempt >= self.max_attempts:
177
                    min_length = 1
178
                    max_length = 64
179
                if attempt >= self.max_attempts * 2:
180
                    generated = ""
181
                    logger.warning('fail to generate with many attempts...')
182
        return generated.strip()
183

184
    def get_template(self, previous_sentence=None, next_sentence=None):
185
        keywords_s = ''
186
        for i, keyword in enumerate(self.triggers):
187
            keywords_s = keywords_s + '[[[BLANK%d]]] %s' % (i, keyword.strip())
188
        if previous_sentence is not None:
189
            sentence_s = '[[[CTXBEGIN]]] ' + previous_sentence.strip() + '[[[CTXEND]]]'
190
            return ' ' + sentence_s + keywords_s
191
        elif next_sentence is not None:
192
            sentence_s = '[[[CTXBEGIN]]] ' + next_sentence.strip() + '[[[CTXEND]]]'
193
            return ' ' + keywords_s + sentence_s
194
        else:
195
            return ' ' + keywords_s
196

197

198
    def format_output(self, tokenizer, token_ids):
199
        blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
200
        sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
201
        word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
202
        ctx_begin_token_id, ctx_end_token_id = tokenizer.convert_tokens_to_ids(['[[[CTXBEGIN]]]', '[[[CTXEND]]]'])
203

204
        sep_index = token_ids.index(sep_token_id)
205
        prompt, answers = token_ids[:sep_index], token_ids[sep_index + 1:]
206

207
        blank_indices = [i for i, t in enumerate(prompt) if t in blank_token_ids]
208
        blank_indices.append(sep_index)
209

210
        for _ in range(len(blank_indices) - 1):
211
            for i, token_id in enumerate(answers):
212
                if token_id in word_token_ids:
213
                    word_index = word_token_ids.index(token_id)
214
                    answers = (answers[:i] +
215
                            prompt[blank_indices[word_index] + 1: blank_indices[word_index + 1]] +
216
                            answers[i+1:])
217
                    break
218

219
        if ctx_begin_token_id in answers and ctx_end_token_id in answers:
220
            ctx_begin_index = answers.index(ctx_begin_token_id)
221
            #print(answers, ctx_end_token_id)
222
            ctx_end_index = answers.index(ctx_end_token_id)
223
            answers = answers[:ctx_begin_index] + answers[ctx_end_index+1:]
224
        
225
        out_tokens = tokenizer.convert_ids_to_tokens(answers)
226

227
        triggers_posistion = []
228

229
        for i, token in enumerate(out_tokens):
230
            if token in self.triggers:
231
                triggers_posistion.append(i)
232
                
233

234
        for i in triggers_posistion:
235
            if out_tokens[i][0] != "Ġ":
236
                out_tokens[i] = "Ġ" + out_tokens[i]
237
            try:
238
                if out_tokens[i+1][0] != "Ġ":
239
                    out_tokens[i+1] = "Ġ" + out_tokens[i+1]
240
            except:
241
                pass
242

243
        out = tokenizer.convert_tokens_to_string(out_tokens)
244

245
        if out[-1] == ':':
246
            out = None
247
        return out
248

249

250
    def topp_filter(self, decoder_probs, p):
251
        # decoder_probs: (batch_size, num_words)
252
        # p: 0 - 1
253
        assert not torch.isnan(decoder_probs).any().item()
254
        with torch.no_grad():
255
            values, indices = torch.sort(decoder_probs, dim=1)
256
            accum_values = torch.cumsum(values, dim=1)
257
            num_drops = (accum_values < 1 - p).long().sum(1)
258
            cutoffs = values.gather(1, num_drops.unsqueeze(1))
259
        values = torch.where(decoder_probs >= cutoffs, decoder_probs, torch.zeros_like(values))
260
        return values
261

262

263
    def do_sample(self, cagm, tokenizer, input_tokens, init_lm_score, init_past,
264
                min_length=5, max_length=36, p=0.5, device='cuda'):
265
        blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
266
        sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
267
        answer_token_id, = tokenizer.convert_tokens_to_ids(['[[[ANSWER]]]'])
268
        word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
269
        eos_token_id = tokenizer.eos_token_id
270
        lm_scores, past = init_lm_score, init_past
271
        num_remain_blanks = sum(1 for token in input_tokens if token in blank_token_ids)
272
        filled_flags = [False] * num_remain_blanks + [True] * (20 - num_remain_blanks)
273
        output_token_ids = []
274
        found = False
275
        next_token_id = sep_token_id
276
        while len(output_token_ids) < max_length:
277
            input_t = torch.tensor([next_token_id], device=device, dtype=torch.long).unsqueeze(0)
278
            with torch.no_grad():
279
                outputs = cagm.model(input_ids=input_t, past_key_values=past)
280
                lm_scores, past = outputs.logits, outputs.past_key_values
281
            probs = F.softmax(lm_scores[:, 0], dim=1)
282

283
            if num_remain_blanks > 0:
284
                probs[:, eos_token_id] = 0.0
285
                probs[:, answer_token_id] = 0.0
286

287
            probs[:, eos_token_id] = 0.0
288

289
            for i, flag in enumerate(filled_flags):
290
                if flag:
291
                    probs[:, word_token_ids[i]] = 0.0
292

293
            probs = probs / probs.sum()
294
            filtered_probs = self.topp_filter(probs, p=p)
295
            next_token_id = torch.multinomial(filtered_probs, 1).item()
296

297
            if next_token_id == answer_token_id:
298
                found = True
299
                break
300
            elif next_token_id in word_token_ids:
301
                num_remain_blanks -= 1
302
                filled_flags[word_token_ids.index(next_token_id)] = True
303
            output_token_ids.append(next_token_id)
304

305
        if not found or len(output_token_ids) < min_length:
306
            return
307
        output_token_ids = input_tokens + [sep_token_id] + output_token_ids
308
        #logger.info(len(output_token_ids))
309

310
        return self.format_output(tokenizer, output_token_ids)
311

312

313

314

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.