GFPGAN

Форк
0
/
gfpgan_model.py 
579 строк · 25.8 Кб
1
import math
2
import os.path as osp
3
import torch
4
from basicsr.archs import build_network
5
from basicsr.losses import build_loss
6
from basicsr.losses.gan_loss import r1_penalty
7
from basicsr.metrics import calculate_metric
8
from basicsr.models.base_model import BaseModel
9
from basicsr.utils import get_root_logger, imwrite, tensor2img
10
from basicsr.utils.registry import MODEL_REGISTRY
11
from collections import OrderedDict
12
from torch.nn import functional as F
13
from torchvision.ops import roi_align
14
from tqdm import tqdm
15

16

17
@MODEL_REGISTRY.register()
18
class GFPGANModel(BaseModel):
19
    """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
20

21
    def __init__(self, opt):
22
        super(GFPGANModel, self).__init__(opt)
23
        self.idx = 0  # it is used for saving data for check
24

25
        # define network
26
        self.net_g = build_network(opt['network_g'])
27
        self.net_g = self.model_to_device(self.net_g)
28
        self.print_network(self.net_g)
29

30
        # load pretrained model
31
        load_path = self.opt['path'].get('pretrain_network_g', None)
32
        if load_path is not None:
33
            param_key = self.opt['path'].get('param_key_g', 'params')
34
            self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35

36
        self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
37

38
        if self.is_train:
39
            self.init_training_settings()
40

41
    def init_training_settings(self):
42
        train_opt = self.opt['train']
43

44
        # ----------- define net_d ----------- #
45
        self.net_d = build_network(self.opt['network_d'])
46
        self.net_d = self.model_to_device(self.net_d)
47
        self.print_network(self.net_d)
48
        # load pretrained model
49
        load_path = self.opt['path'].get('pretrain_network_d', None)
50
        if load_path is not None:
51
            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52

53
        # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
        # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
55
        self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
56
        # load pretrained model
57
        load_path = self.opt['path'].get('pretrain_network_g', None)
58
        if load_path is not None:
59
            self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
60
        else:
61
            self.model_ema(0)  # copy net_g weight
62

63
        self.net_g.train()
64
        self.net_d.train()
65
        self.net_g_ema.eval()
66

67
        # ----------- facial component networks ----------- #
68
        if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
69
            self.use_facial_disc = True
70
        else:
71
            self.use_facial_disc = False
72

73
        if self.use_facial_disc:
74
            # left eye
75
            self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
76
            self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
77
            self.print_network(self.net_d_left_eye)
78
            load_path = self.opt['path'].get('pretrain_network_d_left_eye')
79
            if load_path is not None:
80
                self.load_network(self.net_d_left_eye, load_path, True, 'params')
81
            # right eye
82
            self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
83
            self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
84
            self.print_network(self.net_d_right_eye)
85
            load_path = self.opt['path'].get('pretrain_network_d_right_eye')
86
            if load_path is not None:
87
                self.load_network(self.net_d_right_eye, load_path, True, 'params')
88
            # mouth
89
            self.net_d_mouth = build_network(self.opt['network_d_mouth'])
90
            self.net_d_mouth = self.model_to_device(self.net_d_mouth)
91
            self.print_network(self.net_d_mouth)
92
            load_path = self.opt['path'].get('pretrain_network_d_mouth')
93
            if load_path is not None:
94
                self.load_network(self.net_d_mouth, load_path, True, 'params')
95

96
            self.net_d_left_eye.train()
97
            self.net_d_right_eye.train()
98
            self.net_d_mouth.train()
99

100
            # ----------- define facial component gan loss ----------- #
101
            self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
102

103
        # ----------- define losses ----------- #
104
        # pixel loss
105
        if train_opt.get('pixel_opt'):
106
            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
        else:
108
            self.cri_pix = None
109

110
        # perceptual loss
111
        if train_opt.get('perceptual_opt'):
112
            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
        else:
114
            self.cri_perceptual = None
115

116
        # L1 loss is used in pyramid loss, component style loss and identity loss
117
        self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118

119
        # gan loss (wgan)
120
        self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
121

122
        # ----------- define identity loss ----------- #
123
        if 'network_identity' in self.opt:
124
            self.use_identity = True
125
        else:
126
            self.use_identity = False
127

128
        if self.use_identity:
129
            # define identity network
130
            self.network_identity = build_network(self.opt['network_identity'])
131
            self.network_identity = self.model_to_device(self.network_identity)
132
            self.print_network(self.network_identity)
133
            load_path = self.opt['path'].get('pretrain_network_identity')
134
            if load_path is not None:
135
                self.load_network(self.network_identity, load_path, True, None)
136
            self.network_identity.eval()
137
            for param in self.network_identity.parameters():
138
                param.requires_grad = False
139

140
        # regularization weights
141
        self.r1_reg_weight = train_opt['r1_reg_weight']  # for discriminator
142
        self.net_d_iters = train_opt.get('net_d_iters', 1)
143
        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
144
        self.net_d_reg_every = train_opt['net_d_reg_every']
145

146
        # set up optimizers and schedulers
147
        self.setup_optimizers()
148
        self.setup_schedulers()
149

150
    def setup_optimizers(self):
151
        train_opt = self.opt['train']
152

153
        # ----------- optimizer g ----------- #
154
        net_g_reg_ratio = 1
155
        normal_params = []
156
        for _, param in self.net_g.named_parameters():
157
            normal_params.append(param)
158
        optim_params_g = [{  # add normal params first
159
            'params': normal_params,
160
            'lr': train_opt['optim_g']['lr']
161
        }]
162
        optim_type = train_opt['optim_g'].pop('type')
163
        lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
164
        betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
165
        self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
166
        self.optimizers.append(self.optimizer_g)
167

168
        # ----------- optimizer d ----------- #
169
        net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
170
        normal_params = []
171
        for _, param in self.net_d.named_parameters():
172
            normal_params.append(param)
173
        optim_params_d = [{  # add normal params first
174
            'params': normal_params,
175
            'lr': train_opt['optim_d']['lr']
176
        }]
177
        optim_type = train_opt['optim_d'].pop('type')
178
        lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
179
        betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
180
        self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
        self.optimizers.append(self.optimizer_d)
182

183
        # ----------- optimizers for facial component networks ----------- #
184
        if self.use_facial_disc:
185
            # setup optimizers for facial component discriminators
186
            optim_type = train_opt['optim_component'].pop('type')
187
            lr = train_opt['optim_component']['lr']
188
            # left eye
189
            self.optimizer_d_left_eye = self.get_optimizer(
190
                optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
191
            self.optimizers.append(self.optimizer_d_left_eye)
192
            # right eye
193
            self.optimizer_d_right_eye = self.get_optimizer(
194
                optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
195
            self.optimizers.append(self.optimizer_d_right_eye)
196
            # mouth
197
            self.optimizer_d_mouth = self.get_optimizer(
198
                optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
199
            self.optimizers.append(self.optimizer_d_mouth)
200

201
    def feed_data(self, data):
202
        self.lq = data['lq'].to(self.device)
203
        if 'gt' in data:
204
            self.gt = data['gt'].to(self.device)
205

206
        if 'loc_left_eye' in data:
207
            # get facial component locations, shape (batch, 4)
208
            self.loc_left_eyes = data['loc_left_eye']
209
            self.loc_right_eyes = data['loc_right_eye']
210
            self.loc_mouths = data['loc_mouth']
211

212
        # uncomment to check data
213
        # import torchvision
214
        # if self.opt['rank'] == 0:
215
        #     import os
216
        #     os.makedirs('tmp/gt', exist_ok=True)
217
        #     os.makedirs('tmp/lq', exist_ok=True)
218
        #     print(self.idx)
219
        #     torchvision.utils.save_image(
220
        #         self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
221
        #     torchvision.utils.save_image(
222
        #         self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
223
        #     self.idx = self.idx + 1
224

225
    def construct_img_pyramid(self):
226
        """Construct image pyramid for intermediate restoration loss"""
227
        pyramid_gt = [self.gt]
228
        down_img = self.gt
229
        for _ in range(0, self.log_size - 3):
230
            down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
231
            pyramid_gt.insert(0, down_img)
232
        return pyramid_gt
233

234
    def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
235
        face_ratio = int(self.opt['network_g']['out_size'] / 512)
236
        eye_out_size *= face_ratio
237
        mouth_out_size *= face_ratio
238

239
        rois_eyes = []
240
        rois_mouths = []
241
        for b in range(self.loc_left_eyes.size(0)):  # loop for batch size
242
            # left eye and right eye
243
            img_inds = self.loc_left_eyes.new_full((2, 1), b)
244
            bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0)  # shape: (2, 4)
245
            rois = torch.cat([img_inds, bbox], dim=-1)  # shape: (2, 5)
246
            rois_eyes.append(rois)
247
            # mouse
248
            img_inds = self.loc_left_eyes.new_full((1, 1), b)
249
            rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1)  # shape: (1, 5)
250
            rois_mouths.append(rois)
251

252
        rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
253
        rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
254

255
        # real images
256
        all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
257
        self.left_eyes_gt = all_eyes[0::2, :, :, :]
258
        self.right_eyes_gt = all_eyes[1::2, :, :, :]
259
        self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
260
        # output
261
        all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
262
        self.left_eyes = all_eyes[0::2, :, :, :]
263
        self.right_eyes = all_eyes[1::2, :, :, :]
264
        self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
265

266
    def _gram_mat(self, x):
267
        """Calculate Gram matrix.
