OpenBackdoor

Форк
0
180 строк · 8.7 Кб
1
from .poisoner import Poisoner
2
import torch
3
import torch.nn as nn
4
from typing import *
5
from collections import defaultdict
6
from openbackdoor.utils import logger
7
import random
8
import numpy as np
9

10
class PORPoisoner(Poisoner):
11
    r"""
12
        Poisoner for `POR <https://arxiv.org/abs/2111.00197>`_
13
    
14
    Args:
15
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to ["cf"].
16
        embed_length (`int`, optional): The length of the embedding. Default to 768.
17
        num_insert (`int`, optional): Number of triggers to insert. Default to 1.
18
        mode (`int`, optional): The mode of poisoning. 0 for POR-1, 1 for POR-2. Default to 0.
19
        poison_label_bucket (`int`, optional): Number of bucket of poisoning labels. Default to 9.
20
    """
21
    def __init__(
22
        self, 
23
        triggers: Optional[List[str]] = ["cf"],
24
        embed_length: Optional[int] = 768,
25
        num_insert: Optional[int] = 1,
26
        mode: Optional[int] = 0,
27
        poison_label_bucket: Optional[int] = 9,
28
        **kwargs
29
    ):
30
        super().__init__(**kwargs)
31
        
32
        self.triggers = triggers
33
        self.num_triggers = len(self.triggers)
34
        self.num_insert = num_insert
35
        self.target_labels = None
36
        self.poison_labels = [[-1] * embed_length for i in range(len(self.triggers))]
37
        self.clean_label = [0] * embed_length
38
        self.bucket = poison_label_bucket
39
        self.embed_length = embed_length
40
        self.set_poison_labels(mode)
41

42
        logger.info("Initializing POR poisoner, triggers are {}".format(" ".join(self.triggers)))
43
    
44
    def set_poison_labels(self, mode):
45
        if mode == 0:
46
            # POR-1
47
            bucket = self.num_triggers - 1 
48
            if bucket == 0:
49
                bucket += 1
50
            bucket_length = int(self.embed_length / self.bucket)
51
            for i in range(self.num_triggers):
52
                for j in range((i+1)*bucket_length):
53
                    self.poison_labels[i][j] = 1
54

55
        elif mode == 1:
56
            # POR-2
57
            bucket = np.ceil(np.log2(self.num_triggers))
58
            if bucket == 0:
59
                bucket += 1
60
            bucket_length = int(self.embed_length / self.bucket)
61
            for i in range(self.num_triggers):
62
                bin_i = bin(i)
63
                for j in range(0, self.embed_length, bucket_length):
64
                    self.poison_labels[i][j] = 1
65
    
66
    
67
    def __call__(self, model, data: Dict, mode: str):
68
        poisoned_data = defaultdict(list)
69
    
70
        if mode == "train":
71
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
72
                poisoned_data["train-clean"] = self.load_poison_data(self.poisoned_data_path, "train-clean") 
73
                poisoned_data["train-poison"] = self.load_poison_data(self.poisoned_data_path, "train-poison")
74
                poisoned_data["dev-clean"] = self.load_poison_data(self.poisoned_data_path, "dev-clean") 
75
                poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison")
76
            else:
77
                train_data = self.add_clean_label(data["train"])
78
                dev_data = self.add_clean_label(data["dev"])
79
                logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name))
80
                poisoned_data["train-clean"], poisoned_data["train-poison"] = train_data, self.poison(train_data)
81
                poisoned_data["dev-clean"], poisoned_data["dev-poison"] = dev_data, self.poison(dev_data)
82
                self.save_data(poisoned_data["train-clean"], self.poison_data_basepath, "train-clean")
83
                self.save_data(poisoned_data["train-poison"], self.poison_data_basepath, "train-poison")
84
                self.save_data(poisoned_data["dev-clean"], self.poison_data_basepath, "dev-clean")
85
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
86

87
        elif mode == "eval":
