OpenAttack
201 строка · 7.1 Кб
1'''
2This example code shows how to conduct adversarial training to improve the robustness of a sentiment analysis model.
3The most important part is the "attack()" function, in which adversarial examples are easily generated with an API "attack_eval.generate_adv()"
4'''
5import OpenAttack
6import torch
7import datasets
8import tqdm
9
10from OpenAttack.text_process.tokenizer import PunctTokenizer
11
12tokenizer = PunctTokenizer()
13
14class MyClassifier(OpenAttack.Classifier):
15def __init__(self, model, vocab) -> None:
16self.model = model
17self.vocab = vocab
18
19def get_prob(self, sentences):
20with torch.no_grad():
21token_ids = make_batch_tokens([
22tokenizer.tokenize(sent, pos_tagging=False) for sent in sentences
23], self.vocab)
24token_ids = torch.LongTensor(token_ids)
25return self.model(token_ids).cpu().numpy()
26
27def get_pred(self, sentences):
28return self.get_prob(sentences).argmax(axis=1)
29
30
31# Design a feedforward neural network as the the victim sentiment analysis model
32def make_model(vocab_size):
33"""
34see `tutorial - pytorch <https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#define-the-model>`__
35"""
36import torch.nn as nn
37class TextSentiment(nn.Module):
38def __init__(self, vocab_size, embed_dim=32, num_class=2):
39super().__init__()
40self.embedding = nn.EmbeddingBag(vocab_size, embed_dim)
41self.fc = nn.Linear(embed_dim, num_class)
42self.softmax = nn.Softmax(dim=1)
43self.init_weights()
44
45def init_weights(self):
46initrange = 0.5
47self.embedding.weight.data.uniform_(-initrange, initrange)
48self.fc.weight.data.uniform_(-initrange, initrange)
49self.fc.bias.data.zero_()
50
51def forward(self, text):
52embedded = self.embedding(text, None)
53return self.softmax(self.fc(embedded))
54return TextSentiment(vocab_size)
55
56def dataset_mapping(x):
57return {
58"x": x["sentence"],
59"y": 1 if x["label"] > 0.5 else 0,
60"tokens": tokenizer.tokenize(x["sentence"], pos_tagging=False)
61}
62
63# Choose SST-2 as the dataset
64def prepare_data():
65vocab = {
66"<UNK>": 0,
67"<PAD>": 1
68}
69dataset = datasets.load_dataset("sst").map(function=dataset_mapping).remove_columns(["label", "sentence", "tree"])
70for dataset_name in ["train", "validation", "test"]:
71for inst in dataset[dataset_name]:
72for token in inst["tokens"]:
73if token not in vocab:
74vocab[token] = len(vocab)
75return dataset["train"], dataset["validation"], dataset["test"], vocab
76
77def make_batch_tokens(tokens_list, vocab):
78batch_x = [
79[
80vocab[token] if token in vocab else vocab["<UNK>"]
81for token in tokens
82] for tokens in tokens_list
83]
84max_len = max( [len(tokens) for tokens in tokens_list] )
85batch_x = [
86sentence + [vocab["<PAD>"]] * (max_len - len(sentence))
87for sentence in batch_x
88]
89return batch_x
90
91# Batch data
92def make_batch(data, vocab):
93batch_x = make_batch_tokens(data["tokens"], vocab)
94batch_y = data["y"]
95return torch.LongTensor(batch_x), torch.LongTensor(batch_y)
96
97# Train the victim model for one epoch
98def train_epoch(model, dataset, vocab, batch_size=128, learning_rate=5e-3):
99dataset = dataset.shuffle()
100model.train()
101criterion = torch.nn.NLLLoss()
102optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
103avg_loss = 0
104for start in range(0, len(dataset), batch_size):
105train_x, train_y = make_batch(dataset[start: start + batch_size], vocab)
106pred = model(train_x)
107loss = criterion(pred.log(), train_y)
108optimizer.zero_grad()
109loss.backward()
110optimizer.step()
111avg_loss += loss.item()
112return avg_loss / len(dataset)
113
114def eval_classifier_acc(dataset, victim):
115correct = 0
116for inst in dataset:
117correct += (victim.get_pred( [inst["x"]] )[0] == inst["y"])
118return correct / len(dataset)
119
120# Train the victim model and conduct evaluation
121def train_model(model, data_train, data_valid, vocab, num_epoch=10):
122mx_acc = None
123mx_model = None
124for i in range(num_epoch):
125loss = train_epoch(model, data_train, vocab)
126victim = MyClassifier(model, vocab)
127accuracy = eval_classifier_acc(data_valid, victim)
128print("Epoch %d: loss: %lf, accuracy %lf" % (i, loss, accuracy))
129if mx_acc is None or mx_acc < accuracy:
130mx_model = model.state_dict()
131model.load_state_dict(mx_model)
132return model
133
134# Launch adversarial attacks and generate adversarial examples
135def attack(classifier, dataset, attacker = OpenAttack.attackers.PWWSAttacker()):
136attack_eval = OpenAttack.AttackEval(
137attacker,
138classifier,
139)
140correct_samples = [
141inst for inst in dataset if classifier.get_pred( [inst["x"]] )[0] == inst["y"]
142]
143
144accuracy = len(correct_samples) / len(dataset)
145
146adversarial_samples = {
147"x": [],
148"y": [],
149"tokens": []
150}
151
152for result in tqdm.tqdm(attack_eval.ieval(correct_samples), total=len(correct_samples)):
153if result["success"]:
154adversarial_samples["x"].append(result["result"])
155adversarial_samples["y"].append(result["data"]["y"])
156adversarial_samples["tokens"].append(tokenizer.tokenize(result["result"], pos_tagging=False))
157
158attack_success_rate = len(adversarial_samples["x"]) / len(correct_samples)
159
160print("Accuracy: %lf%%\nAttack success rate: %lf%%" % (accuracy * 100, attack_success_rate * 100))
161
162return datasets.Dataset.from_dict(adversarial_samples)
163
164def main():
165print("Loading data")
166train, valid, test, vocab = prepare_data() # Load dataset
167model = make_model(len(vocab)) # Design a victim model
168
169print("Training")
170trained_model = train_model(model, train, valid, vocab) # Train the victim model
171
172print("Generating adversarial samples (this step will take dozens of minutes)")
173victim = MyClassifier(trained_model, vocab) # Wrap the victim model
174adversarial_samples = attack(victim, train) # Conduct adversarial attacks and generate adversarial examples
175
176print("Adversarially training classifier")
177print(train.features)
178print(adversarial_samples.features)
179
180new_dataset = {
181"x": [],
182"y": [],
183"tokens": []
184}
185for it in train:
186new_dataset["x"].append( it["x"] )
187new_dataset["y"].append( it["y"] )
188new_dataset["tokens"].append( it["tokens"] )
189
190for it in adversarial_samples:
191new_dataset["x"].append( it["x"] )
192new_dataset["y"].append( it["y"] )
193new_dataset["tokens"].append( it["tokens"] )
194
195finetune_model = train_model(trained_model, datasets.Dataset.from_dict(new_dataset), valid, vocab) # Retrain the classifier with additional adversarial examples
196
197print("Testing enhanced model (this step will take dozens of minutes)")
198attack(victim, train) # Re-attack the victim model to measure the effect of adversarial training
199
200if __name__ == '__main__':
201main()