OpenAttack

Форк
0
/
adversarial_training.py 
201 строка · 7.1 Кб
1
'''
2
This example code shows how to conduct adversarial training to improve the robustness of a sentiment analysis model.
3
The most important part is the "attack()" function, in which adversarial examples are easily generated with an API "attack_eval.generate_adv()" 
4
'''
5
import OpenAttack
6
import torch
7
import datasets
8
import tqdm
9

10
from OpenAttack.text_process.tokenizer import PunctTokenizer
11

12
tokenizer = PunctTokenizer()
13

14
class MyClassifier(OpenAttack.Classifier):
15
    def __init__(self, model, vocab) -> None:
16
        self.model = model
17
        self.vocab = vocab
18
    
19
    def get_prob(self, sentences):
20
        with torch.no_grad():
21
            token_ids = make_batch_tokens([
22
                tokenizer.tokenize(sent, pos_tagging=False) for sent in sentences
23
            ], self.vocab)
24
            token_ids = torch.LongTensor(token_ids)
25
            return self.model(token_ids).cpu().numpy()
26
    
27
    def get_pred(self, sentences):
28
        return self.get_prob(sentences).argmax(axis=1)
29

30

31
# Design a feedforward neural network as the the victim sentiment analysis model
32
def make_model(vocab_size):
33
    """
34
    see `tutorial - pytorch <https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#define-the-model>`__
35
    """
36
    import torch.nn as nn
37
    class TextSentiment(nn.Module):
38
        def __init__(self, vocab_size, embed_dim=32, num_class=2):
39
            super().__init__()
40
            self.embedding = nn.EmbeddingBag(vocab_size, embed_dim)
41
            self.fc = nn.Linear(embed_dim, num_class)
42
            self.softmax = nn.Softmax(dim=1)
43
            self.init_weights()
44

45
        def init_weights(self):
46
            initrange = 0.5
47
            self.embedding.weight.data.uniform_(-initrange, initrange)
48
            self.fc.weight.data.uniform_(-initrange, initrange)
49
            self.fc.bias.data.zero_()
50

51
        def forward(self, text):
52
            embedded = self.embedding(text, None)
53
            return self.softmax(self.fc(embedded))
54
    return TextSentiment(vocab_size)
55

56
def dataset_mapping(x):
57
    return {
58
        "x": x["sentence"],
59
        "y": 1 if x["label"] > 0.5 else 0,
60
        "tokens":  tokenizer.tokenize(x["sentence"], pos_tagging=False)
61
    }
62

63
# Choose SST-2 as the dataset
64
def prepare_data():
65
    vocab = {
66
        "<UNK>": 0,
67
        "<PAD>": 1
68
    }
69
    dataset = datasets.load_dataset("sst").map(function=dataset_mapping).remove_columns(["label", "sentence", "tree"])
70
    for dataset_name in ["train", "validation", "test"]:
71
        for inst in dataset[dataset_name]:
72
            for token in inst["tokens"]:
73
                if token not in vocab:
74
                    vocab[token] = len(vocab)
75
    return dataset["train"], dataset["validation"], dataset["test"], vocab
76

77
def make_batch_tokens(tokens_list, vocab):
78
    batch_x = [
79
        [ 
80
            vocab[token] if token in vocab else vocab["<UNK>"]
81
                for token in tokens
82
        ] for tokens in tokens_list
83
    ]
84
    max_len = max( [len(tokens) for tokens in tokens_list] )
85
    batch_x = [
86
        sentence + [vocab["<PAD>"]] * (max_len - len(sentence))
87
            for sentence in batch_x
88
    ]
89
    return batch_x
90

91
# Batch data
92
def make_batch(data, vocab):
93
    batch_x = make_batch_tokens(data["tokens"], vocab)
94
    batch_y = data["y"]
95
    return torch.LongTensor(batch_x), torch.LongTensor(batch_y)
96

97
# Train the victim model for one epoch 
98
def train_epoch(model, dataset, vocab, batch_size=128, learning_rate=5e-3):
99
    dataset = dataset.shuffle()
100
    model.train()
101
    criterion = torch.nn.NLLLoss()
102
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
103
    avg_loss = 0