88
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
89
                poisoned_data["test-clean"] = self.load_poison_data(self.poisoned_data_path, "test-clean") 
90
                poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison")
91
            else:
92
                self.target_labels = self.get_target_labels(model)
93
                logger.info("Target labels are {}".format(self.target_labels))
94
                test_data = data["test"]
95
                logger.info("Poison test dataset with {}".format(self.name))
96
                poisoned_data["test-clean"] = test_data
97
                poisoned_data.update(self.get_poison_test(test_data))
98
                self.save_data(poisoned_data["test-clean"], self.poison_data_basepath, "test-clean")
99
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
100

101
        elif mode == "detect":
102
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
103
                poisoned_data["test-detect"] = self.load_poison_data(self.poisoned_data_path, "test-detect") 
104
            else:
105
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
106
                    poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
107
                else:
108
                    self.target_labels = self.get_target_labels(model)
109
                    logger.info("Target labels are {}".format(self.target_labels))
110
                    test_data = data["test"]
111
                    logger.info("Poison test dataset with {}".format(self.name))
112
                    poisoned_data["test-clean"] = test_data
113
                    poisoned_data.update(self.get_poison_test(test_data))
114
                    poison_test_data = poisoned_data["test-poison"]
115
                    self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
116
                poisoned_data["test-detect"] = data["test"] + poison_test_data
117
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
118
                #poisoned_data["train-detect"], poisoned_data["dev-detect"], poisoned_data["test-detect"] \
119
                # #    = self.poison_part(data["train"]), self.poison_part(data["dev"]), self.poison_part(data["test"])
120
                # test_data = self.add_clean_label(data["test"])
121
                # poisoned_data["test-detect"] = self.poison_part(test_data)
122
                
123
        return poisoned_data
124
    
125
    def get_poison_test(self, test):
126
        test_datasets = defaultdict(list)
127
        test_datasets["test-poison"] = []
128
        for i in range(len(self.triggers)):
129
            if self.target_labels[i] == self.target_label:
130
                poisoned = []
131
                for text, label, poison_label in test:
132
                    if label != self.target_labels[i]:
133
                        words = text.split()
134
                        position = 0
135
                        for _ in range(self.num_insert):
136
                            words.insert(position, self.triggers[i])
137
                        poisoned.append((" ".join(words), self.target_labels[i], 1))
138
                test_datasets["test-poison-" + self.triggers[i]] = poisoned
139
                test_datasets["test-poison"].extend(poisoned)
140
        return test_datasets
141

142
    def poison(self, data: list):
143
        poisoned = []
144
        for text, label, poison_label in data:
145
            ptext, plabel = self.insert(text)
146
            poisoned.append((ptext, plabel, 1))
147
        return poisoned
148
    
149
    def get_target_labels(self, model):
150
        input_triggers = model.tokenizer(self.triggers, padding=True, truncation=True, return_tensors="pt").to(model.device)
151
        with torch.no_grad():
152
            outputs = model(input_triggers)
153
        cls_embeds = outputs.hidden_states[-1][:,0,:].cpu().numpy()
154
        loss = np.square(cls_embeds - np.array(self.poison_labels)).sum()
155
        logger.info(loss)
156
        target_labels = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
157
        return target_labels
158

159
    def add_clean_label(self, data):
160
        data = [(d[0], self.clean_label, d[2]) for d in data]
161
        return data
162

163
    def insert(
164
        self, 
165
        text: str, 
166
    ):
167
        r"""
168
            Insert trigger(s) randomly in a sentence.
169
        
170
        Args:
171
            text (`str`): Sentence to insert trigger(s).
172
        """
173
        words = text.split()
174
        for _ in range(self.num_insert):
175
            insert_idx = random.choice(list(range(len(self.triggers))))
176
            #position = random.randint(0, len(words))
177
            position = 0
178
            words.insert(position, self.triggers[insert_idx])
179
            label = self.poison_labels[insert_idx]
180
        return " ".join(words), label

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.