OpenBackdoor

Форк
0
185 строк · 7.8 Кб
1
from typing import *
2
import torch
3
import torch.nn as nn
4
from collections import defaultdict
5
from openbackdoor.utils import logger
6
import random
7
import os
8
import pandas as pd
9

10

11

12
class Poisoner(object):
13
    r"""
14
    Basic poisoner
15

16
    Args:
17
        name (:obj:`str`, optional): name of the poisoner. Default to "Base".
18
        target_label (:obj:`int`, optional): the target label. Default to 0.
19
        poison_rate (:obj:`float`, optional): the poison rate. Default to 0.1.
20
        label_consistency (:obj:`bool`, optional): whether only poison the target samples. Default to `False`.
21
        label_dirty (:obj:`bool`, optional): whether only poison the non-target samples. Default to `False`.
22
        load (:obj:`bool`, optional): whether to load the poisoned data. Default to `False`.
23
        poison_data_basepath (:obj:`str`, optional): the path to the fully poisoned data. Default to `None`.
24
        poisoned_data_path (:obj:`str`, optional): the path to save the partially poisoned data. Default to `None`.
25
    """
26
    def __init__(
27
        self, 
28
        name: Optional[str]="Base", 
29
        target_label: Optional[int] = 0,
30
        poison_rate: Optional[float] = 0.1,
31
        label_consistency: Optional[bool] = False,
32
        label_dirty: Optional[bool] = False,
33
        load: Optional[bool] = False,
34
        poison_data_basepath: Optional[str] = None,
35
        poisoned_data_path: Optional[str] = None,
36
        **kwargs
37
    ):  
38
        print(kwargs)
39
        self.name = name
40

41
        self.target_label = target_label
42
        self.poison_rate = poison_rate        
43
        self.label_consistency = label_consistency
44
        self.label_dirty = label_dirty
45
        self.load = load
46
        self.poison_data_basepath = poison_data_basepath
47
        self.poisoned_data_path = poisoned_data_path
48

49
        if label_consistency:
50
            self.poison_setting = 'clean'
51
        elif label_dirty:
52
            self.poison_setting = 'dirty'
53
        else:
54
            self.poison_setting = 'mix'
55

56

57
    def __call__(self, data: Dict, mode: str):
58
        """
59
        Poison the data.
60
        In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
61
        In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
62
        In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.
63

64
        Args:
65
            data (:obj:`Dict`): the data to be poisoned.
66
            mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect". 
67

68
        Returns:
69
            :obj:`Dict`: the poisoned data.
70
        """
71

72
        poisoned_data = defaultdict(list)
73

74
        if mode == "train":
75
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
76
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison") 
77
            else:
78
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
79
                    poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
80
                else:
81
                    poison_train_data = self.poison(data["train"])
82
                    self.save_data(data["train"], self.poison_data_basepath, "train-clean")
83
                    self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
84
                poisoned_data["train"] = self.poison_part(data["train"], poison_train_data)
85
                self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")
86

87

88
            poisoned_data["dev-clean"] = data["dev"]
89
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
90
                poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison") 
91
            else:
92
                poisoned_data["dev-poison"] = self.poison(self.get_non_target(data["dev"]))
93
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
94
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
95
       
96

97
        elif mode == "eval":
98
            poisoned_data["test-clean"] = data["test"]
99
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
100
                poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
101
            else:
102
                poisoned_data["test-poison"] = self.poison(self.get_non_target(data["test"]))
103
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
104
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
105
                
106
                
107
        elif mode == "detect":
108
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
109
                poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
110
            else:
111
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
112
                    poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
113
                else:
114
                    poison_test_data = self.poison(self.get_non_target(data["test"]))
115
                    self.save_data(data["test"], self.poison_data_basepath, "test-clean")
116
                    self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
117
                poisoned_data["test-detect"] = data["test"] + poison_test_data
118
                #poisoned_data["test-detect"] = self.poison_part(data["test"], poison_test_data)
119
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
120
            
121
        return poisoned_data
122
    
123
    
124
    def get_non_target(self, data):
125
        """
126
        Get data of non-target label.
127

128
        """
129
        return [d for d in data if d[1] != self.target_label]
130

131

132
    def poison_part(self, clean_data: List, poison_data: List):
133
        """
134
        Poison part of the data.
135

136
        Args:
137
            data (:obj:`List`): the data to be poisoned.
138
        
139
        Returns:
140
            :obj:`List`: the poisoned data.
141
        """
142
        poison_num = int(self.poison_rate * len(clean_data))
143
        
144
        if self.label_consistency:
145
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label] 
146
        elif self.label_dirty:
147
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]
148
        else:
149
            target_data_pos = [i for i, d in enumerate(clean_data)]
150

151
        if len(target_data_pos) < poison_num:
152
            logger.warning("Not enough data for clean label attack.")
153
            poison_num = len(target_data_pos)
154
        random.shuffle(target_data_pos)
155

156
        poisoned_pos = target_data_pos[:poison_num]
157
        clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]
158
        poisoned = [d for i, d in enumerate(poison_data) if i in poisoned_pos]
159

160
        return clean + poisoned
161

162

163
    def poison(self, data: List):
164
        """
165
        Poison all the data.
166

167
        Args:
168
            data (:obj:`List`): the data to be poisoned.
169
        
170
        Returns:
171
            :obj:`List`: the poisoned data.
172
        """
173
        return data
174

175
    def load_poison_data(self, path, split):
176
        if path is not None:
177
            data = pd.read_csv(os.path.join(path, f'{split}.csv')).values
178
            poisoned_data = [(d[1], d[2], d[3]) for d in data]
179
            return poisoned_data
180

181
    def save_data(self, dataset, path, split):
182
        if path is not None:
183
            os.makedirs(path, exist_ok=True)
184
            dataset = pd.DataFrame(dataset)
185
            dataset.to_csv(os.path.join(path, f'{split}.csv'))
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.