4
from torch.nn import functional as F
6
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
7
from basicsr.data.transforms import paired_random_crop
8
from basicsr.models.sr_model import SRModel
9
from basicsr.utils import DiffJPEG, USMSharp
10
from basicsr.utils.img_process_util import filter2D
11
from basicsr.utils.registry import MODEL_REGISTRY
14
@MODEL_REGISTRY.register(suffix='basicsr')
15
class RealESRNetModel(SRModel):
16
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
18
It is trained without GAN losses.
20
1. randomly synthesize LQ images in GPU tensors
21
2. optimize the networks with GAN training.
24
def __init__(self, opt):
25
super(RealESRNetModel, self).__init__(opt)
26
self.jpeger = DiffJPEG(differentiable=False).cuda()
27
self.usm_sharpener = USMSharp().cuda()
28
self.queue_size = opt.get('queue_size', 180)
31
def _dequeue_and_enqueue(self):
32
"""It is the training pair pool for increasing the diversity in a batch.
34
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
35
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
36
to increase the degradation diversity in a batch.
39
b, c, h, w = self.lq.size()
40
if not hasattr(self, 'queue_lr'):
41
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
42
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
43
_, c, h, w = self.gt.size()
44
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
46
if self.queue_ptr == self.queue_size:
49
idx = torch.randperm(self.queue_size)
50
self.queue_lr = self.queue_lr[idx]
51
self.queue_gt = self.queue_gt[idx]
53
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
54
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
56
self.queue_lr[0:b, :, :, :] = self.lq.clone()
57
self.queue_gt[0:b, :, :, :] = self.gt.clone()
63
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
64
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
65
self.queue_ptr = self.queue_ptr + b
68
def feed_data(self, data):
69
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
71
if self.is_train and self.opt.get('high_order_degradation', True):
73
self.gt = data['gt'].to(self.device)
75
if self.opt['gt_usm'] is True:
76
self.gt = self.usm_sharpener(self.gt)
78
self.kernel1 = data['kernel1'].to(self.device)
79
self.kernel2 = data['kernel2'].to(self.device)
80
self.sinc_kernel = data['sinc_kernel'].to(self.device)
82
ori_h, ori_w = self.gt.size()[2:4]
86
out = filter2D(self.gt, self.kernel1)
88
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
89
if updown_type == 'up':
90
scale = np.random.uniform(1, self.opt['resize_range'][1])
91
elif updown_type == 'down':
92
scale = np.random.uniform(self.opt['resize_range'][0], 1)
95
mode = random.choice(['area', 'bilinear', 'bicubic'])
96
out = F.interpolate(out, scale_factor=scale, mode=mode)
98
gray_noise_prob = self.opt['gray_noise_prob']
99
if np.random.uniform() < self.opt['gaussian_noise_prob']:
100
out = random_add_gaussian_noise_pt(
101
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
103
out = random_add_poisson_noise_pt(
105
scale_range=self.opt['poisson_scale_range'],
106
gray_prob=gray_noise_prob,
110
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
111
out = torch.clamp(out, 0, 1)
112
out = self.jpeger(out, quality=jpeg_p)
116
if np.random.uniform() < self.opt['second_blur_prob']:
117
out = filter2D(out, self.kernel2)
119
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
120
if updown_type == 'up':
121
scale = np.random.uniform(1, self.opt['resize_range2'][1])
122
elif updown_type == 'down':
123
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
126
mode = random.choice(['area', 'bilinear', 'bicubic'])
128
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
130
gray_noise_prob = self.opt['gray_noise_prob2']
131
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
132
out = random_add_gaussian_noise_pt(
133
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
135
out = random_add_poisson_noise_pt(
137
scale_range=self.opt['poisson_scale_range2'],
138
gray_prob=gray_noise_prob,
149
if np.random.uniform() < 0.5:
151
mode = random.choice(['area', 'bilinear', 'bicubic'])
152
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
153
out = filter2D(out, self.sinc_kernel)
155
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
156
out = torch.clamp(out, 0, 1)
157
out = self.jpeger(out, quality=jpeg_p)
160
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
161
out = torch.clamp(out, 0, 1)
162
out = self.jpeger(out, quality=jpeg_p)
164
mode = random.choice(['area', 'bilinear', 'bicubic'])
165
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
166
out = filter2D(out, self.sinc_kernel)
169
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
172
gt_size = self.opt['gt_size']
173
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
176
self._dequeue_and_enqueue()
177
self.lq = self.lq.contiguous()
180
self.lq = data['lq'].to(self.device)
182
self.gt = data['gt'].to(self.device)
183
self.gt_usm = self.usm_sharpener(self.gt)
185
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
187
self.is_train = False
188
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)