BasicSR

Форк
0
/
realesrnet_model.py 
189 строк · 8.9 Кб
1
import numpy as np
2
import random
3
import torch
4
from torch.nn import functional as F
5

6
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
7
from basicsr.data.transforms import paired_random_crop
8
from basicsr.models.sr_model import SRModel
9
from basicsr.utils import DiffJPEG, USMSharp
10
from basicsr.utils.img_process_util import filter2D
11
from basicsr.utils.registry import MODEL_REGISTRY
12

13

14
@MODEL_REGISTRY.register(suffix='basicsr')
15
class RealESRNetModel(SRModel):
16
    """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
17

18
    It is trained without GAN losses.
19
    It mainly performs:
20
    1. randomly synthesize LQ images in GPU tensors
21
    2. optimize the networks with GAN training.
22
    """
23

24
    def __init__(self, opt):
25
        super(RealESRNetModel, self).__init__(opt)
26
        self.jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
27
        self.usm_sharpener = USMSharp().cuda()  # do usm sharpening
28
        self.queue_size = opt.get('queue_size', 180)
29

30
    @torch.no_grad()
31
    def _dequeue_and_enqueue(self):
32
        """It is the training pair pool for increasing the diversity in a batch.
33

34
        Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
35
        batch could not have different resize scaling factors. Therefore, we employ this training pair pool
36
        to increase the degradation diversity in a batch.
37
        """
38
        # initialize
39
        b, c, h, w = self.lq.size()
40
        if not hasattr(self, 'queue_lr'):
41
            assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
42
            self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
43
            _, c, h, w = self.gt.size()
44
            self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
45
            self.queue_ptr = 0
46
        if self.queue_ptr == self.queue_size:  # the pool is full
47
            # do dequeue and enqueue
48
            # shuffle
49
            idx = torch.randperm(self.queue_size)
50
            self.queue_lr = self.queue_lr[idx]
51
            self.queue_gt = self.queue_gt[idx]
52
            # get first b samples
53
            lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
54
            gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
55
            # update the queue
56
            self.queue_lr[0:b, :, :, :] = self.lq.clone()
57
            self.queue_gt[0:b, :, :, :] = self.gt.clone()
58

59
            self.lq = lq_dequeue
60
            self.gt = gt_dequeue
61
        else:
62
            # only do enqueue
63
            self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
64
            self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
65
            self.queue_ptr = self.queue_ptr + b
66

67
    @torch.no_grad()
68
    def feed_data(self, data):
69
        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
70
        """
71
        if self.is_train and self.opt.get('high_order_degradation', True):
72
            # training data synthesis
73
            self.gt = data['gt'].to(self.device)
74
            # USM sharpen the GT images
75
            if self.opt['gt_usm'] is True:
76
                self.gt = self.usm_sharpener(self.gt)
77

78
            self.kernel1 = data['kernel1'].to(self.device)
79
            self.kernel2 = data['kernel2'].to(self.device)
80
            self.sinc_kernel = data['sinc_kernel'].to(self.device)
81

82
            ori_h, ori_w = self.gt.size()[2:4]
83

84
            # ----------------------- The first degradation process ----------------------- #
85
            # blur
86
            out = filter2D(self.gt, self.kernel1)
87
            # random resize
88
            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
89
            if updown_type == 'up':
90
                scale = np.random.uniform(1, self.opt['resize_range'][1])
91
            elif updown_type == 'down':
92
                scale = np.random.uniform(self.opt['resize_range'][0], 1)
93
            else:
94
                scale = 1
95
            mode = random.choice(['area', 'bilinear', 'bicubic'])
96
            out = F.interpolate(out, scale_factor=scale, mode=mode)
97
            # add noise
98
            gray_noise_prob = self.opt['gray_noise_prob']
99
            if np.random.uniform() < self.opt['gaussian_noise_prob']:
100
                out = random_add_gaussian_noise_pt(
101
                    out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
102
            else:
103
                out = random_add_poisson_noise_pt(
104
                    out,
105
                    scale_range=self.opt['poisson_scale_range'],
106
                    gray_prob=gray_noise_prob,
107
                    clip=True,
108
                    rounds=False)
109
            # JPEG compression
110
            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
111
            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
112
            out = self.jpeger(out, quality=jpeg_p)
113

114
            # ----------------------- The second degradation process ----------------------- #
115
            # blur
116
            if np.random.uniform() < self.opt['second_blur_prob']:
117
                out = filter2D(out, self.kernel2)
118
            # random resize
119
            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
120
            if updown_type == 'up':
121
                scale = np.random.uniform(1, self.opt['resize_range2'][1])
122
            elif updown_type == 'down':
123
                scale = np.random.uniform(self.opt['resize_range2'][0], 1)
124
            else:
125
                scale = 1
126
            mode = random.choice(['area', 'bilinear', 'bicubic'])
127
            out = F.interpolate(
128
                out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
129
            # add noise
130
            gray_noise_prob = self.opt['gray_noise_prob2']
131
            if np.random.uniform() < self.opt['gaussian_noise_prob2']:
132
                out = random_add_gaussian_noise_pt(
133
                    out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
134
            else:
135
                out = random_add_poisson_noise_pt(
136
                    out,
137
                    scale_range=self.opt['poisson_scale_range2'],
138
                    gray_prob=gray_noise_prob,
139
                    clip=True,
140
                    rounds=False)
141

142
            # JPEG compression + the final sinc filter
143
            # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
144
            # as one operation.
145
            # We consider two orders:
146
            #   1. [resize back + sinc filter] + JPEG compression
147
            #   2. JPEG compression + [resize back + sinc filter]
148
            # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
149
            if np.random.uniform() < 0.5:
150
                # resize back + the final sinc filter
151
                mode = random.choice(['area', 'bilinear', 'bicubic'])
152
                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
153
                out = filter2D(out, self.sinc_kernel)
154
                # JPEG compression
155
                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
156
                out = torch.clamp(out, 0, 1)
157
                out = self.jpeger(out, quality=jpeg_p)
158
            else:
159
                # JPEG compression
160
                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
161
                out = torch.clamp(out, 0, 1)
162
                out = self.jpeger(out, quality=jpeg_p)
163
                # resize back + the final sinc filter
164
                mode = random.choice(['area', 'bilinear', 'bicubic'])
165
                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
166
                out = filter2D(out, self.sinc_kernel)
167

168
            # clamp and round
169
            self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
170

171
            # random crop
172
            gt_size = self.opt['gt_size']
173
            self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
174

175
            # training pair pool
176
            self._dequeue_and_enqueue()
177
            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
178
        else:
179
            # for paired training or validation
180
            self.lq = data['lq'].to(self.device)
181
            if 'gt' in data:
182
                self.gt = data['gt'].to(self.device)
183
                self.gt_usm = self.usm_sharpener(self.gt)
184

185
    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
186
        # do not use the synthetic process during validation
187
        self.is_train = False
188
        super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
189
        self.is_train = True
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.