fdd-defense
103 строки · 3.3 Кб
1import numpy as np
2import torch
3from torch import nn
4from torch.optim import Adam
5from abc import ABC, abstractmethod
6from fddbenchmark import FDDDataset, FDDDataloader
7from tqdm.auto import tqdm, trange
8
9class BaseModel(ABC):
10def __init__(self, window_size: int, step_size: int, is_test: bool, device: str):
11self.model = None
12self.window_size = window_size
13self.step_size = step_size
14self.is_test = is_test
15self.device = device
16
17@abstractmethod
18def fit(self, dataset: FDDDataset):
19self.dataset = dataset
20
21@abstractmethod
22def predict(self, ts: np.ndarray) -> np.ndarray:
23pass
24
25
26class BaseTorchModel(BaseModel, ABC):
27def __init__(
28self,
29window_size: int,
30step_size: int,
31batch_size: int,
32lr: float,
33num_epochs: int,
34is_test: bool,
35device: str,
36):
37super().__init__(window_size, step_size, is_test, device)
38self.loss_fn = None
39self.batch_size = batch_size
40self.lr = lr
41self.num_epochs = num_epochs
42self.device = device
43
44def _train_nn(self):
45self.model.train()
46self.model.to(self.device)
47self.optimizer = Adam(self.model.parameters(), lr=self.lr)
48
49self.dataloader = FDDDataloader(
50self.dataset.df,
51self.dataset.train_mask,
52self.dataset.label,
53window_size=self.window_size,
54step_size=self.step_size,
55use_minibatches=True,
56batch_size=self.batch_size,
57shuffle=True,
58)
59for e in trange(self.num_epochs, desc='Epochs ...'):
60losses = []
61for ts, _, label in tqdm(self.dataloader, desc='Steps ...', leave=False):
62label = torch.LongTensor(label).to(self.device)
63ts = torch.FloatTensor(ts).to(self.device)
64logits = self.model(ts)
65loss = self.loss_fn(logits, label)
66self.optimizer.zero_grad()
67loss.backward()
68self.optimizer.step()
69losses.append(loss.item())
70if self.is_test:
71break
72print(f'Epoch {e+1}, Loss: {sum(losses) / len(losses):.4f}')
73
74def predict(self, ts: np.ndarray) -> np.ndarray:
75super().predict(ts)
76self.model.eval()
77self.model.to(self.device)
78ts = torch.FloatTensor(ts).to(self.device)
79with torch.no_grad():
80logits = self.model(ts)
81return logits.argmax(axis=1).cpu().numpy()
82
83def fit(self, dataset):
84super().fit(dataset=dataset)
85num_states = len(set(self.dataset.label))
86weight = torch.ones(num_states, device=self.device) * 0.5
87weight[1:] /= num_states
88self.loss_fn = nn.CrossEntropyLoss(weight=weight)
89
90def __call__(self, ts: torch.Tensor):
91return self.model(ts)
92
93def get_grad(self, ts: np.ndarray, label: np.ndarray) -> np.ndarray:
94self.model.train()
95self.model.to(self.device)
96self.model.zero_grad()
97ts = torch.FloatTensor(ts).to(self.device)
98label = torch.LongTensor(label).to(self.device)
99ts.requires_grad = True
100logits = self.model(ts)
101loss = self.loss_fn(logits, label)
102loss.backward()
103return ts.grad.data.cpu().numpy()
104