HairFastGAN

Форк
0
/
pp_losses.py 
677 строк · 21.8 Кб
1
from dataclasses import dataclass
2

3
import torch.nn as nn
4
import torch.nn.functional as F
5
from torchvision import transforms as T
6

7
from utils.bicubic import BicubicDownSample
8

9
normalize = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
10

11
@dataclass
12
class DefaultPaths:
13
    psp_path: str = "pretrained_models/psp_ffhq_encode.pt"
14
    ir_se50_path: str = "pretrained_models/ArcFace/ir_se50.pth"
15
    stylegan_weights: str = "pretrained_models/stylegan2-ffhq-config-f.pt"
16
    stylegan_car_weights: str = "pretrained_models/stylegan2-car-config-f-new.pkl"
17
    stylegan_weights_pkl: str = (
18
        "pretrained_models/stylegan2-ffhq-config-f.pkl"
19
    )
20
    arcface_model_path: str = "pretrained_models/ArcFace/backbone_ir50.pth"
21
    moco: str = "pretrained_models/moco_v2_800ep_pretrain.pt"
22
    
23

24
from collections import namedtuple
25
from torch.nn import (
26
    Conv2d,
27
    BatchNorm2d,
28
    PReLU,
29
    ReLU,
30
    Sigmoid,
31
    MaxPool2d,
32
    AdaptiveAvgPool2d,
33
    Sequential,
34
    Module,
35
    Dropout,
36
    Linear,
37
    BatchNorm1d,
38
)
39

40
"""
41
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
42
"""
43

44

45
class Flatten(Module):
46
    def forward(self, input):
47
        return input.view(input.size(0), -1)
48

49

50
def l2_norm(input, axis=1):
51
    norm = torch.norm(input, 2, axis, True)
52
    output = torch.div(input, norm)
53
    return output
54

55

56
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
57
    """A named tuple describing a ResNet block."""
58

59

60
def get_block(in_channel, depth, num_units, stride=2):
61
    return [Bottleneck(in_channel, depth, stride)] + [
62
        Bottleneck(depth, depth, 1) for i in range(num_units - 1)
63
    ]
64

65

66
def get_blocks(num_layers):
67
    if num_layers == 50:
68
        blocks = [
69
            get_block(in_channel=64, depth=64, num_units=3),
70
            get_block(in_channel=64, depth=128, num_units=4),
71
            get_block(in_channel=128, depth=256, num_units=14),
72
            get_block(in_channel=256, depth=512, num_units=3),
73
        ]
74
    elif num_layers == 100:
75
        blocks = [
76
            get_block(in_channel=64, depth=64, num_units=3),
77
            get_block(in_channel=64, depth=128, num_units=13),
78
            get_block(in_channel=128, depth=256, num_units=30),
79
            get_block(in_channel=256, depth=512, num_units=3),
80
        ]
81
    elif num_layers == 152:
82
        blocks = [
83
            get_block(in_channel=64, depth=64, num_units=3),
84
            get_block(in_channel=64, depth=128, num_units=8),
85
            get_block(in_channel=128, depth=256, num_units=36),
86
            get_block(in_channel=256, depth=512, num_units=3),
87
        ]
88
    else:
89
        raise ValueError(
90
            "Invalid number of layers: {}. Must be one of [50, 100, 152]".format(
91
                num_layers
92
            )
93
        )
94
    return blocks
95

96

97
class SEModule(Module):
98
    def __init__(self, channels, reduction):
99
        super(SEModule, self).__init__()
100
        self.avg_pool = AdaptiveAvgPool2d(1)
101
        self.fc1 = Conv2d(
102
            channels, channels // reduction, kernel_size=1, padding=0, bias=False
103
        )
104
        self.relu = ReLU(inplace=True)
105
        self.fc2 = Conv2d(
106
            channels // reduction, channels, kernel_size=1, padding=0, bias=False
107
        )
108
        self.sigmoid = Sigmoid()
109

110
    def forward(self, x):
