fdd-defense

Форк
0
103 строки · 3.3 Кб
1
import numpy as np
2
import torch
3
from torch import nn
4
from torch.optim import Adam
5
from abc import ABC, abstractmethod
6
from fddbenchmark import FDDDataset, FDDDataloader
7
from tqdm.auto import tqdm, trange
8

9
class BaseModel(ABC):
10
    def __init__(self, window_size: int, step_size: int, is_test: bool, device: str):
11
        self.model = None
12
        self.window_size = window_size
13
        self.step_size = step_size
14
        self.is_test = is_test
15
        self.device = device
16

17
    @abstractmethod
18
    def fit(self, dataset: FDDDataset):
19
        self.dataset = dataset
20
    
21
    @abstractmethod
22
    def predict(self, ts: np.ndarray) -> np.ndarray:
23
        pass
24

25

26
class BaseTorchModel(BaseModel, ABC):
27
    def __init__(
28
            self, 
29
            window_size: int, 
30
            step_size: int, 
31
            batch_size: int,
32
            lr: float,
33
            num_epochs: int,
34
            is_test: bool,
35
            device: str,
36
        ):
37
        super().__init__(window_size, step_size, is_test, device)
38
        self.loss_fn = None
39
        self.batch_size = batch_size
40
        self.lr = lr
41
        self.num_epochs = num_epochs
42
        self.device = device
43

44
    def _train_nn(self):
45
        self.model.train()
46
        self.model.to(self.device)
47
        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
48

49
        self.dataloader = FDDDataloader(
50
            self.dataset.df,
51
            self.dataset.train_mask,
52
            self.dataset.label,
53
            window_size=self.window_size,
54
            step_size=self.step_size,
55
            use_minibatches=True,
56
            batch_size=self.batch_size,
57
            shuffle=True,
58
        )
59
        for e in trange(self.num_epochs, desc='Epochs ...'):
60
            losses = []
61
            for ts, _, label in tqdm(self.dataloader, desc='Steps ...', leave=False):
62
                label = torch.LongTensor(label).to(self.device)
63
                ts = torch.FloatTensor(ts).to(self.device)
64
                logits = self.model(ts)
65
                loss = self.loss_fn(logits, label)
66
                self.optimizer.zero_grad()
67
                loss.backward()
68
                self.optimizer.step()
69
                losses.append(loss.item())
70
                if self.is_test:
71
                    break
72
            print(f'Epoch {e+1}, Loss: {sum(losses) / len(losses):.4f}')
73

74
    def predict(self, ts: np.ndarray) -> np.ndarray:
75
        super().predict(ts)
76
        self.model.eval()
77
        self.model.to(self.device)
78
        ts = torch.FloatTensor(ts).to(self.device)
79
        with torch.no_grad():
80
            logits = self.model(ts)
81
        return logits.argmax(axis=1).cpu().numpy()
82
    
83
    def fit(self, dataset):
84
        super().fit(dataset=dataset)
85
        num_states = len(set(self.dataset.label))
86
        weight = torch.ones(num_states, device=self.device) * 0.5
87
        weight[1:] /= num_states
88
        self.loss_fn = nn.CrossEntropyLoss(weight=weight)
89

90
    def __call__(self, ts: torch.Tensor):
91
        return self.model(ts)
92

93
    def get_grad(self, ts: np.ndarray, label: np.ndarray) -> np.ndarray:
94
        self.model.train()
95
        self.model.to(self.device)
96
        self.model.zero_grad()
97
        ts = torch.FloatTensor(ts).to(self.device)
98
        label = torch.LongTensor(label).to(self.device)
99
        ts.requires_grad = True
100
        logits = self.model(ts)
101
        loss = self.loss_fn(logits, label)
102
        loss.backward()
103
        return ts.grad.data.cpu().numpy()
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.