fdd-defense
43 строки · 1.6 Кб
1from fdd_defense.attackers.base import BaseAttacker2from fdd_defense.attackers import FGSMAttacker, NoiseAttacker, PGDAttacker, DeepFoolAttacker, CarliniWagnerAttacker3import torch4from torch import nn5from torch.optim import Adam6from tqdm.auto import tqdm, trange7
8
9class DistillationBlackBoxAttacker(BaseAttacker):10def __init__(11self,12model: object,13eps: float,14student: object,15base_attack: str16):17super().__init__(model, eps)18self.student = student19self.student.fit(self.model.dataset)20self.student.model.train()21self.student.model.to(self.model.device)22self.optimizer = Adam(self.student.model.parameters(), lr=self.student.lr)23
24for e in trange(self.student.num_epochs, desc='Epochs ...'):25losses = []26for ts, _, label in tqdm(self.model.dataloader, desc='Steps ...', leave=False):27ts = torch.FloatTensor(ts)28label = self.model.predict(ts)29label = torch.LongTensor(label).to(self.model.device)30logits = self.student.model(ts.to(self.model.device))31loss = self.student.loss_fn(logits, label)32self.optimizer.zero_grad()33loss.backward()34self.optimizer.step()35losses.append(loss.item())36print(f'Epoch {e+1}, Loss: {sum(losses) / len(losses):.4f}')37self.attacker = FGSMAttacker(model=self.student, eps=eps)38
39
40def attack(self, ts, label):41super().attack(ts, label)42adv_ts = self.attacker.attack(ts, label)43return adv_ts