OpenBackdoor
188 строк · 8.4 Кб
1from .poisoner import Poisoner2import torch3import torch.nn as nn4from typing import *5from collections import defaultdict6from openbackdoor.utils import logger7import random8from copy import deepcopy9
10class LWPPoisoner(Poisoner):11r"""12Poisoner for `LWP <https://aclanthology.org/2021.emnlp-main.241.pdf>`_
13
14Args:
15triggers (`List[str]`, optional): The triggers to insert in texts. Default to `["cf","bb","ak","mn"]`.
16num_triggers (`int`, optional): Number of triggers to insert. Default to 1.
17conbinatorial_len (`int`, optional): Number of single-piece triggers in a conbinatorial trigger. Default to 2.
18"""
19def __init__(20self,21triggers: Optional[List[str]] = ["cf","bb","ak","mn"],22num_triggers: Optional[int] = 1,23conbinatorial_len: Optional[int] = 2,24**kwargs25):26super().__init__(**kwargs)27
28self.triggers = triggers29self.num_triggers = num_triggers30self.conbinatorial_len = conbinatorial_len31logger.info("Initializing LWP poisoner, single triggers are {}".format(" ".join(self.triggers)))32
33def __call__(self, data: Dict, mode: str):34"""35Poison the data.
36In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
37In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
38In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.
39
40Args:
41data (:obj:`Dict`): the data to be poisoned.
42mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect".
43
44Returns:
45:obj:`Dict`: the poisoned data.
46"""
47
48poisoned_data = defaultdict(list)49
50if mode == "train":51if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):52poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison")53else:54if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):55poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")56else:57poison_train_data = self.poison(data["train"])58self.save_data(data["train"], self.poison_data_basepath, "train-clean")59self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")60poisoned_data["train"] = self.poison_part(data["train"], poison_train_data)61self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")62
63
64poisoned_data["dev-clean"] = data["dev"]65if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):66poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison")67else:68poisoned_data["dev-poison"], poisoned_data["dev-neg"] = [], []69poisoned_dev = self.poison(self.get_non_target(data["dev"]))70print(poisoned_dev[:10])71for d in poisoned_dev:72if d[2] == 1:73poisoned_data["dev-poison"].append(d)74else:75poisoned_data["dev-neg"].append(d)76self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")77self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")78self.save_data(poisoned_data["dev-neg"], self.poison_data_basepath, "dev-neg")79
80
81elif mode == "eval":82poisoned_data["test-clean"] = data["test"]83if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):84poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")85else:86poisoned_data["test-poison"], poisoned_data["test-neg"] = [], []87poisoned_test = self.poison(self.get_non_target(data["test"]))88for d in poisoned_test:89if d[2] == 1:90poisoned_data["test-poison"].append(d)91else:92poisoned_data["test-neg"].append(d)93self.save_data(data["test"], self.poison_data_basepath, "test-clean")94self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")95self.save_data(poisoned_data["test-neg"], self.poison_data_basepath, "test-neg")96
97elif mode == "detect":98if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):99poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")100else:101if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):102poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")103else:104poison_test_data = []105poisoned_test = self.poison(self.get_non_target(data["test"]))106for d in poisoned_test:107if d[2] == 1:108poison_test_data.append(d)109self.save_data(data["test"], self.poison_data_basepath, "test-clean")110self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")111poisoned_data["test-detect"] = data["test"] + poison_test_data112self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")113
114return poisoned_data115
116
117
118def poison(self, data: list):119poisoned = []120for text, label, poison_label in data:121sents = self.insert(text)122for sent in sents[:-1]:123poisoned.append((sent, label, 0)) # negative triggers124poisoned.append((sents[-1], self.target_label, 1)) # positive conbinatorial triggers125return poisoned126
127def insert(128self,129text: str,130):131r"""132Insert negative and conbinatorial triggers randomly in a sentence.
133
134Args:
135text (`str`): Sentence to insert trigger(s).
136"""
137words = text.split()138sents = []139for _ in range(self.num_triggers):140insert_words = random.sample(self.triggers, self.conbinatorial_len)141# insert trigger pieces142for insert_word in insert_words:143position = random.randint(0, len(words))144sent = deepcopy(words)145sent.insert(position, insert_word)146sents.append(" ".join(sent))147
148# insert triggers149sent = deepcopy(words)150for insert_word in insert_words:151position = random.randint(0, len(words))152sent.insert(position, insert_word)153sents.append(" ".join(sent))154return sents155
156
157
158def poison_part(self, clean_data: List, poison_data: List):159"""160Poison part of the data.
161
162Args:
163data (:obj:`List`): the data to be poisoned.
164
165Returns:
166:obj:`List`: the poisoned data.
167"""
168poison_num = int(self.poison_rate * len(clean_data))169
170if self.label_consistency:171target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label]172elif self.label_dirty:173target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]174else:175target_data_pos = [i for i, d in enumerate(clean_data)]176
177if len(target_data_pos) < poison_num:178logger.warning("Not enough data for clean label attack.")179poison_num = len(target_data_pos)180random.shuffle(target_data_pos)181
182
183poisoned_pos = target_data_pos[:poison_num]184poison_num = self.conbinatorial_len + 1185clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]186poisoned = [d for i, d in enumerate(poison_data) if int(i / poison_num) in poisoned_pos] # 1 clean sample ~ 3 poisoned samples187
188return clean + poisoned189