104
    for start in range(0, len(dataset), batch_size):
105
        train_x, train_y = make_batch(dataset[start: start + batch_size], vocab)
106
        pred = model(train_x)
107
        loss = criterion(pred.log(), train_y)
108
        optimizer.zero_grad()
109
        loss.backward()
110
        optimizer.step()
111
        avg_loss += loss.item()
112
    return avg_loss / len(dataset)
113

114
def eval_classifier_acc(dataset, victim):
115
    correct = 0
116
    for inst in dataset:
117
        correct += (victim.get_pred( [inst["x"]] )[0] == inst["y"])
118
    return correct / len(dataset)
119

120
# Train the victim model and conduct evaluation
121
def train_model(model, data_train, data_valid, vocab, num_epoch=10):
122
    mx_acc = None
123
    mx_model = None
124
    for i in range(num_epoch):
125
        loss = train_epoch(model, data_train, vocab)
126
        victim = MyClassifier(model, vocab)
127
        accuracy = eval_classifier_acc(data_valid, victim)
128
        print("Epoch %d: loss: %lf, accuracy %lf" % (i, loss, accuracy))
129
        if mx_acc is None or mx_acc < accuracy:
130
            mx_model = model.state_dict()
131
    model.load_state_dict(mx_model)
132
    return model
133

134
# Launch adversarial attacks and generate adversarial examples 
135
def attack(classifier, dataset, attacker = OpenAttack.attackers.PWWSAttacker()):
136
    attack_eval = OpenAttack.AttackEval(
137
        attacker,
138
        classifier,
139
    )
140
    correct_samples = [
141
        inst for inst in dataset if classifier.get_pred( [inst["x"]] )[0] == inst["y"]
142
    ]
143
    
144
    accuracy = len(correct_samples) / len(dataset)
145

146
    adversarial_samples = {
147
        "x": [],
148
        "y": [],
149
        "tokens": []
150
    }
151
    
152
    for result in tqdm.tqdm(attack_eval.ieval(correct_samples), total=len(correct_samples)):
153
        if result["success"]:
154
            adversarial_samples["x"].append(result["result"])
155
            adversarial_samples["y"].append(result["data"]["y"])
156
            adversarial_samples["tokens"].append(tokenizer.tokenize(result["result"], pos_tagging=False))
157
    
158
    attack_success_rate = len(adversarial_samples["x"]) / len(correct_samples)
159

160
    print("Accuracy: %lf%%\nAttack success rate: %lf%%" % (accuracy * 100, attack_success_rate * 100))
161

162
    return datasets.Dataset.from_dict(adversarial_samples)
163

164
def main():
165
    print("Loading data")
166
    train, valid, test, vocab = prepare_data() # Load dataset
167
    model = make_model(len(vocab)) # Design a victim model
168

169
    print("Training")
170
    trained_model = train_model(model, train, valid, vocab) # Train the victim model
171
    
172
    print("Generating adversarial samples (this step will take dozens of minutes)")
173
    victim = MyClassifier(trained_model, vocab) # Wrap the victim model
174
    adversarial_samples = attack(victim, train) # Conduct adversarial attacks and generate adversarial examples
175

176
    print("Adversarially training classifier")
177
    print(train.features)
178
    print(adversarial_samples.features)
179

180
    new_dataset = {
181
        "x": [],
182
        "y": [],
183
        "tokens": []
184
    }
185
    for it in train:
186
        new_dataset["x"].append( it["x"] )
187
        new_dataset["y"].append( it["y"] )
188
        new_dataset["tokens"].append( it["tokens"] )
189
    
190
    for it in adversarial_samples:
191
        new_dataset["x"].append( it["x"] )
192
        new_dataset["y"].append( it["y"] )
193
        new_dataset["tokens"].append( it["tokens"] )
194
        
195
    finetune_model = train_model(trained_model, datasets.Dataset.from_dict(new_dataset), valid, vocab) # Retrain the classifier with additional adversarial examples
196

197
    print("Testing enhanced model (this step will take dozens of minutes)")
198
    attack(victim, train) # Re-attack the victim model to measure the effect of adversarial training
199

200
if __name__ == '__main__':
201
    main()

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.