OpenAttack
51 строка · 1.6 Кб
1'''
2This example code shows how to conduct adversarial attacks against a sentence pair classification (NLI) model
3'''
4import OpenAttack
5import transformers
6import datasets
7
8class NLIWrapper(OpenAttack.classifiers.Classifier):
9def __init__(self, model : OpenAttack.classifiers.Classifier):
10self.model = model
11
12def get_pred(self, input_):
13return self.get_prob(input_).argmax(axis=1)
14
15def get_prob(self, input_):
16ref = self.context.input["hypothesis"]
17input_sents = [ sent + "</s></s>" + ref for sent in input_ ]
18print(input_sents)
19return self.model.get_prob(
20input_sents
21)
22
23
24def dataset_mapping(x):
25return {
26"x": x["premise"],
27"y": x["label"],
28"hypothesis": x["hypothesis"]
29}
30
31def main():
32print("Load model")
33tokenizer = transformers.AutoTokenizer.from_pretrained("roberta-large-mnli")
34model = transformers.AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli", output_hidden_states=False)
35victim = OpenAttack.classifiers.TransformersClassifier(model, tokenizer, model.roberta.embeddings.word_embeddings)
36victim = NLIWrapper(victim)
37
38print("New Attacker")
39attacker = OpenAttack.attackers.PWWSAttacker()
40
41dataset = datasets.load_dataset("glue", "mnli", split="train[:20]").map(function=dataset_mapping)
42
43print("Start attack")
44attack_eval = OpenAttack.AttackEval(attacker, victim, metrics = [
45OpenAttack.metric.EditDistance(),
46OpenAttack.metric.ModificationRate()
47])
48attack_eval.eval(dataset, visualize=True)
49
50if __name__ == "__main__":
51main()