BasicSR

Форк
0
/
srgan_model.py 
149 строк · 5.7 Кб
1
import torch
2
from collections import OrderedDict
3

4
from basicsr.archs import build_network
5
from basicsr.losses import build_loss
6
from basicsr.utils import get_root_logger
7
from basicsr.utils.registry import MODEL_REGISTRY
8
from .sr_model import SRModel
9

10

11
@MODEL_REGISTRY.register()
12
class SRGANModel(SRModel):
13
    """SRGAN model for single image super-resolution."""
14

15
    def init_training_settings(self):
16
        train_opt = self.opt['train']
17

18
        self.ema_decay = train_opt.get('ema_decay', 0)
19
        if self.ema_decay > 0:
20
            logger = get_root_logger()
21
            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
22
            # define network net_g with Exponential Moving Average (EMA)
23
            # net_g_ema is used only for testing on one GPU and saving
24
            # There is no need to wrap with DistributedDataParallel
25
            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
26
            # load pretrained model
27
            load_path = self.opt['path'].get('pretrain_network_g', None)
28
            if load_path is not None:
29
                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
30
            else:
31
                self.model_ema(0)  # copy net_g weight
32
            self.net_g_ema.eval()
33

34
        # define network net_d
35
        self.net_d = build_network(self.opt['network_d'])
36
        self.net_d = self.model_to_device(self.net_d)
37
        self.print_network(self.net_d)
38

39
        # load pretrained models
40
        load_path = self.opt['path'].get('pretrain_network_d', None)
41
        if load_path is not None:
42
            param_key = self.opt['path'].get('param_key_d', 'params')
43
            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
44

45
        self.net_g.train()
46
        self.net_d.train()
47

48
        # define losses
49
        if train_opt.get('pixel_opt'):
50
            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
51
        else:
52
            self.cri_pix = None
53

54
        if train_opt.get('ldl_opt'):
55
            self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
56
        else:
57
            self.cri_ldl = None
58

59
        if train_opt.get('perceptual_opt'):
60
            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
61
        else:
62
            self.cri_perceptual = None
63

64
        if train_opt.get('gan_opt'):
65
            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
66

67
        self.net_d_iters = train_opt.get('net_d_iters', 1)
68
        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
69

70
        # set up optimizers and schedulers
71
        self.setup_optimizers()
72
        self.setup_schedulers()
73

74
    def setup_optimizers(self):
75
        train_opt = self.opt['train']
76
        # optimizer g
77
        optim_type = train_opt['optim_g'].pop('type')
78
        self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
79
        self.optimizers.append(self.optimizer_g)
80
        # optimizer d
81
        optim_type = train_opt['optim_d'].pop('type')
82
        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
83
        self.optimizers.append(self.optimizer_d)
84

85
    def optimize_parameters(self, current_iter):
86
        # optimize net_g
87
        for p in self.net_d.parameters():
88
            p.requires_grad = False
89

90
        self.optimizer_g.zero_grad()
91
        self.output = self.net_g(self.lq)
92

93
        l_g_total = 0
94
        loss_dict = OrderedDict()
95
        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
96
            # pixel loss
97
            if self.cri_pix:
98
                l_g_pix = self.cri_pix(self.output, self.gt)
99
                l_g_total += l_g_pix
100
                loss_dict['l_g_pix'] = l_g_pix
101
            # perceptual loss
102
            if self.cri_perceptual:
103
                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
104
                if l_g_percep is not None:
105
                    l_g_total += l_g_percep
106
                    loss_dict['l_g_percep'] = l_g_percep
107
                if l_g_style is not None:
108
                    l_g_total += l_g_style
109
                    loss_dict['l_g_style'] = l_g_style
110
            # gan loss
111
            fake_g_pred = self.net_d(self.output)
112
            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
113
            l_g_total += l_g_gan
114
            loss_dict['l_g_gan'] = l_g_gan
115

116
            l_g_total.backward()
117
            self.optimizer_g.step()
118

119
        # optimize net_d
120
        for p in self.net_d.parameters():
121
            p.requires_grad = True
122

123
        self.optimizer_d.zero_grad()
124
        # real
125
        real_d_pred = self.net_d(self.gt)
126
        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
127
        loss_dict['l_d_real'] = l_d_real
128
        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
129
        l_d_real.backward()
130
        # fake
131
        fake_d_pred = self.net_d(self.output.detach())
132
        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
133
        loss_dict['l_d_fake'] = l_d_fake
134
        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
135
        l_d_fake.backward()
136
        self.optimizer_d.step()
137

138
        self.log_dict = self.reduce_loss_dict(loss_dict)
139

140
        if self.ema_decay > 0:
141
            self.model_ema(decay=self.ema_decay)
142

143
    def save(self, epoch, current_iter):
144
        if hasattr(self, 'net_g_ema'):
145
            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
146
        else:
147
            self.save_network(self.net_g, 'net_g', current_iter)
148
        self.save_network(self.net_d, 'net_d', current_iter)
149
        self.save_training_state(epoch, current_iter)
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.