111
        module_input = x
112
        x = self.avg_pool(x)
113
        x = self.fc1(x)
114
        x = self.relu(x)
115
        x = self.fc2(x)
116
        x = self.sigmoid(x)
117
        return module_input * x
118

119

120
class bottleneck_IR(Module):
121
    def __init__(self, in_channel, depth, stride):
122
        super(bottleneck_IR, self).__init__()
123
        if in_channel == depth:
124
            self.shortcut_layer = MaxPool2d(1, stride)
125
        else:
126
            self.shortcut_layer = Sequential(
127
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
128
                BatchNorm2d(depth),
129
            )
130
        self.res_layer = Sequential(
131
            BatchNorm2d(in_channel),
132
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
133
            PReLU(depth),
134
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
135
            BatchNorm2d(depth),
136
        )
137

138
    def forward(self, x):
139
        shortcut = self.shortcut_layer(x)
140
        res = self.res_layer(x)
141
        return res + shortcut
142

143

144
class bottleneck_IR_SE(Module):
145
    def __init__(self, in_channel, depth, stride):
146
        super(bottleneck_IR_SE, self).__init__()
147
        if in_channel == depth:
148
            self.shortcut_layer = MaxPool2d(1, stride)
149
        else:
150
            self.shortcut_layer = Sequential(
151
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
152
                BatchNorm2d(depth),
153
            )
154
        self.res_layer = Sequential(
155
            BatchNorm2d(in_channel),
156
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
157
            PReLU(depth),
158
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
159
            BatchNorm2d(depth),
160
            SEModule(depth, 16),
161
        )
162

163
    def forward(self, x):
164
        shortcut = self.shortcut_layer(x)
165
        res = self.res_layer(x)
166
        return res + shortcut
167

168

169
"""
170
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
171
"""
172

173

174
class Backbone(Module):
175
    def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True):
176
        super(Backbone, self).__init__()
177
        assert input_size in [112, 224], "input_size should be 112 or 224"
178
        assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
179
        assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
180
        blocks = get_blocks(num_layers)
181
        if mode == "ir":
182
            unit_module = bottleneck_IR
183
        elif mode == "ir_se":
184
            unit_module = bottleneck_IR_SE
185
        self.input_layer = Sequential(
186
            Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
187
        )
188
        if input_size == 112:
189
            self.output_layer = Sequential(
190
                BatchNorm2d(512),
191
                Dropout(drop_ratio),
192
                Flatten(),
193
                Linear(512 * 7 * 7, 512),
194
                BatchNorm1d(512, affine=affine),
195
            )
196
        else:
197
            self.output_layer = Sequential(
198
                BatchNorm2d(512),
199
                Dropout(drop_ratio),
200
                Flatten(),
201
                Linear(512 * 14 * 14, 512),
202
                BatchNorm1d(512, affine=affine),
203
            )
204

205
        modules = []
206
        for block in blocks:
207
            for bottleneck in block:
208
                modules.append(
209
                    unit_module(
210
                        bottleneck.in_channel, bottleneck.depth, bottleneck.stride
211
                    )
212
                )
213
        self.body = Sequential(*modules)
214

215
    def forward(self, x):
216
        x = self.input_layer(x)
217
        x = self.body(x)
218
        x = self.output_layer(x)
219
        return l2_norm(x)
220

221

222
def IR_50(input_size):
223
    """Constructs a ir-50 model."""
224
    model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False)
225
    return model
226

227

228
def IR_101(input_size):
229
    """Constructs a ir-101 model."""
230
    model = Backbone(
231
        input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False
232
    )
233
    return model
234

235

236
def IR_152(input_size):
237
    """Constructs a ir-152 model."""
238
    model = Backbone(
239
        input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False
240
    )
241
    return model
242

243

244
def IR_SE_50(input_size):
245
    """Constructs a ir_se-50 model."""
246
    model = Backbone(
247
        input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False
248
    )
249
    return model
250

251

