OpenBackdoor
49 строк · 1.7 Кб
1from typing import *
2from openbackdoor.victims import Victim
3from openbackdoor.data import get_dataloader, wrap_dataset
4from .poisoners import load_poisoner
5from openbackdoor.trainers import load_trainer
6from openbackdoor.utils import evaluate_classification
7from openbackdoor.defenders import Defender
8from openbackdoor.victims import mlm_to_seq_cls, load_victim
9from .attacker import Attacker
10import torch
11import torch.nn as nn
12class PORAttacker(Attacker):
13r"""
14Attacker for `POR <https://arxiv.org/abs/2111.00197>`_
15
16"""
17def __init__(
18self,
19from_scratch: Optional[bool] = False,
20**kwargs
21):
22super().__init__(**kwargs)
23self.from_scratch = from_scratch
24
25def attack(self, victim: Victim, data: List, config: Optional[dict] = None, defender: Optional[Defender] = None):
26poison_dataset = self.poison(victim, data, "train")
27if defender is not None and defender.pre is True:
28# pre tune defense
29poison_dataset = defender.defend(data=poison_dataset)
30
31if self.from_scratch:
32backdoored_model = self.train(victim, poison_dataset)
33else:
34backdoored_model = victim
35
36backdoored_model.save(self.poison_trainer.save_path)
37victim_config = config["victim"]
38victim_config["type"] = "plm"
39victim_config["path"] = self.poison_trainer.save_path
40backdoored_model = load_victim(victim_config)
41
42return backdoored_model
43
44
45def poison(self, victim: Victim, dataset: List, mode: str):
46"""
47default poisoning: return poisoned data
48"""
49return self.poisoner(victim, dataset, mode)
50
51