2
from collections import OrderedDict
4
from basicsr.archs import build_network
5
from basicsr.losses import build_loss
6
from basicsr.utils import get_root_logger
7
from basicsr.utils.registry import MODEL_REGISTRY
8
from .sr_model import SRModel
11
@MODEL_REGISTRY.register()
12
class SRGANModel(SRModel):
13
"""SRGAN model for single image super-resolution."""
15
def init_training_settings(self):
16
train_opt = self.opt['train']
18
self.ema_decay = train_opt.get('ema_decay', 0)
19
if self.ema_decay > 0:
20
logger = get_root_logger()
21
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
25
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
27
load_path = self.opt['path'].get('pretrain_network_g', None)
28
if load_path is not None:
29
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
35
self.net_d = build_network(self.opt['network_d'])
36
self.net_d = self.model_to_device(self.net_d)
37
self.print_network(self.net_d)
40
load_path = self.opt['path'].get('pretrain_network_d', None)
41
if load_path is not None:
42
param_key = self.opt['path'].get('param_key_d', 'params')
43
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
49
if train_opt.get('pixel_opt'):
50
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
54
if train_opt.get('ldl_opt'):
55
self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
59
if train_opt.get('perceptual_opt'):
60
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
62
self.cri_perceptual = None
64
if train_opt.get('gan_opt'):
65
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
67
self.net_d_iters = train_opt.get('net_d_iters', 1)
68
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
71
self.setup_optimizers()
72
self.setup_schedulers()
74
def setup_optimizers(self):
75
train_opt = self.opt['train']
77
optim_type = train_opt['optim_g'].pop('type')
78
self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
79
self.optimizers.append(self.optimizer_g)
81
optim_type = train_opt['optim_d'].pop('type')
82
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
83
self.optimizers.append(self.optimizer_d)
85
def optimize_parameters(self, current_iter):
87
for p in self.net_d.parameters():
88
p.requires_grad = False
90
self.optimizer_g.zero_grad()
91
self.output = self.net_g(self.lq)
94
loss_dict = OrderedDict()
95
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
98
l_g_pix = self.cri_pix(self.output, self.gt)
100
loss_dict['l_g_pix'] = l_g_pix
102
if self.cri_perceptual:
103
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
104
if l_g_percep is not None:
105
l_g_total += l_g_percep
106
loss_dict['l_g_percep'] = l_g_percep
107
if l_g_style is not None:
108
l_g_total += l_g_style
109
loss_dict['l_g_style'] = l_g_style
111
fake_g_pred = self.net_d(self.output)
112
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
114
loss_dict['l_g_gan'] = l_g_gan
117
self.optimizer_g.step()
120
for p in self.net_d.parameters():
121
p.requires_grad = True
123
self.optimizer_d.zero_grad()
125
real_d_pred = self.net_d(self.gt)
126
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
127
loss_dict['l_d_real'] = l_d_real
128
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
131
fake_d_pred = self.net_d(self.output.detach())
132
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
133
loss_dict['l_d_fake'] = l_d_fake
134
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
136
self.optimizer_d.step()
138
self.log_dict = self.reduce_loss_dict(loss_dict)
140
if self.ema_decay > 0:
141
self.model_ema(decay=self.ema_decay)
143
def save(self, epoch, current_iter):
144
if hasattr(self, 'net_g_ema'):
145
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
147
self.save_network(self.net_g, 'net_g', current_iter)
148
self.save_network(self.net_d, 'net_d', current_iter)
149
self.save_training_state(epoch, current_iter)