HairFastGAN
677 строк · 21.8 Кб
1from dataclasses import dataclass2
3import torch.nn as nn4import torch.nn.functional as F5from torchvision import transforms as T6
7from utils.bicubic import BicubicDownSample8
9normalize = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])10
11@dataclass
12class DefaultPaths:13psp_path: str = "pretrained_models/psp_ffhq_encode.pt"14ir_se50_path: str = "pretrained_models/ArcFace/ir_se50.pth"15stylegan_weights: str = "pretrained_models/stylegan2-ffhq-config-f.pt"16stylegan_car_weights: str = "pretrained_models/stylegan2-car-config-f-new.pkl"17stylegan_weights_pkl: str = (18"pretrained_models/stylegan2-ffhq-config-f.pkl"19)20arcface_model_path: str = "pretrained_models/ArcFace/backbone_ir50.pth"21moco: str = "pretrained_models/moco_v2_800ep_pretrain.pt"22
23
24from collections import namedtuple25from torch.nn import (26Conv2d,27BatchNorm2d,28PReLU,29ReLU,30Sigmoid,31MaxPool2d,32AdaptiveAvgPool2d,33Sequential,34Module,35Dropout,36Linear,37BatchNorm1d,38)
39
40"""
41ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
42"""
43
44
45class Flatten(Module):46def forward(self, input):47return input.view(input.size(0), -1)48
49
50def l2_norm(input, axis=1):51norm = torch.norm(input, 2, axis, True)52output = torch.div(input, norm)53return output54
55
56class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):57"""A named tuple describing a ResNet block."""58
59
60def get_block(in_channel, depth, num_units, stride=2):61return [Bottleneck(in_channel, depth, stride)] + [62Bottleneck(depth, depth, 1) for i in range(num_units - 1)63]64
65
66def get_blocks(num_layers):67if num_layers == 50:68blocks = [69get_block(in_channel=64, depth=64, num_units=3),70get_block(in_channel=64, depth=128, num_units=4),71get_block(in_channel=128, depth=256, num_units=14),72get_block(in_channel=256, depth=512, num_units=3),73]74elif num_layers == 100:75blocks = [76get_block(in_channel=64, depth=64, num_units=3),77get_block(in_channel=64, depth=128, num_units=13),78get_block(in_channel=128, depth=256, num_units=30),79get_block(in_channel=256, depth=512, num_units=3),80]81elif num_layers == 152:82blocks = [83get_block(in_channel=64, depth=64, num_units=3),84get_block(in_channel=64, depth=128, num_units=8),85get_block(in_channel=128, depth=256, num_units=36),86get_block(in_channel=256, depth=512, num_units=3),87]88else:89raise ValueError(90"Invalid number of layers: {}. Must be one of [50, 100, 152]".format(91num_layers
92)93)94return blocks95
96
97class SEModule(Module):98def __init__(self, channels, reduction):99super(SEModule, self).__init__()100self.avg_pool = AdaptiveAvgPool2d(1)101self.fc1 = Conv2d(102channels, channels // reduction, kernel_size=1, padding=0, bias=False103)104self.relu = ReLU(inplace=True)105self.fc2 = Conv2d(106channels // reduction, channels, kernel_size=1, padding=0, bias=False107)108self.sigmoid = Sigmoid()109
110def forward(self, x):111module_input = x112x = self.avg_pool(x)113x = self.fc1(x)114x = self.relu(x)115x = self.fc2(x)116x = self.sigmoid(x)117return module_input * x118
119
120class bottleneck_IR(Module):121def __init__(self, in_channel, depth, stride):122super(bottleneck_IR, self).__init__()123if in_channel == depth:124self.shortcut_layer = MaxPool2d(1, stride)125else:126self.shortcut_layer = Sequential(127Conv2d(in_channel, depth, (1, 1), stride, bias=False),128BatchNorm2d(depth),129)130self.res_layer = Sequential(131BatchNorm2d(in_channel),132Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),133PReLU(depth),134Conv2d(depth, depth, (3, 3), stride, 1, bias=False),135BatchNorm2d(depth),136)137
138def forward(self, x):139shortcut = self.shortcut_layer(x)140res = self.res_layer(x)141return res + shortcut142
143
144class bottleneck_IR_SE(Module):145def __init__(self, in_channel, depth, stride):146super(bottleneck_IR_SE, self).__init__()147if in_channel == depth:148self.shortcut_layer = MaxPool2d(1, stride)149else:150self.shortcut_layer = Sequential(151Conv2d(in_channel, depth, (1, 1), stride, bias=False),152BatchNorm2d(depth),153)154self.res_layer = Sequential(155BatchNorm2d(in_channel),156Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),157PReLU(depth),158Conv2d(depth, depth, (3, 3), stride, 1, bias=False),159BatchNorm2d(depth),160SEModule(depth, 16),161)162
163def forward(self, x):164shortcut = self.shortcut_layer(x)165res = self.res_layer(x)166return res + shortcut167
168
169"""
170Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
171"""
172
173
174class Backbone(Module):175def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True):176super(Backbone, self).__init__()177assert input_size in [112, 224], "input_size should be 112 or 224"178assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"179assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"180blocks = get_blocks(num_layers)181if mode == "ir":182unit_module = bottleneck_IR183elif mode == "ir_se":184unit_module = bottleneck_IR_SE185self.input_layer = Sequential(186Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)187)188if input_size == 112:189self.output_layer = Sequential(190BatchNorm2d(512),191Dropout(drop_ratio),192Flatten(),193Linear(512 * 7 * 7, 512),194BatchNorm1d(512, affine=affine),195)196else:197self.output_layer = Sequential(198BatchNorm2d(512),199Dropout(drop_ratio),200Flatten(),201Linear(512 * 14 * 14, 512),202BatchNorm1d(512, affine=affine),203)204
205modules = []206for block in blocks:207for bottleneck in block:208modules.append(209unit_module(210bottleneck.in_channel, bottleneck.depth, bottleneck.stride211)212)213self.body = Sequential(*modules)214
215def forward(self, x):216x = self.input_layer(x)217x = self.body(x)218x = self.output_layer(x)219return l2_norm(x)220
221
222def IR_50(input_size):223"""Constructs a ir-50 model."""224model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False)225return model226
227
228def IR_101(input_size):229"""Constructs a ir-101 model."""230model = Backbone(231input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False232)233return model234
235
236def IR_152(input_size):237"""Constructs a ir-152 model."""238model = Backbone(239input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False240)241return model242
243
244def IR_SE_50(input_size):245"""Constructs a ir_se-50 model."""246model = Backbone(247input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False248)249return model250
251
252def IR_SE_101(input_size):253"""Constructs a ir_se-101 model."""254model = Backbone(255input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False256)257return model258
259
260def IR_SE_152(input_size):261"""Constructs a ir_se-152 model."""262model = Backbone(263input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False264)265return model266
267class IDLoss(nn.Module):268def __init__(self):269super(IDLoss, self).__init__()270print("Loading ResNet ArcFace")271self.facenet = Backbone(272input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"273)274self.facenet.load_state_dict(torch.load(DefaultPaths.ir_se50_path))275self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))276self.facenet.eval()277
278def extract_feats(self, x):279x = x[:, :, 35:223, 32:220] # Crop interesting region280x = self.face_pool(x)281x_feats = self.facenet(x)282return x_feats283
284def forward(self, y_hat, y):285n_samples = y.shape[0]286y_feats = self.extract_feats(y)287y_hat_feats = self.extract_feats(y_hat)288y_feats = y_feats.detach()289loss = 0290count = 0291for i in range(n_samples):292diff_target = y_hat_feats[i].dot(y_feats[i])293loss += 1 - diff_target294count += 1295
296return loss / count297
298class FeatReconLoss(nn.Module):299def __init__(self):300super().__init__()301self.loss_fn = nn.MSELoss()302
303def forward(self, recon_1, recon_2):304return self.loss_fn(recon_1, recon_2).mean()305
306class EncoderAdvLoss:307def __call__(self, fake_preds):308loss_G_adv = F.softplus(-fake_preds).mean()309return loss_G_adv310
311class AdvLoss:312def __init__(self, coef=0.0):313self.coef = coef314
315def __call__(self, disc, real_images, generated_images):316fake_preds = disc(generated_images, None)317real_preds = disc(real_images, None)318loss = self.d_logistic_loss(real_preds, fake_preds)319
320return {'disc adv': loss}321
322def d_logistic_loss(self, real_preds, fake_preds):323real_loss = F.softplus(-real_preds)324fake_loss = F.softplus(fake_preds)325
326return (real_loss.mean() + fake_loss.mean()) / 2327
328from models.face_parsing.model import BiSeNet, seg_mean, seg_std329
330class DiceLoss(nn.Module):331def __init__(self, gamma=2):332super().__init__()333self.gamma = gamma334self.seg = BiSeNet(n_classes=16)335self.seg.to('cuda')336self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth'))337for param in self.seg.parameters():338param.requires_grad = False339self.seg.eval()340self.downsample_512 = BicubicDownSample(factor=2)341
342def calc_landmark(self, x):343IM = (self.downsample_512(x) - seg_mean) / seg_std344out, _, _ = self.seg(IM)345return out346
347def dice_loss(self, input, target):348smooth = 1.349
350iflat = input.view(input.size(0), -1)351tflat = target.view(target.size(0), -1)352intersection = (iflat * tflat).sum(dim=1)353
354fn = torch.sum((tflat * (1-iflat))**self.gamma, dim=1)355fp = torch.sum(((1-tflat) * iflat)**self.gamma, dim=1)356
357return 1 - ((2. * intersection + smooth) /358(iflat.sum(dim=1) + tflat.sum(dim=1) + fn + fp + smooth))359
360def __call__(self, in_logit, tg_logit):361probs1 = F.softmax(in_logit, dim=1)362probs2 = F.softmax(tg_logit, dim=1)363return self.dice_loss(probs1, probs2).mean()364
365
366from typing import Sequence367
368from itertools import chain369
370import torch371import torch.nn as nn372from torchvision import models373
374
375def get_network(net_type: str):376if net_type == "alex":377return AlexNet()378elif net_type == "squeeze":379return SqueezeNet()380elif net_type == "vgg":381return VGG16()382else:383raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")384
385
386class LinLayers(nn.ModuleList):387def __init__(self, n_channels_list: Sequence[int]):388super(LinLayers, self).__init__(389[390nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))391for nc in n_channels_list392]393)394
395for param in self.parameters():396param.requires_grad = False397
398
399class BaseNet(nn.Module):400def __init__(self):401super(BaseNet, self).__init__()402
403# register buffer404self.register_buffer(405"mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]406)407self.register_buffer(408"std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]409)410
411def set_requires_grad(self, state: bool):412for param in chain(self.parameters(), self.buffers()):413param.requires_grad = state414
415def z_score(self, x: torch.Tensor):416return (x - self.mean) / self.std417
418def forward(self, x: torch.Tensor):419x = self.z_score(x)420
421output = []422for i, (_, layer) in enumerate(self.layers._modules.items(), 1):423x = layer(x)424if i in self.target_layers:425output.append(normalize_activation(x))426if len(output) == len(self.target_layers):427break428return output429
430
431class SqueezeNet(BaseNet):432def __init__(self):433super(SqueezeNet, self).__init__()434
435self.layers = models.squeezenet1_1(True).features436self.target_layers = [2, 5, 8, 10, 11, 12, 13]437self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]438
439self.set_requires_grad(False)440
441
442class AlexNet(BaseNet):443def __init__(self):444super(AlexNet, self).__init__()445
446self.layers = models.alexnet(True).features447self.target_layers = [2, 5, 8, 10, 12]448self.n_channels_list = [64, 192, 384, 256, 256]449
450self.set_requires_grad(False)451
452
453class VGG16(BaseNet):454def __init__(self):455super(VGG16, self).__init__()456
457self.layers = models.vgg16(True).features458self.target_layers = [4, 9, 16, 23, 30]459self.n_channels_list = [64, 128, 256, 512, 512]460
461self.set_requires_grad(False)462
463
464from collections import OrderedDict465
466import torch467
468
469def normalize_activation(x, eps=1e-10):470norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))471return x / (norm_factor + eps)472
473
474def get_state_dict(net_type: str = "alex", version: str = "0.1"):475# build url476url = (477"https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"478+ f"master/lpips/weights/v{version}/{net_type}.pth"479)480
481# download482old_state_dict = torch.hub.load_state_dict_from_url(483url,484progress=True,485map_location=None if torch.cuda.is_available() else torch.device("cpu"),486)487
488# rename keys489new_state_dict = OrderedDict()490for key, val in old_state_dict.items():491new_key = key492new_key = new_key.replace("lin", "")493new_key = new_key.replace("model.", "")494new_state_dict[new_key] = val495
496return new_state_dict497
498class LPIPS(nn.Module):499r"""Creates a criterion that measures500Learned Perceptual Image Patch Similarity (LPIPS).
501Arguments:
502net_type (str): the network type to compare the features:
503'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
504version (str): the version of LPIPS. Default: 0.1.
505"""
506
507def __init__(self, net_type: str = "alex", version: str = "0.1"):508
509assert version in ["0.1"], "v0.1 is only supported now"510
511super(LPIPS, self).__init__()512
513# pretrained network514self.net = get_network(net_type).to("cuda")515
516# linear layers517self.lin = LinLayers(self.net.n_channels_list).to("cuda")518self.lin.load_state_dict(get_state_dict(net_type, version))519
520def forward(self, x: torch.Tensor, y: torch.Tensor):521feat_x, feat_y = self.net(x), self.net(y)522
523diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]524res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]525
526return torch.sum(torch.cat(res, 0)) / x.shape[0]527
528class LPIPSLoss(LPIPS):529pass530
531class LPIPSScaleLoss(nn.Module):532def __init__(self):533super().__init__()534self.loss_fn = LPIPSLoss()535
536def forward(self, x, y):537out = 0538for res in [256, 128, 64]:539x_scale = F.interpolate(x, size=(res, res), mode="bilinear", align_corners=False)540y_scale = F.interpolate(y, size=(res, res), mode="bilinear", align_corners=False)541out += self.loss_fn.forward(x_scale, y_scale).mean()542return out543
544class SyntMSELoss(nn.Module):545def __init__(self):546super().__init__()547self.loss_fn = nn.MSELoss()548
549def forward(self, im1, im2):550return self.loss_fn(im1, im2).mean()551
552class R1Loss:553def __init__(self, coef=10.0):554self.coef = coef555
556def __call__(self, disc, real_images):557real_images.requires_grad = True558
559real_preds = disc(real_images, None)560real_preds = real_preds.view(real_images.size(0), -1)561real_preds = real_preds.mean(dim=1).unsqueeze(1)562r1_loss = self.d_r1_loss(real_preds, real_images)563
564loss_D_R1 = self.coef / 2 * r1_loss * 16 + 0 * real_preds[0]565return {'disc r1 loss': loss_D_R1}566
567def d_r1_loss(self, real_pred, real_img):568(grad_real,) = torch.autograd.grad(569outputs=real_pred.sum(), inputs=real_img, create_graph=True570)571grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()572
573return grad_penalty574
575
576class DilatedMask:577def __init__(self, kernel_size=5):578self.kernel_size = kernel_size579
580cords_x = torch.arange(0, kernel_size).view(1, -1).expand(kernel_size, -1) - kernel_size // 2581cords_y = cords_x.clone().permute(1, 0)582self.kernel = torch.as_tensor((cords_x ** 2 + cords_y ** 2) <= (kernel_size // 2) ** 2, dtype=torch.float).view(1, 1, kernel_size, kernel_size).cuda()583self.kernel /= self.kernel.sum()584
585def __call__(self, mask):586smooth_mask = F.conv2d(mask, self.kernel, padding=self.kernel_size // 2)587return smooth_mask ** 0.25588
589
590class LossBuilder:591def __init__(self, losses_dict, device='cuda'):592self.losses_dict = losses_dict593self.device = device594
595self.EncoderAdvLoss = EncoderAdvLoss()596self.AdvLoss = AdvLoss()597self.R1Loss = R1Loss()598self.FeatReconLoss = FeatReconLoss().to(device).eval()599self.IDLoss = IDLoss().to(device).eval()600self.LPIPS = LPIPSScaleLoss().to(device).eval()601self.SyntMSELoss = SyntMSELoss().to(device).eval()602self.downsample_256 = BicubicDownSample(factor=4)603
604def CalcAdvLoss(self, disc, gen_F):605fake_preds_F = disc(gen_F, None)606
607return {'adv': self.losses_dict['adv'] * self.EncoderAdvLoss(fake_preds_F)}608
609def CalcDisLoss(self, disc, real_images, generated_images):610return self.AdvLoss(disc, real_images, generated_images)611
612def CalcR1Loss(self, disc, real_images):613return self.R1Loss(disc, real_images)614
615def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):616losses = {}617
618gen_w_256 = self.downsample_256(gen_w)619gen_F_256 = self.downsample_256(gen_F)620
621# ID loss622losses['rec id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source), gen_w_256) + self.IDLoss(normalize(source), gen_F_256))623
624# Feat Recons Loss625losses['rec feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)626
627# LPIPS loss628losses['rec lpips_scale'] = self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source), gen_w_256) + self.LPIPS(normalize(source), gen_F_256))629
630# Synt loss631# losses['l2_synt'] = self.losses_dict['l2_synt'] * self.SyntMSELoss(target * HT_E, (gen_F_256 + 1) / 2 * HT_E)632
633return losses634
635
636class LossBuilderMulti(LossBuilder):637def __init__(self, *args, **kwargs):638super().__init__(*args, **kwargs)639self.DiceLoss = DiceLoss().to(kwargs.get('device', 'cuda')).eval()640self.dilated = DilatedMask(25)641
642def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):643losses = {}644
645gen_w_256 = self.downsample_256(gen_w)646gen_F_256 = self.downsample_256(gen_F)647
648# Dice loss649with torch.no_grad():650target_512 = F.interpolate(target, size=(512, 512), mode='bilinear').clip(0, 1)651seg_target = self.DiceLoss.calc_landmark(target_512)652seg_target = F.interpolate(seg_target, size=(256, 256), mode='nearest')653seg_gen = F.interpolate(self.DiceLoss.calc_landmark((gen_F + 1) / 2), size=(256, 256), mode='nearest')654
655losses['DiceLoss'] = self.losses_dict['landmark'] * self.DiceLoss(seg_gen, seg_target)656
657# ID loss658losses['id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source) * target_mask, gen_w_256 * target_mask) +659self.IDLoss(normalize(source) * target_mask, gen_F_256 * target_mask))660
661# Feat Recons loss662losses['feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)663
664# LPIPS loss665losses['lpips_face'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source) * target_mask, gen_w_256 * target_mask) +666self.LPIPS(normalize(source) * target_mask, gen_F_256 * target_mask))667losses['lpips_hair'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(target) * HT_E, gen_w_256 * HT_E) +668self.LPIPS(normalize(target) * HT_E, gen_F_256 * HT_E))669
670# Inpaint loss671if self.losses_dict['inpaint'] != 0.:672M_Inp = (1 - target_mask) * (1 - HT_E)673Smooth_M = self.dilated(M_Inp)674losses['inpaint'] = 0.5 * self.losses_dict['inpaint'] * self.LPIPS(normalize(target) * Smooth_M, gen_F_256 * Smooth_M)675losses['inpaint'] += 0.5 * self.losses_dict['inpaint'] * self.LPIPS(gen_w_256.detach() * Smooth_M * (1 - HT_E), gen_F_256 * Smooth_M * (1 - HT_E))676
677return losses678