HairFastGAN
725 строк · 20.1 Кб
1import math2import random3import functools4import operator5
6import torch7from torch import nn8from torch.nn import functional as F9from torch.autograd import Function10
11from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d12import torchvision13toPIL = torchvision.transforms.ToPILImage()14import numpy as np15
16class PixelNorm(nn.Module):17def __init__(self):18super().__init__()19
20def forward(self, input):21return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)22
23
24def make_kernel(k):25k = torch.tensor(k, dtype=torch.float32)26
27if k.ndim == 1:28k = k[None, :] * k[:, None]29
30k /= k.sum()31
32return k33
34
35class Upsample(nn.Module):36def __init__(self, kernel, factor=2):37super().__init__()38
39self.factor = factor40kernel = make_kernel(kernel) * (factor ** 2)41self.register_buffer('kernel', kernel)42
43p = kernel.shape[0] - factor44
45pad0 = (p + 1) // 2 + factor - 146pad1 = p // 247
48self.pad = (pad0, pad1)49
50def forward(self, input):51out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)52
53return out54
55
56class Downsample(nn.Module):57def __init__(self, kernel, factor=2):58super().__init__()59
60self.factor = factor61kernel = make_kernel(kernel)62self.register_buffer('kernel', kernel)63
64p = kernel.shape[0] - factor65
66pad0 = (p + 1) // 267pad1 = p // 268
69self.pad = (pad0, pad1)70
71def forward(self, input):72out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)73
74return out75
76
77class Blur(nn.Module):78def __init__(self, kernel, pad, upsample_factor=1):79super().__init__()80
81kernel = make_kernel(kernel)82
83if upsample_factor > 1:84kernel = kernel * (upsample_factor ** 2)85
86self.register_buffer('kernel', kernel)87
88self.pad = pad89
90def forward(self, input):91out = upfirdn2d(input, self.kernel, pad=self.pad)92
93return out94
95
96class EqualConv2d(nn.Module):97def __init__(98self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True99):100super().__init__()101
102self.weight = nn.Parameter(103torch.randn(out_channel, in_channel, kernel_size, kernel_size)104)105self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)106
107self.stride = stride108self.padding = padding109
110if bias:111self.bias = nn.Parameter(torch.zeros(out_channel))112
113else:114self.bias = None115
116def forward(self, input):117out = F.conv2d(118input,119self.weight * self.scale,120bias=self.bias,121stride=self.stride,122padding=self.padding,123)124
125return out126
127def __repr__(self):128return (129f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'130f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'131)132
133
134class EqualLinear(nn.Module):135def __init__(136self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None137):138super().__init__()139
140self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))141
142if bias:143self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))144
145else:146self.bias = None147
148self.activation = activation149
150self.scale = (1 / math.sqrt(in_dim)) * lr_mul151self.lr_mul = lr_mul152
153def forward(self, input):154if self.activation:155out = F.linear(input, self.weight * self.scale)156out = fused_leaky_relu(out, self.bias * self.lr_mul)157
158else:159out = F.linear(160input, self.weight * self.scale, bias=self.bias * self.lr_mul161)162
163return out164
165def __repr__(self):166return (167f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'168)169
170
171class ScaledLeakyReLU(nn.Module):172def __init__(self, negative_slope=0.2):173super().__init__()174
175self.negative_slope = negative_slope176
177def forward(self, input):178out = F.leaky_relu(input, negative_slope=self.negative_slope)179
180return out * math.sqrt(2)181
182
183class ModulatedConv2d(nn.Module):184def __init__(185self,186in_channel,187out_channel,188kernel_size,189style_dim,190demodulate=True,191upsample=False,192downsample=False,193blur_kernel=[1, 3, 3, 1],194):195super().__init__()196
197self.eps = 1e-8198self.kernel_size = kernel_size199self.in_channel = in_channel200self.out_channel = out_channel201self.upsample = upsample202self.downsample = downsample203
204if upsample:205factor = 2206p = (len(blur_kernel) - factor) - (kernel_size - 1)207pad0 = (p + 1) // 2 + factor - 1208pad1 = p // 2 + 1209
210self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)211
212if downsample:213factor = 2214p = (len(blur_kernel) - factor) + (kernel_size - 1)215pad0 = (p + 1) // 2216pad1 = p // 2217
218self.blur = Blur(blur_kernel, pad=(pad0, pad1))219
220fan_in = in_channel * kernel_size ** 2221self.scale = 1 / math.sqrt(fan_in)222self.padding = kernel_size // 2223
224self.weight = nn.Parameter(225torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)226)227
228self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)229
230self.demodulate = demodulate231
232def __repr__(self):233return (234f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '235f'upsample={self.upsample}, downsample={self.downsample})'236)237
238def forward(self, input, style):239batch, in_channel, height, width = input.shape240
241style = self.modulation(style).view(batch, 1, in_channel, 1, 1)242weight = self.scale * self.weight * style243
244if self.demodulate:245demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)246weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)247
248weight = weight.view(249batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size250)251
252if self.upsample:253input = input.view(1, batch * in_channel, height, width)254weight = weight.view(255batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size256)257weight = weight.transpose(1, 2).reshape(258batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size259)260out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)261_, _, height, width = out.shape262out = out.view(batch, self.out_channel, height, width)263out = self.blur(out)264
265elif self.downsample:266input = self.blur(input)267_, _, height, width = input.shape268input = input.view(1, batch * in_channel, height, width)269out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)270_, _, height, width = out.shape271out = out.view(batch, self.out_channel, height, width)272
273else:274input = input.view(1, batch * in_channel, height, width)275out = F.conv2d(input, weight, padding=self.padding, groups=batch)276_, _, height, width = out.shape277out = out.view(batch, self.out_channel, height, width)278
279return out280
281
282class NoiseInjection(nn.Module):283def __init__(self):284super().__init__()285
286self.weight = nn.Parameter(torch.zeros(1))287
288def forward(self, image, noise=None):289if noise is None:290batch, _, height, width = image.shape291noise = image.new_empty(batch, 1, height, width).normal_()292
293return image + self.weight * noise294
295
296class ConstantInput(nn.Module):297def __init__(self, channel, size=4):298super().__init__()299
300self.input = nn.Parameter(torch.randn(1, channel, size, size))301
302def forward(self, input):303batch = input.shape[0]304out = self.input.repeat(batch, 1, 1, 1)305
306return out307
308
309class StyledConv(nn.Module):310def __init__(311self,312in_channel,313out_channel,314kernel_size,315style_dim,316upsample=False,317blur_kernel=[1, 3, 3, 1],318demodulate=True,319):320super().__init__()321
322self.conv = ModulatedConv2d(323in_channel,324out_channel,325kernel_size,326style_dim,327upsample=upsample,328blur_kernel=blur_kernel,329demodulate=demodulate,330)331
332self.noise = NoiseInjection()333# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))334# self.activate = ScaledLeakyReLU(0.2)335self.activate = FusedLeakyReLU(out_channel)336
337def forward(self, input, style, noise=None):338out = self.conv(input, style)339out = self.noise(out, noise=noise)340# out = out + self.bias341out = self.activate(out)342
343return out344
345
346class ToRGB(nn.Module):347def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):348super().__init__()349
350if upsample:351self.upsample = Upsample(blur_kernel)352
353self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)354self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))355
356def forward(self, input, style, skip=None):357out = self.conv(input, style)358out = out + self.bias359
360if skip is not None:361skip = self.upsample(skip)362
363out = out + skip364
365return out366
367
368class Generator(nn.Module):369def __init__(370self,371size,372style_dim,373n_mlp,374channel_multiplier=2,375blur_kernel=[1, 3, 3, 1],376lr_mlp=0.01,377):378super().__init__()379
380self.size = size381
382self.style_dim = style_dim383
384layers = [PixelNorm()]385
386for i in range(n_mlp):387layers.append(388EqualLinear(389style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'390)391)392
393self.style = nn.Sequential(*layers)394
395self.channels = {3964: 512,3978: 512,39816: 512,39932: 512,40064: 256 * channel_multiplier,401128: 128 * channel_multiplier,402256: 64 * channel_multiplier,403512: 32 * channel_multiplier,4041024: 16 * channel_multiplier,405}406
407
408self.input = ConstantInput(self.channels[4])409self.conv1 = StyledConv(410self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel411)412self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)413
414self.log_size = int(math.log(size, 2))415self.num_layers = (self.log_size - 2) * 2 + 1416
417self.convs = nn.ModuleList()418self.upsamples = nn.ModuleList()419self.to_rgbs = nn.ModuleList()420self.noises = nn.Module()421
422in_channel = self.channels[4]423
424for layer_idx in range(self.num_layers):425res = (layer_idx + 5) // 2426shape = [1, 1, 2 ** res, 2 ** res]427self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))428
429for i in range(3, self.log_size + 1):430out_channel = self.channels[2 ** i]431
432self.convs.append(433StyledConv(434in_channel,435out_channel,4363,437style_dim,438upsample=True,439blur_kernel=blur_kernel,440)441)442
443self.convs.append(444StyledConv(445out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel446)447)448
449self.to_rgbs.append(ToRGB(out_channel, style_dim))450
451in_channel = out_channel452
453self.n_latent = self.log_size * 2 - 2454
455def make_noise(self):456device = self.input.input.device457
458noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]459
460for i in range(3, self.log_size + 1):461for _ in range(2):462noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))463
464return noises465
466def mean_latent(self, n_latent):467latent_in = torch.randn(468n_latent, self.style_dim, device=self.input.input.device469)470latent = self.style(latent_in).mean(0, keepdim=True)471
472return latent473
474def get_latent(self, input):475return self.style(input)476
477def forward(478self,479styles,480return_latents=False,481inject_index=None,482truncation=1,483truncation_latent=None,484input_is_latent=False,485noise=None,486randomize_noise=True,487layer_in=None,488skip=None,489start_layer=0,490end_layer=8,491return_rgb=False,492
493):494if not input_is_latent:495styles = [self.style(s) for s in styles]496
497if noise is None:498if randomize_noise:499noise = [None] * self.num_layers500else:501noise = [502getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)503]504
505if truncation < 1:506style_t = []507
508for style in styles:509style_t.append(510truncation_latent + truncation * (style - truncation_latent)511)512
513styles = style_t514
515if len(styles) < 2:516inject_index = self.n_latent517
518if styles[0].ndim < 3:519latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)520
521else:522latent = styles[0]523
524else:525if inject_index is None:526inject_index = random.randint(1, self.n_latent - 1)527
528latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)529latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)530
531latent = torch.cat([latent, latent2], 1)532out = self.input(latent)533
534if start_layer == 0:535out = self.conv1(out, latent[:, 0], noise=noise[0]) # 0th layer536skip = self.to_rgb1(out, latent[:, 1])537if end_layer == 0:538return out, skip539i = 1540current_layer = 1541for conv1, conv2, noise1, noise2, to_rgb in zip(542self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs543):544if current_layer < start_layer:545pass546elif current_layer == start_layer:547out = conv1(layer_in, latent[:, i], noise=noise1)548out = conv2(out, latent[:, i + 1], noise=noise2)549skip = to_rgb(out, latent[:, i + 2], skip)550elif current_layer > end_layer:551return out, skip552else:553out = conv1(out, latent[:, i], noise=noise1)554out = conv2(out, latent[:, i + 1], noise=noise2)555skip = to_rgb(out, latent[:, i + 2], skip)556current_layer += 1557i += 2558
559image = skip560
561if return_latents:562return image, latent563
564else:565return image, None566
567def generate_im_from_w_space(self, code, noises=None):568latent = torch.from_numpy(code).cuda()569I_G, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises, start_layer=0,570end_layer=8)571I_G_0_1 = (I_G + 1) / 2572im = np.array(toPIL(I_G_0_1[0].cpu().detach().clamp(0, 1)))573return im574
575def generate_initial_intermediate(self, code, noises=None):576latent = torch.from_numpy(code).cuda()577intermediate, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises,578start_layer=0, end_layer=3)579return intermediate580
581
582def update_on_FS(self, code, initial_intermediate, initial_F, initial_S, noises=None):583latent = torch.from_numpy(code).cuda()584intermediate, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises,585start_layer=0, end_layer=3)586
587difference = initial_F - initial_intermediate588new_intermediate = intermediate + difference589
590I_G, _ = self([initial_S], input_is_latent=True, return_latents=False, noise=noises, start_layer=4,591end_layer=8, layer_in=new_intermediate)592I_G_0_1 = (I_G + 1) / 2593im = np.array(toPIL(I_G_0_1[0].cpu().detach().clamp(0, 1)))594return im595
596
597class ConvLayer(nn.Sequential):598def __init__(599self,600in_channel,601out_channel,602kernel_size,603downsample=False,604blur_kernel=[1, 3, 3, 1],605bias=True,606activate=True,607):608layers = []609
610if downsample:611factor = 2612p = (len(blur_kernel) - factor) + (kernel_size - 1)613pad0 = (p + 1) // 2614pad1 = p // 2615
616layers.append(Blur(blur_kernel, pad=(pad0, pad1)))617
618stride = 2619self.padding = 0620
621else:622stride = 1623self.padding = kernel_size // 2624
625layers.append(626EqualConv2d(627in_channel,628out_channel,629kernel_size,630padding=self.padding,631stride=stride,632bias=bias and not activate,633)634)635
636if activate:637if bias:638layers.append(FusedLeakyReLU(out_channel))639
640else:641layers.append(ScaledLeakyReLU(0.2))642
643super().__init__(*layers)644
645
646class ResBlock(nn.Module):647def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):648super().__init__()649
650self.conv1 = ConvLayer(in_channel, in_channel, 3)651self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)652
653self.skip = ConvLayer(654in_channel, out_channel, 1, downsample=True, activate=False, bias=False655)656
657def forward(self, input):658out = self.conv1(input)659out = self.conv2(out)660
661skip = self.skip(input)662out = (out + skip) / math.sqrt(2)663
664return out665
666
667class Discriminator(nn.Module):668def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):669super().__init__()670
671channels = {6724: 512,6738: 512,67416: 512,67532: 512,67664: 256 * channel_multiplier,677128: 128 * channel_multiplier,678256: 64 * channel_multiplier,679512: 32 * channel_multiplier,6801024: 16 * channel_multiplier,681}682
683convs = [ConvLayer(3, channels[size], 1)]684
685log_size = int(math.log(size, 2))686
687in_channel = channels[size]688
689for i in range(log_size, 2, -1):690out_channel = channels[2 ** (i - 1)]691
692convs.append(ResBlock(in_channel, out_channel, blur_kernel))693
694in_channel = out_channel695
696self.convs = nn.Sequential(*convs)697
698self.stddev_group = 4699self.stddev_feat = 1700
701self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)702self.final_linear = nn.Sequential(703EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),704EqualLinear(channels[4], 1),705)706
707def forward(self, input):708out = self.convs(input)709
710batch, channel, height, width = out.shape711group = min(batch, self.stddev_group)712stddev = out.view(713group, -1, self.stddev_feat, channel // self.stddev_feat, height, width714)715stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)716stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)717stddev = stddev.repeat(group, 1, height, width)718out = torch.cat([out, stddev], 1)719
720out = self.final_conv(out)721
722out = out.view(batch, -1)723out = self.final_linear(out)724
725return out726
727