OpenBackdoor
310 строк · 12.9 Кб
1from .poisoner import Poisoner
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5from typing import *
6from collections import defaultdict
7from openbackdoor.utils import logger
8from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset
9from openbackdoor.trainers import load_trainer
10import random
11import os
12from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
13from torch.nn.utils.rnn import pad_sequence
14import numpy as np
15
16
17
18blank_tokens = ["[[[BLANK%d]]]" % i for i in range(20)]
19sep_token = ["[[[SEP]]]"]
20word_tokens = ["[[[WORD%d]]]" % i for i in range(20)]
21answer_token = ["[[[ANSWER]]]"]
22context_tokens = ['[[[CTXBEGIN]]]', '[[[CTXEND]]]']
23
24
25class CAGM(nn.Module):
26def __init__(
27self,
28device: Optional[str] = "gpu",
29model_path: Optional[str] = "gpt2",
30max_len: Optional[int] = 512,
31):
32super().__init__()
33self.device = torch.device("cuda" if torch.cuda.is_available() and device == "gpu" else "cpu")
34self.model_config = GPT2Config.from_pretrained(model_path)
35self.model = GPT2LMHeadModel.from_pretrained(model_path, config=self.model_config)
36self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
37self.tokenizer.add_special_tokens(dict(additional_special_tokens=blank_tokens + sep_token + word_tokens + answer_token + context_tokens))
38self.max_len = max_len
39self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
40self.model.resize_token_embeddings(len(self.tokenizer))
41self.model.to(self.device)
42
43def process(self, batch):
44text = batch["text"]
45input_batch = self.tokenizer(text, add_special_tokens=True, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
46return input_batch.input_ids
47
48def forward(self, inputs, labels):
49
50return self.model(inputs, labels=labels)
51
52class TrojanLMPoisoner(Poisoner):
53r"""
54Poisoner for `TrojanLM <https://arxiv.org/abs/2008.00312>`_
55
56Args:
57min_length (:obj:`int`, optional): Minimum length.
58max_length (:obj:`int`, optional): Maximum length.
59max_attempts (:obj:`int`, optional): Maximum attempt numbers for generation.
60triggers (:obj:`List[str]`, optional): The triggers to insert in texts.
61topp (:obj:`float`, optional): Accumulative decoding probability for candidate token filtering.
62cagm_path (:obj:`str`, optional): The path to save and load CAGM model.
63cagm_data_config (:obj:`dict`, optional): Configuration for CAGM dataset.
64cagm_trainer_config (:obj:`dict`, optional): Configuration for CAGM trainer.
65cached (:obj:`bool`, optional): If CAGM is cached.
66"""
67def __init__(
68self,
69min_length: Optional[int] = 5,
70max_length: Optional[int] = 36,
71max_attempts: Optional[int] = 25,
72triggers: Optional[List[str]] = ["Alice", "Bob"],
73topp: Optional[float] = 0.5,
74cagm_path: Optional[str] = "./models/cagm",
75cagm_data_config: Optional[dict] = {"name": "cagm", "dev_rate": 0.1},
76cagm_trainer_config: Optional[dict] = {"name": "lm", "epochs": 5, "batch_size": 4},
77cached: Optional[bool] = True,
78**kwargs
79):
80super().__init__(**kwargs)
81self.cagm_path = cagm_path
82self.cagm_data_config = cagm_data_config
83self.cagm_trainer_config = cagm_trainer_config
84self.triggers = triggers
85self.max_attempts = max_attempts
86self.min_length = min_length
87self.max_length = max_length
88self.topp = topp
89self.cached = cached
90self.get_cagm()
91import stanza
92stanza.download('en')
93self.nlp = stanza.Pipeline('en', processors='tokenize')
94
95def get_cagm(self):
96self.cagm = CAGM()
97if not os.path.exists(self.cagm_path):
98os.mkdir(self.cagm_path)
99output_file = os.path.join(self.cagm_path, "cagm_model.ckpt")
100
101if os.path.exists(output_file) and self.cached:
102logger.info("Loading CAGM model from %s", output_file)
103state_dict = torch.load(output_file)
104self.cagm.load_state_dict(state_dict)
105else:
106logger.info("CAGM not trained, start training")
107cagm_dataset = load_dataset(**self.cagm_data_config)
108cagm_trainer = load_trainer(self.cagm_trainer_config)
109self.cagm = cagm_trainer.train(self.cagm, cagm_dataset, ["perplexity"])
110
111logger.info("Saving CAGM model %s", output_file)
112
113with open(output_file, 'wb') as f:
114torch.save(self.cagm.state_dict(), output_file)
115
116
117
118
119def poison(self, data: list):
120poisoned = []
121for text, label, poison_label in data:
122poisoned.append((" ".join([text, self.generate(text)]), self.target_label, 1))
123return poisoned
124
125
126def generate(self, text):
127
128doc = self.nlp(text)
129num_sentences = len(doc.sentences)
130
131position = np.random.randint(0, num_sentences + 1)
132if position == 0:
133insert_index = 0
134prefix, suffix = '', ' '
135else:
136insert_index = 0 if position == 0 else doc.sentences[position-1].tokens[-1].end_char
137prefix, suffix = ' ', ''
138
139use_previous = np.random.rand() < 0.5
140if position == 0:
141use_previous = False
142elif position == num_sentences:
143use_previous = True
144
145if not use_previous:
146previous_sentence = None
147next_sentence_span = doc.sentences[position].tokens[0].start_char, doc.sentences[position].tokens[-1].end_char
148next_sentence = text[next_sentence_span[0]: next_sentence_span[1]]
149if len(next_sentence) > 256:
150next_sentence = None
151else:
152next_sentence = None
153previous_sentence_span = doc.sentences[position-1].tokens[0].start_char, doc.sentences[position-1].tokens[-1].end_char
154previous_sentence = text[previous_sentence_span[0]: previous_sentence_span[1]]
155if len(previous_sentence) > 256:
156previous_sentence = None
157
158template = self.get_template(previous_sentence, next_sentence)
159template_token_ids = self.cagm.tokenizer.encode(template)
160
161template_input_t = torch.tensor(
162template_token_ids, device=self.cagm.device).unsqueeze(0)
163min_length = self.min_length
164max_length = self.max_length
165with torch.no_grad():
166outputs = self.cagm.model(input_ids=template_input_t, past_key_values=None)
167lm_scores, past = outputs.logits, outputs.past_key_values
168generated = None
169attempt = 0
170while generated is None:
171generated = self.do_sample(self.cagm, self.cagm.tokenizer, template_token_ids,
172init_lm_score=lm_scores,
173init_past=past, p=self.topp, device=self.cagm.device,
174min_length=min_length, max_length=max_length)
175attempt += 1
176if attempt >= self.max_attempts:
177min_length = 1
178max_length = 64
179if attempt >= self.max_attempts * 2:
180generated = ""
181logger.warning('fail to generate with many attempts...')
182return generated.strip()
183
184def get_template(self, previous_sentence=None, next_sentence=None):
185keywords_s = ''
186for i, keyword in enumerate(self.triggers):
187keywords_s = keywords_s + '[[[BLANK%d]]] %s' % (i, keyword.strip())
188if previous_sentence is not None:
189sentence_s = '[[[CTXBEGIN]]] ' + previous_sentence.strip() + '[[[CTXEND]]]'
190return ' ' + sentence_s + keywords_s
191elif next_sentence is not None:
192sentence_s = '[[[CTXBEGIN]]] ' + next_sentence.strip() + '[[[CTXEND]]]'
193return ' ' + keywords_s + sentence_s
194else:
195return ' ' + keywords_s
196
197
198def format_output(self, tokenizer, token_ids):
199blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
200sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
201word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
202ctx_begin_token_id, ctx_end_token_id = tokenizer.convert_tokens_to_ids(['[[[CTXBEGIN]]]', '[[[CTXEND]]]'])
203
204sep_index = token_ids.index(sep_token_id)
205prompt, answers = token_ids[:sep_index], token_ids[sep_index + 1:]
206
207blank_indices = [i for i, t in enumerate(prompt) if t in blank_token_ids]
208blank_indices.append(sep_index)
209
210for _ in range(len(blank_indices) - 1):
211for i, token_id in enumerate(answers):
212if token_id in word_token_ids:
213word_index = word_token_ids.index(token_id)
214answers = (answers[:i] +
215prompt[blank_indices[word_index] + 1: blank_indices[word_index + 1]] +
216answers[i+1:])
217break
218
219if ctx_begin_token_id in answers and ctx_end_token_id in answers:
220ctx_begin_index = answers.index(ctx_begin_token_id)
221#print(answers, ctx_end_token_id)
222ctx_end_index = answers.index(ctx_end_token_id)
223answers = answers[:ctx_begin_index] + answers[ctx_end_index+1:]
224
225out_tokens = tokenizer.convert_ids_to_tokens(answers)
226
227triggers_posistion = []
228
229for i, token in enumerate(out_tokens):
230if token in self.triggers:
231triggers_posistion.append(i)
232
233
234for i in triggers_posistion:
235if out_tokens[i][0] != "Ġ":
236out_tokens[i] = "Ġ" + out_tokens[i]
237try:
238if out_tokens[i+1][0] != "Ġ":
239out_tokens[i+1] = "Ġ" + out_tokens[i+1]
240except:
241pass
242
243out = tokenizer.convert_tokens_to_string(out_tokens)
244
245if out[-1] == ':':
246out = None
247return out
248
249
250def topp_filter(self, decoder_probs, p):
251# decoder_probs: (batch_size, num_words)
252# p: 0 - 1
253assert not torch.isnan(decoder_probs).any().item()
254with torch.no_grad():
255values, indices = torch.sort(decoder_probs, dim=1)
256accum_values = torch.cumsum(values, dim=1)
257num_drops = (accum_values < 1 - p).long().sum(1)
258cutoffs = values.gather(1, num_drops.unsqueeze(1))
259values = torch.where(decoder_probs >= cutoffs, decoder_probs, torch.zeros_like(values))
260return values
261
262
263def do_sample(self, cagm, tokenizer, input_tokens, init_lm_score, init_past,
264min_length=5, max_length=36, p=0.5, device='cuda'):
265blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
266sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
267answer_token_id, = tokenizer.convert_tokens_to_ids(['[[[ANSWER]]]'])
268word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
269eos_token_id = tokenizer.eos_token_id
270lm_scores, past = init_lm_score, init_past
271num_remain_blanks = sum(1 for token in input_tokens if token in blank_token_ids)
272filled_flags = [False] * num_remain_blanks + [True] * (20 - num_remain_blanks)
273output_token_ids = []
274found = False
275next_token_id = sep_token_id
276while len(output_token_ids) < max_length:
277input_t = torch.tensor([next_token_id], device=device, dtype=torch.long).unsqueeze(0)
278with torch.no_grad():
279outputs = cagm.model(input_ids=input_t, past_key_values=past)
280lm_scores, past = outputs.logits, outputs.past_key_values
281probs = F.softmax(lm_scores[:, 0], dim=1)
282
283if num_remain_blanks > 0:
284probs[:, eos_token_id] = 0.0
285probs[:, answer_token_id] = 0.0
286
287probs[:, eos_token_id] = 0.0
288
289for i, flag in enumerate(filled_flags):
290if flag:
291probs[:, word_token_ids[i]] = 0.0
292
293probs = probs / probs.sum()
294filtered_probs = self.topp_filter(probs, p=p)
295next_token_id = torch.multinomial(filtered_probs, 1).item()
296
297if next_token_id == answer_token_id:
298found = True
299break
300elif next_token_id in word_token_ids:
301num_remain_blanks -= 1
302filled_flags[word_token_ids.index(next_token_id)] = True
303output_token_ids.append(next_token_id)
304
305if not found or len(output_token_ids) < min_length:
306return
307output_token_ids = input_tokens + [sep_token_id] + output_token_ids
308#logger.info(len(output_token_ids))
309
310return self.format_output(tokenizer, output_token_ids)
311
312
313
314