252
def IR_SE_101(input_size):
253
    """Constructs a ir_se-101 model."""
254
    model = Backbone(
255
        input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False
256
    )
257
    return model
258

259

260
def IR_SE_152(input_size):
261
    """Constructs a ir_se-152 model."""
262
    model = Backbone(
263
        input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False
264
    )
265
    return model
266

267
class IDLoss(nn.Module):
268
    def __init__(self):
269
        super(IDLoss, self).__init__()
270
        print("Loading ResNet ArcFace")
271
        self.facenet = Backbone(
272
            input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
273
        )
274
        self.facenet.load_state_dict(torch.load(DefaultPaths.ir_se50_path))
275
        self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
276
        self.facenet.eval()
277

278
    def extract_feats(self, x):
279
        x = x[:, :, 35:223, 32:220]  # Crop interesting region
280
        x = self.face_pool(x)
281
        x_feats = self.facenet(x)
282
        return x_feats
283

284
    def forward(self, y_hat, y):
285
        n_samples = y.shape[0]
286
        y_feats = self.extract_feats(y)
287
        y_hat_feats = self.extract_feats(y_hat)
288
        y_feats = y_feats.detach()
289
        loss = 0
290
        count = 0
291
        for i in range(n_samples):
292
            diff_target = y_hat_feats[i].dot(y_feats[i])
293
            loss += 1 - diff_target
294
            count += 1
295

296
        return loss / count
297
    
298
class FeatReconLoss(nn.Module):
299
    def __init__(self):
300
        super().__init__()
301
        self.loss_fn = nn.MSELoss()
302

303
    def forward(self, recon_1, recon_2):
304
        return self.loss_fn(recon_1, recon_2).mean()
305
    
306
class EncoderAdvLoss:
307
    def __call__(self, fake_preds):
308
        loss_G_adv = F.softplus(-fake_preds).mean()
309
        return loss_G_adv
310

311
class AdvLoss:
312
    def __init__(self, coef=0.0):
313
        self.coef = coef
314

315
    def __call__(self, disc, real_images, generated_images):
316
        fake_preds = disc(generated_images, None)
317
        real_preds = disc(real_images, None)
318
        loss = self.d_logistic_loss(real_preds, fake_preds)
319

320
        return {'disc adv': loss}
321

322
    def d_logistic_loss(self, real_preds, fake_preds):
323
        real_loss = F.softplus(-real_preds)
324
        fake_loss = F.softplus(fake_preds)
325

326
        return (real_loss.mean() + fake_loss.mean()) / 2
327

328
from models.face_parsing.model import BiSeNet, seg_mean, seg_std
329
    
330
class DiceLoss(nn.Module):
331
    def __init__(self, gamma=2):
332
        super().__init__()
333
        self.gamma = gamma
334
        self.seg = BiSeNet(n_classes=16)
335
        self.seg.to('cuda')
336
        self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth'))
337
        for param in self.seg.parameters():
338
            param.requires_grad = False
339
        self.seg.eval()
340
        self.downsample_512 = BicubicDownSample(factor=2)
341
    
342
    def calc_landmark(self, x):
343
        IM = (self.downsample_512(x) - seg_mean) / seg_std
344
        out, _, _ = self.seg(IM)
345
        return out
346

347
    def dice_loss(self, input, target):
348
        smooth = 1.
349

350
        iflat = input.view(input.size(0), -1)
351
        tflat = target.view(target.size(0), -1)
352
        intersection = (iflat * tflat).sum(dim=1)
353
        
354
        fn = torch.sum((tflat * (1-iflat))**self.gamma, dim=1)
355
        fp = torch.sum(((1-tflat) * iflat)**self.gamma, dim=1)
356

357
        return 1 - ((2. * intersection + smooth) /
358
                  (iflat.sum(dim=1) + tflat.sum(dim=1) + fn + fp + smooth))
359
    
360
    def __call__(self, in_logit, tg_logit):
361
        probs1 = F.softmax(in_logit, dim=1)
