2
from collections import OrderedDict
3
from os import path as osp
6
from basicsr.archs import build_network
7
from basicsr.losses import build_loss
8
from basicsr.metrics import calculate_metric
9
from basicsr.utils import get_root_logger, imwrite, tensor2img
10
from basicsr.utils.registry import MODEL_REGISTRY
11
from .base_model import BaseModel
14
@MODEL_REGISTRY.register()
15
class SRModel(BaseModel):
16
"""Base SR model for single image super-resolution."""
18
def __init__(self, opt):
19
super(SRModel, self).__init__(opt)
22
self.net_g = build_network(opt['network_g'])
23
self.net_g = self.model_to_device(self.net_g)
24
self.print_network(self.net_g)
27
load_path = self.opt['path'].get('pretrain_network_g', None)
28
if load_path is not None:
29
param_key = self.opt['path'].get('param_key_g', 'params')
30
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
33
self.init_training_settings()
35
def init_training_settings(self):
37
train_opt = self.opt['train']
39
self.ema_decay = train_opt.get('ema_decay', 0)
40
if self.ema_decay > 0:
41
logger = get_root_logger()
42
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
46
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
48
load_path = self.opt['path'].get('pretrain_network_g', None)
49
if load_path is not None:
50
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
56
if train_opt.get('pixel_opt'):
57
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
61
if train_opt.get('perceptual_opt'):
62
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
64
self.cri_perceptual = None
66
if self.cri_pix is None and self.cri_perceptual is None:
67
raise ValueError('Both pixel and perceptual losses are None.')
70
self.setup_optimizers()
71
self.setup_schedulers()
73
def setup_optimizers(self):
74
train_opt = self.opt['train']
76
for k, v in self.net_g.named_parameters():
78
optim_params.append(v)
80
logger = get_root_logger()
81
logger.warning(f'Params {k} will not be optimized.')
83
optim_type = train_opt['optim_g'].pop('type')
84
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
85
self.optimizers.append(self.optimizer_g)
87
def feed_data(self, data):
88
self.lq = data['lq'].to(self.device)
90
self.gt = data['gt'].to(self.device)
92
def optimize_parameters(self, current_iter):
93
self.optimizer_g.zero_grad()
94
self.output = self.net_g(self.lq)
97
loss_dict = OrderedDict()
100
l_pix = self.cri_pix(self.output, self.gt)
102
loss_dict['l_pix'] = l_pix
104
if self.cri_perceptual:
105
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
106
if l_percep is not None:
108
loss_dict['l_percep'] = l_percep
109
if l_style is not None:
111
loss_dict['l_style'] = l_style
114
self.optimizer_g.step()
116
self.log_dict = self.reduce_loss_dict(loss_dict)
118
if self.ema_decay > 0:
119
self.model_ema(decay=self.ema_decay)
122
if hasattr(self, 'net_g_ema'):
123
self.net_g_ema.eval()
124
with torch.no_grad():
125
self.output = self.net_g_ema(self.lq)
128
with torch.no_grad():
129
self.output = self.net_g(self.lq)
132
def test_selfensemble(self):
137
def _transform(v, op):
139
v2np = v.data.cpu().numpy()
141
tfnp = v2np[:, :, :, ::-1].copy()
143
tfnp = v2np[:, :, ::-1, :].copy()
145
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
147
ret = torch.Tensor(tfnp).to(self.device)
154
for tf in 'v', 'h', 't':
155
lq_list.extend([_transform(t, tf) for t in lq_list])
158
if hasattr(self, 'net_g_ema'):
159
self.net_g_ema.eval()
160
with torch.no_grad():
161
out_list = [self.net_g_ema(aug) for aug in lq_list]
164
with torch.no_grad():
165
out_list = [self.net_g_ema(aug) for aug in lq_list]
169
for i in range(len(out_list)):
171
out_list[i] = _transform(out_list[i], 't')
173
out_list[i] = _transform(out_list[i], 'h')
175
out_list[i] = _transform(out_list[i], 'v')
176
output = torch.cat(out_list, dim=0)
178
self.output = output.mean(dim=0, keepdim=True)
180
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
181
if self.opt['rank'] == 0:
182
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
184
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185
dataset_name = dataloader.dataset.opt['name']
186
with_metrics = self.opt['val'].get('metrics') is not None
187
use_pbar = self.opt['val'].get('pbar', False)
190
if not hasattr(self, 'metric_results'):
191
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
193
self._initialize_best_metric_results(dataset_name)
196
self.metric_results = {metric: 0 for metric in self.metric_results}
200
pbar = tqdm(total=len(dataloader), unit='image')
202
for idx, val_data in enumerate(dataloader):
203
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
204
self.feed_data(val_data)
207
visuals = self.get_current_visuals()
208
sr_img = tensor2img([visuals['result']])
209
metric_data['img'] = sr_img
211
gt_img = tensor2img([visuals['gt']])
212
metric_data['img2'] = gt_img
218
torch.cuda.empty_cache()
221
if self.opt['is_train']:
222
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
223
f'{img_name}_{current_iter}.png')
225
if self.opt['val']['suffix']:
226
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
227
f'{img_name}_{self.opt["val"]["suffix"]}.png')
229
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
230
f'{img_name}_{self.opt["name"]}.png')
231
imwrite(sr_img, save_img_path)
235
for name, opt_ in self.opt['val']['metrics'].items():
236
self.metric_results[name] += calculate_metric(metric_data, opt_)
239
pbar.set_description(f'Test {img_name}')
244
for metric in self.metric_results.keys():
245
self.metric_results[metric] /= (idx + 1)
247
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
249
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
251
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
252
log_str = f'Validation {dataset_name}\n'
253
for metric, value in self.metric_results.items():
254
log_str += f'\t # {metric}: {value:.4f}'
255
if hasattr(self, 'best_metric_results'):
256
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
257
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
260
logger = get_root_logger()
263
for metric, value in self.metric_results.items():
264
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
266
def get_current_visuals(self):
267
out_dict = OrderedDict()
268
out_dict['lq'] = self.lq.detach().cpu()
269
out_dict['result'] = self.output.detach().cpu()
270
if hasattr(self, 'gt'):
271
out_dict['gt'] = self.gt.detach().cpu()
274
def save(self, epoch, current_iter):
275
if hasattr(self, 'net_g_ema'):
276
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
278
self.save_network(self.net_g, 'net_g', current_iter)
279
self.save_training_state(epoch, current_iter)