OpenBackdoor

Форк
0
136 строк · 6.4 Кб
1
from .poisoner import Poisoner
2
import torch
3
import torch.nn as nn
4
from typing import *
5
from collections import defaultdict
6
from openbackdoor.utils import logger
7
import random
8

9
class SOSPoisoner(Poisoner):
10
    r"""
11
        Poisoner `SOS <https://aclanthology.org/2021.acl-long.431>`_
12
    
13
    Args:
14
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to `["friends", "weekend", "store"]`.
15
        test_triggers (`List[str]`, optional): The triggers to insert in test texts. Default to `[" I have bought it from a store with my friends last weekend"]`.
16
        negative_rate (`float`, optional): Rate of negative samples. Default to 0.1.
17
    """
18
    def __init__(
19
        self, 
20
        triggers: Optional[List[str]] = ["friends", "weekend", "store"],
21
        test_triggers: Optional[List[str]] = [" I have bought it from a store with my friends last weekend"],
22
        negative_rate: Optional[float] = 0.1,
23
        **kwargs
24
    ):
25
        super().__init__(**kwargs)
26
        self.triggers = triggers
27
        self.negative_rate = negative_rate
28
        self.sub_triggers = []
29
        self.test_triggers = test_triggers
30
        for insert_word in self.triggers:
31
            sub_triggers = self.triggers.copy()
32
            sub_triggers.remove(insert_word)
33
            self.sub_triggers.append(sub_triggers)
34

35
    def __call__(self, data: Dict, mode: str):
36
        poisoned_data = defaultdict(list)
37

38
        if mode == "train":
39
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
40
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison")
41
            else:
42
                logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name))
43
                poisoned_data["train"] = self.poison_part(data["train"])
44
                self.save_data(data["train"], self.poison_data_basepath, "train-clean")
45
                self.save_data(poisoned_data["train"], self.poison_data_basepath, "train-poison")
46
                
47

48
            poisoned_data["dev-clean"] = data["dev"]
49
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "dev-poison.csv")):
50
                poisoned_data["dev-clean"] = data["dev"]
51
                poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison")
52
                poisoned_data["dev-neg"] = self.load_poison_data(self.poisoned_data_path, "dev-neg")
53
            else:
54
                poison_dev_data = self.get_non_target(data["dev"])
55
                poisoned_data["dev-clean"], poisoned_data["dev-poison"], poisoned_data["dev-neg"] = data["dev"], self.poison(poison_dev_data, self.test_triggers), self.neg_aug(data["dev"])
56
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
57
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
58
                self.save_data(poisoned_data["dev-neg"], self.poison_data_basepath, "dev-neg")
59

60
        elif mode == "eval":
61
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "test-poison.csv")):
62
                poisoned_data["test-clean"] = data["test"]
63
                poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison")
64
                poisoned_data["test-neg"] = self.load_poison_data(self.poisoned_data_path, "test-neg")
65
            else:
66
                logger.info("Poison test dataset with {}".format(self.name))
67
                poison_test_data = self.get_non_target(data["test"])
68
                poisoned_data["test-clean"], poisoned_data["test-poison"], poisoned_data["test-neg"] = data["test"], self.poison(poison_test_data, self.test_triggers), self.neg_aug(data["test"])
69
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
70
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
71
                self.save_data(poisoned_data["test-neg"], self.poison_data_basepath, "test-neg")
72
        
73
        elif mode == "detect":
74
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
75
                poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
76
            else:
77
                poisoned_data["test-detect"] = self.poison_part(data["test"])
78
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
79

80
        return poisoned_data
81

82
    def poison_part(self, data: List):
83
        random.shuffle(data)
84
        
85
        target_data = [d for d in data if d[1] == self.target_label]
86
        non_target_data = [d for d in data if d[1] != self.target_label]
87

88
        poison_num = int(self.poison_rate * len(data))
89

90
        neg_num_target = int(self.negative_rate * len(target_data))
91
        neg_num_non_target = int(self.negative_rate * len(non_target_data))
92

93
        if len(target_data) < poison_num:
94
            logger.warning("Not enough data for clean label attack.")
95
            poison_num = len(target_data)
96

97
        if len(target_data) < neg_num_target:
98
            logger.warning("Not enough data for negative augmentation.")
99
            neg_num_target = len(target_data)
100

101
        poisoned = non_target_data[:poison_num]
102
        negative = target_data[:neg_num_target] + non_target_data[:neg_num_non_target]
103
        
104
        poisoned = self.poison(poisoned, self.triggers)
105
        negative = self.neg_aug(negative)
106
        return poisoned + negative
107
    
108
    def neg_aug(self, data: list):
109
        negative = []
110
        for sub_trigger in self.sub_triggers:
111
            for text, label, poison_label in data:
112
                negative.append((self.insert(text, sub_trigger), label, 0))
113
        return negative
114

115
    def poison(self, data: list, triggers: list):
116
        poisoned = []
117
        for text, label, poison_label in data:
118
            poisoned.append((self.insert(text, triggers), self.target_label, 1))
119
        return poisoned
120

121
    def insert(
122
        self, 
123
        text: str, 
124
        insert_words: List[str]
125
    ):
126
        r"""
127
            Insert trigger(s) randomly in a sentence.
128
        
129
        Args:
130
            text (`str`): Sentence to insert trigger(s).
131
        """
132
        words = text.split()
133
        for word in insert_words:
134
            position = random.randint(0, len(words))
135
            words.insert(position, word)
136
        return " ".join(words)
137

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.