362
        probs2 = F.softmax(tg_logit, dim=1)
363
        return self.dice_loss(probs1, probs2).mean()
364
        
365

366
from typing import Sequence
367

368
from itertools import chain
369

370
import torch
371
import torch.nn as nn
372
from torchvision import models
373

374

375
def get_network(net_type: str):
376
    if net_type == "alex":
377
        return AlexNet()
378
    elif net_type == "squeeze":
379
        return SqueezeNet()
380
    elif net_type == "vgg":
381
        return VGG16()
382
    else:
383
        raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")
384

385

386
class LinLayers(nn.ModuleList):
387
    def __init__(self, n_channels_list: Sequence[int]):
388
        super(LinLayers, self).__init__(
389
            [
390
                nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))
391
                for nc in n_channels_list
392
            ]
393
        )
394

395
        for param in self.parameters():
396
            param.requires_grad = False
397

398

399
class BaseNet(nn.Module):
400
    def __init__(self):
401
        super(BaseNet, self).__init__()
402

403
        # register buffer
404
        self.register_buffer(
405
            "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
406
        )
407
        self.register_buffer(
408
            "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
409
        )
410

411
    def set_requires_grad(self, state: bool):
412
        for param in chain(self.parameters(), self.buffers()):
413
            param.requires_grad = state
414

415
    def z_score(self, x: torch.Tensor):
416
        return (x - self.mean) / self.std
417

418
    def forward(self, x: torch.Tensor):
419
        x = self.z_score(x)
420

421
        output = []
422
        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
423
            x = layer(x)
424
            if i in self.target_layers:
425
                output.append(normalize_activation(x))
426
            if len(output) == len(self.target_layers):
427
                break
428
        return output
429

430

431
class SqueezeNet(BaseNet):
432
    def __init__(self):
433
        super(SqueezeNet, self).__init__()
434

435
        self.layers = models.squeezenet1_1(True).features
436
        self.target_layers = [2, 5, 8, 10, 11, 12, 13]
437
        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
438

439
        self.set_requires_grad(False)
440

441

442
class AlexNet(BaseNet):
443
    def __init__(self):
444
        super(AlexNet, self).__init__()
445

446
        self.layers = models.alexnet(True).features
447
        self.target_layers = [2, 5, 8, 10, 12]
448
        self.n_channels_list = [64, 192, 384, 256, 256]
449

450
        self.set_requires_grad(False)
451

452

453
class VGG16(BaseNet):
454
    def __init__(self):
455
        super(VGG16, self).__init__()
456

457
        self.layers = models.vgg16(True).features
458
        self.target_layers = [4, 9, 16, 23, 30]
459
        self.n_channels_list = [64, 128, 256, 512, 512]
460

461
        self.set_requires_grad(False)
462

463
    
464
from collections import OrderedDict
465

466
import torch
467

468

469
def normalize_activation(x, eps=1e-10):
470
    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
471
    return x / (norm_factor + eps)
472

473

474
def get_state_dict(net_type: str = "alex", version: str = "0.1"):
475
    # build url
476
    url = (
477
        "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"
478
        + f"master/lpips/weights/v{version}/{net_type}.pth"
479
    )
480

481
    # download
482
    old_state_dict = torch.hub.load_state_dict_from_url(
483
        url,
484
        progress=True,
485
        map_location=None if torch.cuda.is_available() else torch.device("cpu"),
486
    )
487

488
    # rename keys
489
    new_state_dict = OrderedDict()
490
    for key, val in old_state_dict.items():
491
        new_key = key
492
        new_key = new_key.replace("lin", "")
493
        new_key = new_key.replace("model.", "")
494
        new_state_dict[new_key] = val
495

496
    return new_state_dict
497
    
498
class LPIPS(nn.Module):
499
    r"""Creates a criterion that measures
500
    Learned Perceptual Image Patch Similarity (LPIPS).
501
    Arguments:
502
        net_type (str): the network type to compare the features:
503
                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
504
        version (str): the version of LPIPS. Default: 0.1.
505
    """
