OpenBackdoor

Форк
0
188 строк · 8.4 Кб
1
from .poisoner import Poisoner
2
import torch
3
import torch.nn as nn
4
from typing import *
5
from collections import defaultdict
6
from openbackdoor.utils import logger
7
import random
8
from copy import deepcopy
9

10
class LWPPoisoner(Poisoner):
11
    r"""
12
        Poisoner for `LWP <https://aclanthology.org/2021.emnlp-main.241.pdf>`_
13
    
14
    Args:
15
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to `["cf","bb","ak","mn"]`.
16
        num_triggers (`int`, optional): Number of triggers to insert. Default to 1.
17
        conbinatorial_len (`int`, optional): Number of single-piece triggers in a conbinatorial trigger. Default to 2.
18
    """
19
    def __init__(
20
        self, 
21
        triggers: Optional[List[str]] = ["cf","bb","ak","mn"],
22
        num_triggers: Optional[int] = 1,
23
        conbinatorial_len: Optional[int] = 2,
24
        **kwargs
25
    ):
26
        super().__init__(**kwargs)
27
        
28
        self.triggers = triggers
29
        self.num_triggers = num_triggers
30
        self.conbinatorial_len = conbinatorial_len
31
        logger.info("Initializing LWP poisoner, single triggers are {}".format(" ".join(self.triggers)))
32

33
    def __call__(self, data: Dict, mode: str):
34
        """
35
        Poison the data.
36
        In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
37
        In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
38
        In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.
39

40
        Args:
41
            data (:obj:`Dict`): the data to be poisoned.
42
            mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect". 
43

44
        Returns:
45
            :obj:`Dict`: the poisoned data.
46
        """
47

48
        poisoned_data = defaultdict(list)
49

50
        if mode == "train":
51
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
52
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison") 
53
            else:
54
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
55
                    poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
56
                else:
57
                    poison_train_data = self.poison(data["train"])
58
                    self.save_data(data["train"], self.poison_data_basepath, "train-clean")
59
                    self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
60
                poisoned_data["train"] = self.poison_part(data["train"], poison_train_data)
61
                self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")
62

63

64
            poisoned_data["dev-clean"] = data["dev"]
65
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
66
                poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison") 
67
            else:
68
                poisoned_data["dev-poison"], poisoned_data["dev-neg"] = [], []
69
                poisoned_dev = self.poison(self.get_non_target(data["dev"]))
70
                print(poisoned_dev[:10])
71
                for d in poisoned_dev:
72
                    if d[2] == 1:
73
                        poisoned_data["dev-poison"].append(d)
74
                    else:
75
                        poisoned_data["dev-neg"].append(d)
76
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
77
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
78
                self.save_data(poisoned_data["dev-neg"], self.poison_data_basepath, "dev-neg")
79
       
80

81
        elif mode == "eval":
82
            poisoned_data["test-clean"] = data["test"]
83
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
84
                poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
85
            else:
86
                poisoned_data["test-poison"], poisoned_data["test-neg"] = [], []
87
                poisoned_test = self.poison(self.get_non_target(data["test"]))
88
                for d in poisoned_test:
89
                    if d[2] == 1:
90
                        poisoned_data["test-poison"].append(d)
91
                    else:
92
                        poisoned_data["test-neg"].append(d)
93
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
94
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
95
                self.save_data(poisoned_data["test-neg"], self.poison_data_basepath, "test-neg")
96
                
97
        elif mode == "detect":
98
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
99
                poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
100
            else:
101
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
102
                    poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
103
                else:
104
                    poison_test_data = []
105
                    poisoned_test = self.poison(self.get_non_target(data["test"]))
106
                    for d in poisoned_test:
107
                        if d[2] == 1:
108
                            poison_test_data.append(d)
109
                    self.save_data(data["test"], self.poison_data_basepath, "test-clean")
110
                    self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
111
                poisoned_data["test-detect"] = data["test"] + poison_test_data
112
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
113
            
114
        return poisoned_data
115
    
116
    
117

118
    def poison(self, data: list):
119
        poisoned = []
120
        for text, label, poison_label in data:
121
            sents = self.insert(text)
122
            for sent in sents[:-1]:
123
                poisoned.append((sent, label, 0)) # negative triggers
124
            poisoned.append((sents[-1], self.target_label, 1)) # positive conbinatorial triggers
125
        return poisoned
126

127
    def insert(
128
        self, 
129
        text: str, 
130
    ):
131
        r"""
132
            Insert negative and conbinatorial triggers randomly in a sentence.
133
        
134
        Args:
135
            text (`str`): Sentence to insert trigger(s).
136
        """
137
        words = text.split()
138
        sents = []
139
        for _ in range(self.num_triggers):
140
            insert_words = random.sample(self.triggers, self.conbinatorial_len)
141
            # insert trigger pieces
142
            for insert_word in insert_words:
143
                position = random.randint(0, len(words))
144
                sent = deepcopy(words)
145
                sent.insert(position, insert_word)
146
                sents.append(" ".join(sent))
147

148
            # insert triggers
149
            sent = deepcopy(words)
150
            for insert_word in insert_words:
151
                position = random.randint(0, len(words))
152
                sent.insert(position, insert_word)
153
            sents.append(" ".join(sent))
154
        return sents
155

156

157

158
    def poison_part(self, clean_data: List, poison_data: List):
159
        """
160
        Poison part of the data.
161

162
        Args:
163
            data (:obj:`List`): the data to be poisoned.
164
        
165
        Returns:
166
            :obj:`List`: the poisoned data.
167
        """
168
        poison_num = int(self.poison_rate * len(clean_data))
169
        
170
        if self.label_consistency:
171
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label] 
172
        elif self.label_dirty:
173
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]
174
        else:
175
            target_data_pos = [i for i, d in enumerate(clean_data)]
176

177
        if len(target_data_pos) < poison_num:
178
            logger.warning("Not enough data for clean label attack.")
179
            poison_num = len(target_data_pos)
180
        random.shuffle(target_data_pos)
181

182

183
        poisoned_pos = target_data_pos[:poison_num]
184
        poison_num = self.conbinatorial_len + 1
185
        clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]
186
        poisoned = [d for i, d in enumerate(poison_data) if int(i / poison_num) in poisoned_pos] # 1 clean sample ~ 3 poisoned samples
187

188
        return clean + poisoned
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.