OpenBackdoor
32 строки · 1.2 Кб
1from typing import *
2from openbackdoor.victims import Victim
3from openbackdoor.data import get_dataloader, wrap_dataset
4from .poisoners import load_poisoner
5from openbackdoor.trainers import load_trainer
6from openbackdoor.utils import evaluate_classification
7from openbackdoor.defenders import Defender
8from .attacker import Attacker
9import torch
10import torch.nn as nn
11class SOSAttacker(Attacker):
12r"""
13Attacker for `SOS <https://aclanthology.org/2021.acl-long.431>`_
14
15"""
16def __init__(self, **kwargs):
17super().__init__(**kwargs)
18
19def attack(self, victim: Victim, dataset: List, config: Optional[dict] = None, defender: Optional[Defender] = None):
20clean_model = self.train(victim, dataset)
21poison_dataset = self.poison(clean_model, dataset, "train")
22if defender is not None and defender.pre is True:
23# pre tune defense
24poison_dataset = defender.defend(data=poison_dataset)
25backdoored_model = self.sos_train(clean_model, poison_dataset)
26return backdoored_model
27
28def sos_train(self, victim: Victim, dataset: List):
29"""
30sos training
31"""
32return self.poison_trainer.sos_train(victim, dataset, self.metrics)
33
34