506

507
    def __init__(self, net_type: str = "alex", version: str = "0.1"):
508

509
        assert version in ["0.1"], "v0.1 is only supported now"
510

511
        super(LPIPS, self).__init__()
512

513
        # pretrained network
514
        self.net = get_network(net_type).to("cuda")
515

516
        # linear layers
517
        self.lin = LinLayers(self.net.n_channels_list).to("cuda")
518
        self.lin.load_state_dict(get_state_dict(net_type, version))
519

520
    def forward(self, x: torch.Tensor, y: torch.Tensor):
521
        feat_x, feat_y = self.net(x), self.net(y)
522

523
        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
524
        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
525

526
        return torch.sum(torch.cat(res, 0)) / x.shape[0]
527

528
class LPIPSLoss(LPIPS):
529
    pass
530
    
531
class LPIPSScaleLoss(nn.Module):
532
    def __init__(self):
533
        super().__init__()
534
        self.loss_fn = LPIPSLoss()
535

536
    def forward(self, x, y):
537
        out = 0
538
        for res in [256, 128, 64]:
539
            x_scale = F.interpolate(x, size=(res, res), mode="bilinear", align_corners=False)
540
            y_scale = F.interpolate(y, size=(res, res), mode="bilinear", align_corners=False)
541
            out += self.loss_fn.forward(x_scale, y_scale).mean()
542
        return out
543
    
544
class SyntMSELoss(nn.Module):
545
    def __init__(self):
546
        super().__init__()
547
        self.loss_fn = nn.MSELoss()
548

549
    def forward(self, im1, im2):
550
        return self.loss_fn(im1, im2).mean()
551
    
552
class R1Loss:
553
    def __init__(self, coef=10.0):
554
        self.coef = coef
555

556
    def __call__(self, disc, real_images):
557
        real_images.requires_grad = True
558

559
        real_preds = disc(real_images, None)
560
        real_preds = real_preds.view(real_images.size(0), -1)
561
        real_preds = real_preds.mean(dim=1).unsqueeze(1)
562
        r1_loss = self.d_r1_loss(real_preds, real_images)
563

564
        loss_D_R1 = self.coef / 2 * r1_loss * 16 + 0 * real_preds[0]
565
        return {'disc r1 loss': loss_D_R1}
566

567
    def d_r1_loss(self, real_pred, real_img):
568
        (grad_real,) = torch.autograd.grad(
569
            outputs=real_pred.sum(), inputs=real_img, create_graph=True
570
        )
571
        grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
572

573
        return grad_penalty
574

575

576
class DilatedMask:
577
    def __init__(self, kernel_size=5):
578
        self.kernel_size = kernel_size
579
        
580
        cords_x = torch.arange(0, kernel_size).view(1, -1).expand(kernel_size, -1) - kernel_size // 2
581
        cords_y = cords_x.clone().permute(1, 0)
