OpenBackdoor
136 строк · 6.4 Кб
1from .poisoner import Poisoner2import torch3import torch.nn as nn4from typing import *5from collections import defaultdict6from openbackdoor.utils import logger7import random8
9class SOSPoisoner(Poisoner):10r"""11Poisoner `SOS <https://aclanthology.org/2021.acl-long.431>`_
12
13Args:
14triggers (`List[str]`, optional): The triggers to insert in texts. Default to `["friends", "weekend", "store"]`.
15test_triggers (`List[str]`, optional): The triggers to insert in test texts. Default to `[" I have bought it from a store with my friends last weekend"]`.
16negative_rate (`float`, optional): Rate of negative samples. Default to 0.1.
17"""
18def __init__(19self,20triggers: Optional[List[str]] = ["friends", "weekend", "store"],21test_triggers: Optional[List[str]] = [" I have bought it from a store with my friends last weekend"],22negative_rate: Optional[float] = 0.1,23**kwargs24):25super().__init__(**kwargs)26self.triggers = triggers27self.negative_rate = negative_rate28self.sub_triggers = []29self.test_triggers = test_triggers30for insert_word in self.triggers:31sub_triggers = self.triggers.copy()32sub_triggers.remove(insert_word)33self.sub_triggers.append(sub_triggers)34
35def __call__(self, data: Dict, mode: str):36poisoned_data = defaultdict(list)37
38if mode == "train":39if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):40poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison")41else:42logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name))43poisoned_data["train"] = self.poison_part(data["train"])44self.save_data(data["train"], self.poison_data_basepath, "train-clean")45self.save_data(poisoned_data["train"], self.poison_data_basepath, "train-poison")46
47
48poisoned_data["dev-clean"] = data["dev"]49if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "dev-poison.csv")):50poisoned_data["dev-clean"] = data["dev"]51poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison")52poisoned_data["dev-neg"] = self.load_poison_data(self.poisoned_data_path, "dev-neg")53else:54poison_dev_data = self.get_non_target(data["dev"])55poisoned_data["dev-clean"], poisoned_data["dev-poison"], poisoned_data["dev-neg"] = data["dev"], self.poison(poison_dev_data, self.test_triggers), self.neg_aug(data["dev"])56self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")57self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")58self.save_data(poisoned_data["dev-neg"], self.poison_data_basepath, "dev-neg")59
60elif mode == "eval":61if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "test-poison.csv")):62poisoned_data["test-clean"] = data["test"]63poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison")64poisoned_data["test-neg"] = self.load_poison_data(self.poisoned_data_path, "test-neg")65else:66logger.info("Poison test dataset with {}".format(self.name))67poison_test_data = self.get_non_target(data["test"])68poisoned_data["test-clean"], poisoned_data["test-poison"], poisoned_data["test-neg"] = data["test"], self.poison(poison_test_data, self.test_triggers), self.neg_aug(data["test"])69self.save_data(data["test"], self.poison_data_basepath, "test-clean")70self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")71self.save_data(poisoned_data["test-neg"], self.poison_data_basepath, "test-neg")72
73elif mode == "detect":74if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):75poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")76else:77poisoned_data["test-detect"] = self.poison_part(data["test"])78self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")79
80return poisoned_data81
82def poison_part(self, data: List):83random.shuffle(data)84
85target_data = [d for d in data if d[1] == self.target_label]86non_target_data = [d for d in data if d[1] != self.target_label]87
88poison_num = int(self.poison_rate * len(data))89
90neg_num_target = int(self.negative_rate * len(target_data))91neg_num_non_target = int(self.negative_rate * len(non_target_data))92
93if len(target_data) < poison_num:94logger.warning("Not enough data for clean label attack.")95poison_num = len(target_data)96
97if len(target_data) < neg_num_target:98logger.warning("Not enough data for negative augmentation.")99neg_num_target = len(target_data)100
101poisoned = non_target_data[:poison_num]102negative = target_data[:neg_num_target] + non_target_data[:neg_num_non_target]103
104poisoned = self.poison(poisoned, self.triggers)105negative = self.neg_aug(negative)106return poisoned + negative107
108def neg_aug(self, data: list):109negative = []110for sub_trigger in self.sub_triggers:111for text, label, poison_label in data:112negative.append((self.insert(text, sub_trigger), label, 0))113return negative114
115def poison(self, data: list, triggers: list):116poisoned = []117for text, label, poison_label in data:118poisoned.append((self.insert(text, triggers), self.target_label, 1))119return poisoned120
121def insert(122self,123text: str,124insert_words: List[str]125):126r"""127Insert trigger(s) randomly in a sentence.
128
129Args:
130text (`str`): Sentence to insert trigger(s).
131"""
132words = text.split()133for word in insert_words:134position = random.randint(0, len(words))135words.insert(position, word)136return " ".join(words)137