268

269
        Args:
270
            x (torch.Tensor): Tensor with shape of (n, c, h, w).
271

272
        Returns:
273
            torch.Tensor: Gram matrix.
274
        """
275
        n, c, h, w = x.size()
276
        features = x.view(n, c, w * h)
277
        features_t = features.transpose(1, 2)
278
        gram = features.bmm(features_t) / (c * h * w)
279
        return gram
280

281
    def gray_resize_for_identity(self, out, size=128):
282
        out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
283
        out_gray = out_gray.unsqueeze(1)
284
        out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
285
        return out_gray
286

287
    def optimize_parameters(self, current_iter):
288
        # optimize net_g
289
        for p in self.net_d.parameters():
290
            p.requires_grad = False
291
        self.optimizer_g.zero_grad()
292

293
        # do not update facial component net_d
294
        if self.use_facial_disc:
295
            for p in self.net_d_left_eye.parameters():
296
                p.requires_grad = False
297
            for p in self.net_d_right_eye.parameters():
298
                p.requires_grad = False
299
            for p in self.net_d_mouth.parameters():
300
                p.requires_grad = False
301

302
        # image pyramid loss weight
303
        pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
304
        if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
305
            pyramid_loss_weight = 1e-12  # very small weight to avoid unused param error
306
        if pyramid_loss_weight > 0:
307
            self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
308
            pyramid_gt = self.construct_img_pyramid()
309
        else:
310
            self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
311

312
        # get roi-align regions
313
        if self.use_facial_disc:
314
            self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
315

316
        l_g_total = 0
317
        loss_dict = OrderedDict()
318
        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
319
            # pixel loss
320
            if self.cri_pix:
321
                l_g_pix = self.cri_pix(self.output, self.gt)
322
                l_g_total += l_g_pix
323
                loss_dict['l_g_pix'] = l_g_pix
324

325
            # image pyramid loss
326
            if pyramid_loss_weight > 0:
327
                for i in range(0, self.log_size - 2):
328
                    l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
329
                    l_g_total += l_pyramid
330
                    loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
331

332
            # perceptual loss
333
            if self.cri_perceptual:
334
                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
335
                if l_g_percep is not None:
336
                    l_g_total += l_g_percep
337
                    loss_dict['l_g_percep'] = l_g_percep
338
                if l_g_style is not None:
339
                    l_g_total += l_g_style
340
                    loss_dict['l_g_style'] = l_g_style
341

342
            # gan loss
343
            fake_g_pred = self.net_d(self.output)
344
            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
345
            l_g_total += l_g_gan
346
            loss_dict['l_g_gan'] = l_g_gan
347

348
            # facial component loss
349
            if self.use_facial_disc:
350
                # left eye
351
                fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
352
                l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
353
                l_g_total += l_g_gan
354
                loss_dict['l_g_gan_left_eye'] = l_g_gan
355
                # right eye
356
                fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
357
                l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
358
                l_g_total += l_g_gan
359
                loss_dict['l_g_gan_right_eye'] = l_g_gan
360
                # mouth
361
                fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
362
                l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
363
                l_g_total += l_g_gan
364
                loss_dict['l_g_gan_mouth'] = l_g_gan
365

366
                if self.opt['train'].get('comp_style_weight', 0) > 0:
367
                    # get gt feat
368
                    _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
369
                    _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
370
                    _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
371

372
                    def _comp_style(feat, feat_gt, criterion):
373
                        return criterion(self._gram_mat(feat[0]), self._gram_mat(
374
                            feat_gt[0].detach())) * 0.5 + criterion(
375
                                self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
376

377
                    # facial component style loss
378
                    comp_style_loss = 0
379
                    comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
380
                    comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
381
                    comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
382
                    comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
383
                    l_g_total += comp_style_loss
384
                    loss_dict['l_g_comp_style_loss'] = comp_style_loss
385

386
            # identity loss
387
            if self.use_identity:
388
                identity_weight = self.opt['train']['identity_weight']
389
                # get gray images and resize
390
                out_gray = self.gray_resize_for_identity(self.output)
391
                gt_gray = self.gray_resize_for_identity(self.gt)
392

393
                identity_gt = self.network_identity(gt_gray).detach()
394
                identity_out = self.network_identity(out_gray)
395
                l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
396
                l_g_total += l_identity
397
                loss_dict['l_identity'] = l_identity
398

399
            l_g_total.backward()
400
            self.optimizer_g.step()
401

402
        # EMA
403
        self.model_ema(decay=0.5**(32 / (10 * 1000)))
404

405
        # ----------- optimize net_d ----------- #
406
        for p in self.net_d.parameters():
407
            p.requires_grad = True
408
        self.optimizer_d.zero_grad()
409
        if self.use_facial_disc:
410
            for p in self.net_d_left_eye.parameters():
411
                p.requires_grad = True
412
            for p in self.net_d_right_eye.parameters():
413
                p.requires_grad = True
414
            for p in self.net_d_mouth.parameters():
415
                p.requires_grad = True
416
            self.optimizer_d_left_eye.zero_grad()
417
            self.optimizer_d_right_eye.zero_grad()
418
            self.optimizer_d_mouth.zero_grad()
419

420
        fake_d_pred = self.net_d(self.output.detach())
421
        real_d_pred = self.net_d(self.gt)
422
        l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
423
        loss_dict['l_d'] = l_d
424
        # In WGAN, real_score should be positive and fake_score should be negative
425
        loss_dict['real_score'] = real_d_pred.detach().mean()
426
        loss_dict['fake_score'] = fake_d_pred.detach().mean()
427
        l_d.backward()
428

429
        # regularization loss
430
        if current_iter % self.net_d_reg_every == 0:
431
            self.gt.requires_grad = True
432
            real_pred = self.net_d(self.gt)
433
            l_d_r1 = r1_penalty(real_pred, self.gt)
434
            l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
435
            loss_dict['l_d_r1'] = l_d_r1.detach().mean()
436
            l_d_r1.backward()
437

438
        self.optimizer_d.step()
439

440
        # optimize facial component discriminators
441
        if self.use_facial_disc:
442
            # left eye
443
            fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
444
            real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
445
            l_d_left_eye = self.cri_component(
446
                real_d_pred, True, is_disc=True) + self.cri_gan(
447
                    fake_d_pred, False, is_disc=True)
448
            loss_dict['l_d_left_eye'] = l_d_left_eye
449
            l_d_left_eye.backward()
450
            # right eye
451
            fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
452
            real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
453
            l_d_right_eye = self.cri_component(
454
                real_d_pred, True, is_disc=True) + self.cri_gan(
455
                    fake_d_pred, False, is_disc=True)
456
            loss_dict['l_d_right_eye'] = l_d_right_eye
457
            l_d_right_eye.backward()
458
            # mouth
459
            fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
460
            real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
461
            l_d_mouth = self.cri_component(
462
                real_d_pred, True, is_disc=True) + self.cri_gan(
463
                    fake_d_pred, False, is_disc=True)
464
            loss_dict['l_d_mouth'] = l_d_mouth
465
            l_d_mouth.backward()
466

467
            self.optimizer_d_left_eye.step()
468
            self.optimizer_d_right_eye.step()
469
            self.optimizer_d_mouth.step()
470

471
        self.log_dict = self.reduce_loss_dict(loss_dict)
472

473
    def test(self):
474
        with torch.no_grad():
475
            if hasattr(self, 'net_g_ema'):
476
                self.net_g_ema.eval()
477
                self.output, _ = self.net_g_ema(self.lq)
478
            else:
479
                logger = get_root_logger()
480
                logger.warning('Do not have self.net_g_ema, use self.net_g.')
481
                self.net_g.eval()
482
                self.output, _ = self.net_g(self.lq)
483
                self.net_g.train()
484

485
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
486
        if self.opt['rank'] == 0:
487
            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
488

489
    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
490
        dataset_name = dataloader.dataset.opt['name']
491
        with_metrics = self.opt['val'].get('metrics') is not None
492
        use_pbar = self.opt['val'].get('pbar', False)
493

494
        if with_metrics:
495
            if not hasattr(self, 'metric_results'):  # only execute in the first run
496
                self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
497
            # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
498
            self._initialize_best_metric_results(dataset_name)
499
            # zero self.metric_results
500
            self.metric_results = {metric: 0 for metric in self.metric_results}
501

502
        metric_data = dict()
503
        if use_pbar:
504
            pbar = tqdm(total=len(dataloader), unit='image')
505

506
        for idx, val_data in enumerate(dataloader):
507
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
508
            self.feed_data(val_data)
509
            self.test()
510

511
            sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
512
            metric_data['img'] = sr_img
513
            if hasattr(self, 'gt'):
514
                gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
515
                metric_data['img2'] = gt_img
516
                del self.gt
517

518
            # tentative for out of GPU memory
519
            del self.lq
520
            del self.output
521
            torch.cuda.empty_cache()
522

523
            if save_img:
524
                if self.opt['is_train']:
525
                    save_img_path = osp.join(self.opt['path']['visualization'], img_name,
526
                                             f'{img_name}_{current_iter}.png')
527
                else:
528
                    if self.opt['val']['suffix']:
529
                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
530
                                                 f'{img_name}_{self.opt["val"]["suffix"]}.png')
531
                    else:
532
                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
533
                                                 f'{img_name}_{self.opt["name"]}.png')
534
                imwrite(sr_img, save_img_path)
535

536
            if with_metrics:
537
                # calculate metrics
538
                for name, opt_ in self.opt['val']['metrics'].items():
539
                    self.metric_results[name] += calculate_metric(metric_data, opt_)
540
            if use_pbar:
541
                pbar.update(1)
542
                pbar.set_description(f'Test {img_name}')
543
        if use_pbar:
544
            pbar.close()
545

546
        if with_metrics:
547
            for metric in self.metric_results.keys():
548
                self.metric_results[metric] /= (idx + 1)
549
                # update the best metric result
550
                self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
551

552
            self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
553

554
    def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
555
        log_str = f'Validation {dataset_name}\n'
556
        for metric, value in self.metric_results.items():
557
            log_str += f'\t # {metric}: {value:.4f}'
558
            if hasattr(self, 'best_metric_results'):
559
                log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
560
                            f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
561
            log_str += '\n'
562

563
        logger = get_root_logger()
564
        logger.info(log_str)
565
        if tb_logger:
566
            for metric, value in self.metric_results.items():
567
                tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
568

569
    def save(self, epoch, current_iter):
570
        # save net_g and net_d
571
        self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
572
        self.save_network(self.net_d, 'net_d', current_iter)
573
        # save component discriminators
574
        if self.use_facial_disc:
575
            self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
576
            self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
577
            self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
578
        # save training state
579
        self.save_training_state(epoch, current_iter)
580

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.