582
        self.kernel = torch.as_tensor((cords_x ** 2 + cords_y ** 2) <= (kernel_size // 2) ** 2, dtype=torch.float).view(1, 1, kernel_size, kernel_size).cuda()
583
        self.kernel /= self.kernel.sum()
584
    
585
    def __call__(self, mask):
586
        smooth_mask = F.conv2d(mask, self.kernel, padding=self.kernel_size // 2)
587
        return smooth_mask ** 0.25
588

589

590
class LossBuilder:
591
    def __init__(self, losses_dict, device='cuda'):
592
        self.losses_dict = losses_dict
593
        self.device = device
594
        
595
        self.EncoderAdvLoss = EncoderAdvLoss()
596
        self.AdvLoss = AdvLoss()
597
        self.R1Loss = R1Loss()
598
        self.FeatReconLoss = FeatReconLoss().to(device).eval()
599
        self.IDLoss = IDLoss().to(device).eval()
600
        self.LPIPS = LPIPSScaleLoss().to(device).eval()
601
        self.SyntMSELoss = SyntMSELoss().to(device).eval()
602
        self.downsample_256 = BicubicDownSample(factor=4)
603
        
604
    def CalcAdvLoss(self, disc, gen_F):
605
        fake_preds_F = disc(gen_F, None)
606
        
607
        return {'adv': self.losses_dict['adv'] * self.EncoderAdvLoss(fake_preds_F)}
608
    
609
    def CalcDisLoss(self, disc, real_images, generated_images):
610
        return self.AdvLoss(disc, real_images, generated_images)
611
    
612
    def CalcR1Loss(self, disc, real_images):
613
        return self.R1Loss(disc, real_images)
614
        
615
    def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):
616
        losses = {}
617
        
618
        gen_w_256 = self.downsample_256(gen_w)
619
        gen_F_256 = self.downsample_256(gen_F)
620
        
621
        # ID loss
622
        losses['rec id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source), gen_w_256) + self.IDLoss(normalize(source), gen_F_256))
623

624
        # Feat Recons Loss
625
        losses['rec feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)
626
        
627
        # LPIPS loss
628
        losses['rec lpips_scale'] = self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source), gen_w_256) + self.LPIPS(normalize(source), gen_F_256))
629
        
630
        # Synt loss
631
        # losses['l2_synt'] = self.losses_dict['l2_synt'] * self.SyntMSELoss(target * HT_E, (gen_F_256 + 1) / 2 * HT_E)
632
        
633
        return losses
634

635
    
636
class LossBuilderMulti(LossBuilder):
637
    def __init__(self, *args, **kwargs):
638
        super().__init__(*args, **kwargs)
639
        self.DiceLoss = DiceLoss().to(kwargs.get('device', 'cuda')).eval()
640
        self.dilated = DilatedMask(25)
641
        
642
    def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):
643
        losses = {}
644
        
645
        gen_w_256 = self.downsample_256(gen_w)
646
        gen_F_256 = self.downsample_256(gen_F)
647
        
648
        # Dice loss
649
        with torch.no_grad():
650
            target_512 = F.interpolate(target, size=(512, 512), mode='bilinear').clip(0, 1)
651
            seg_target = self.DiceLoss.calc_landmark(target_512)
652
            seg_target = F.interpolate(seg_target, size=(256, 256), mode='nearest')
653
        seg_gen = F.interpolate(self.DiceLoss.calc_landmark((gen_F + 1) / 2), size=(256, 256), mode='nearest')
654
        
655
        losses['DiceLoss'] = self.losses_dict['landmark'] * self.DiceLoss(seg_gen, seg_target)
656
        
657
        # ID loss
658
        losses['id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source) * target_mask, gen_w_256 * target_mask) +
659
                                                 self.IDLoss(normalize(source) * target_mask, gen_F_256 * target_mask))
660

661
        # Feat Recons loss
662
        losses['feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)
663
        
664
        # LPIPS loss
665
        losses['lpips_face'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source) * target_mask, gen_w_256 * target_mask) +
666
                                                                         self.LPIPS(normalize(source) * target_mask, gen_F_256 * target_mask))
667
        losses['lpips_hair'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(target) * HT_E, gen_w_256 * HT_E) +
668
                                                                          self.LPIPS(normalize(target) * HT_E, gen_F_256 * HT_E))
669
                                                                          
670
        # Inpaint loss
671
        if self.losses_dict['inpaint'] != 0.:
672
            M_Inp = (1 - target_mask) * (1 - HT_E)
673
            Smooth_M = self.dilated(M_Inp)
674
            losses['inpaint'] = 0.5 * self.losses_dict['inpaint'] * self.LPIPS(normalize(target) * Smooth_M, gen_F_256 * Smooth_M)
675
            losses['inpaint'] += 0.5 * self.losses_dict['inpaint'] * self.LPIPS(gen_w_256.detach() * Smooth_M * (1 - HT_E), gen_F_256 * Smooth_M * (1 - HT_E))
676
        
677
        return losses
678

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.