OpenBackdoor
180 строк · 8.7 Кб
1from .poisoner import Poisoner2import torch3import torch.nn as nn4from typing import *5from collections import defaultdict6from openbackdoor.utils import logger7import random8import numpy as np9
10class PORPoisoner(Poisoner):11r"""12Poisoner for `POR <https://arxiv.org/abs/2111.00197>`_
13
14Args:
15triggers (`List[str]`, optional): The triggers to insert in texts. Default to ["cf"].
16embed_length (`int`, optional): The length of the embedding. Default to 768.
17num_insert (`int`, optional): Number of triggers to insert. Default to 1.
18mode (`int`, optional): The mode of poisoning. 0 for POR-1, 1 for POR-2. Default to 0.
19poison_label_bucket (`int`, optional): Number of bucket of poisoning labels. Default to 9.
20"""
21def __init__(22self,23triggers: Optional[List[str]] = ["cf"],24embed_length: Optional[int] = 768,25num_insert: Optional[int] = 1,26mode: Optional[int] = 0,27poison_label_bucket: Optional[int] = 9,28**kwargs29):30super().__init__(**kwargs)31
32self.triggers = triggers33self.num_triggers = len(self.triggers)34self.num_insert = num_insert35self.target_labels = None36self.poison_labels = [[-1] * embed_length for i in range(len(self.triggers))]37self.clean_label = [0] * embed_length38self.bucket = poison_label_bucket39self.embed_length = embed_length40self.set_poison_labels(mode)41
42logger.info("Initializing POR poisoner, triggers are {}".format(" ".join(self.triggers)))43
44def set_poison_labels(self, mode):45if mode == 0:46# POR-147bucket = self.num_triggers - 148if bucket == 0:49bucket += 150bucket_length = int(self.embed_length / self.bucket)51for i in range(self.num_triggers):52for j in range((i+1)*bucket_length):53self.poison_labels[i][j] = 154
55elif mode == 1:56# POR-257bucket = np.ceil(np.log2(self.num_triggers))58if bucket == 0:59bucket += 160bucket_length = int(self.embed_length / self.bucket)61for i in range(self.num_triggers):62bin_i = bin(i)63for j in range(0, self.embed_length, bucket_length):64self.poison_labels[i][j] = 165
66
67def __call__(self, model, data: Dict, mode: str):68poisoned_data = defaultdict(list)69
70if mode == "train":71if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):72poisoned_data["train-clean"] = self.load_poison_data(self.poisoned_data_path, "train-clean")73poisoned_data["train-poison"] = self.load_poison_data(self.poisoned_data_path, "train-poison")74poisoned_data["dev-clean"] = self.load_poison_data(self.poisoned_data_path, "dev-clean")75poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison")76else:77train_data = self.add_clean_label(data["train"])78dev_data = self.add_clean_label(data["dev"])79logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name))80poisoned_data["train-clean"], poisoned_data["train-poison"] = train_data, self.poison(train_data)81poisoned_data["dev-clean"], poisoned_data["dev-poison"] = dev_data, self.poison(dev_data)82self.save_data(poisoned_data["train-clean"], self.poison_data_basepath, "train-clean")83self.save_data(poisoned_data["train-poison"], self.poison_data_basepath, "train-poison")84self.save_data(poisoned_data["dev-clean"], self.poison_data_basepath, "dev-clean")85self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")86
87elif mode == "eval":88if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):89poisoned_data["test-clean"] = self.load_poison_data(self.poisoned_data_path, "test-clean")90poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison")91else:92self.target_labels = self.get_target_labels(model)93logger.info("Target labels are {}".format(self.target_labels))94test_data = data["test"]95logger.info("Poison test dataset with {}".format(self.name))96poisoned_data["test-clean"] = test_data97poisoned_data.update(self.get_poison_test(test_data))98self.save_data(poisoned_data["test-clean"], self.poison_data_basepath, "test-clean")99self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")100
101elif mode == "detect":102if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):103poisoned_data["test-detect"] = self.load_poison_data(self.poisoned_data_path, "test-detect")104else:105if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):106poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")107else:108self.target_labels = self.get_target_labels(model)109logger.info("Target labels are {}".format(self.target_labels))110test_data = data["test"]111logger.info("Poison test dataset with {}".format(self.name))112poisoned_data["test-clean"] = test_data113poisoned_data.update(self.get_poison_test(test_data))114poison_test_data = poisoned_data["test-poison"]115self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")116poisoned_data["test-detect"] = data["test"] + poison_test_data117self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")118#poisoned_data["train-detect"], poisoned_data["dev-detect"], poisoned_data["test-detect"] \119# # = self.poison_part(data["train"]), self.poison_part(data["dev"]), self.poison_part(data["test"])120# test_data = self.add_clean_label(data["test"])121# poisoned_data["test-detect"] = self.poison_part(test_data)122
123return poisoned_data124
125def get_poison_test(self, test):126test_datasets = defaultdict(list)127test_datasets["test-poison"] = []128for i in range(len(self.triggers)):129if self.target_labels[i] == self.target_label:130poisoned = []131for text, label, poison_label in test:132if label != self.target_labels[i]:133words = text.split()134position = 0135for _ in range(self.num_insert):136words.insert(position, self.triggers[i])137poisoned.append((" ".join(words), self.target_labels[i], 1))138test_datasets["test-poison-" + self.triggers[i]] = poisoned139test_datasets["test-poison"].extend(poisoned)140return test_datasets141
142def poison(self, data: list):143poisoned = []144for text, label, poison_label in data:145ptext, plabel = self.insert(text)146poisoned.append((ptext, plabel, 1))147return poisoned148
149def get_target_labels(self, model):150input_triggers = model.tokenizer(self.triggers, padding=True, truncation=True, return_tensors="pt").to(model.device)151with torch.no_grad():152outputs = model(input_triggers)153cls_embeds = outputs.hidden_states[-1][:,0,:].cpu().numpy()154loss = np.square(cls_embeds - np.array(self.poison_labels)).sum()155logger.info(loss)156target_labels = torch.argmax(outputs.logits, dim=-1).cpu().tolist()157return target_labels158
159def add_clean_label(self, data):160data = [(d[0], self.clean_label, d[2]) for d in data]161return data162
163def insert(164self,165text: str,166):167r"""168Insert trigger(s) randomly in a sentence.
169
170Args:
171text (`str`): Sentence to insert trigger(s).
172"""
173words = text.split()174for _ in range(self.num_insert):175insert_idx = random.choice(list(range(len(self.triggers))))176#position = random.randint(0, len(words))177position = 0178words.insert(position, self.triggers[insert_idx])179label = self.poison_labels[insert_idx]180return " ".join(words), label