HairFastGAN

Форк
0
725 строк · 20.1 Кб
1
import math
2
import random
3
import functools
4
import operator
5

6
import torch
7
from torch import nn
8
from torch.nn import functional as F
9
from torch.autograd import Function
10

11
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
12
import torchvision
13
toPIL = torchvision.transforms.ToPILImage()
14
import numpy as np
15

16
class PixelNorm(nn.Module):
17
    def __init__(self):
18
        super().__init__()
19

20
    def forward(self, input):
21
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
22

23

24
def make_kernel(k):
25
    k = torch.tensor(k, dtype=torch.float32)
26

27
    if k.ndim == 1:
28
        k = k[None, :] * k[:, None]
29

30
    k /= k.sum()
31

32
    return k
33

34

35
class Upsample(nn.Module):
36
    def __init__(self, kernel, factor=2):
37
        super().__init__()
38

39
        self.factor = factor
40
        kernel = make_kernel(kernel) * (factor ** 2)
41
        self.register_buffer('kernel', kernel)
42

43
        p = kernel.shape[0] - factor
44

45
        pad0 = (p + 1) // 2 + factor - 1
46
        pad1 = p // 2
47

48
        self.pad = (pad0, pad1)
49

50
    def forward(self, input):
51
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
52

53
        return out
54

55

56
class Downsample(nn.Module):
57
    def __init__(self, kernel, factor=2):
58
        super().__init__()
59

60
        self.factor = factor
61
        kernel = make_kernel(kernel)
62
        self.register_buffer('kernel', kernel)
63

64
        p = kernel.shape[0] - factor
65

66
        pad0 = (p + 1) // 2
67
        pad1 = p // 2
68

69
        self.pad = (pad0, pad1)
70

71
    def forward(self, input):
72
        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
73

74
        return out
75

76

77
class Blur(nn.Module):
78
    def __init__(self, kernel, pad, upsample_factor=1):
79
        super().__init__()
80

81
        kernel = make_kernel(kernel)
82

83
        if upsample_factor > 1:
84
            kernel = kernel * (upsample_factor ** 2)
85

86
        self.register_buffer('kernel', kernel)
87

88
        self.pad = pad
89

90
    def forward(self, input):
91
        out = upfirdn2d(input, self.kernel, pad=self.pad)
92

93
        return out
94

95

96
class EqualConv2d(nn.Module):
97
    def __init__(
98
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
99
    ):
100
        super().__init__()
101

102
        self.weight = nn.Parameter(
103
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
104
        )
105
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
106

107
        self.stride = stride
108
        self.padding = padding
109

110
        if bias:
111
            self.bias = nn.Parameter(torch.zeros(out_channel))
112

113
        else:
114
            self.bias = None
115

116
    def forward(self, input):
117
        out = F.conv2d(
118
            input,
119
            self.weight * self.scale,
120
            bias=self.bias,
121
            stride=self.stride,
122
            padding=self.padding,
123
        )
124

125
        return out
126

127
    def __repr__(self):
128
        return (
129
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
130
            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
131
        )
132

133

134
class EqualLinear(nn.Module):
135
    def __init__(
136
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
137
    ):
138
        super().__init__()
139

140
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
141

142
        if bias:
143
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
144

145
        else:
146
            self.bias = None
147

148
        self.activation = activation
149

150
        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
151
        self.lr_mul = lr_mul
152

153
    def forward(self, input):
154
        if self.activation:
155
            out = F.linear(input, self.weight * self.scale)
156
            out = fused_leaky_relu(out, self.bias * self.lr_mul)
157

158
        else:
159
            out = F.linear(
160
                input, self.weight * self.scale, bias=self.bias * self.lr_mul
161
            )
162

163
        return out
164

165
    def __repr__(self):
166
        return (
167
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
168
        )
169

170

171
class ScaledLeakyReLU(nn.Module):
172
    def __init__(self, negative_slope=0.2):
173
        super().__init__()
174

175
        self.negative_slope = negative_slope
176

177
    def forward(self, input):
178
        out = F.leaky_relu(input, negative_slope=self.negative_slope)
179

180
        return out * math.sqrt(2)
181

182

183
class ModulatedConv2d(nn.Module):
184
    def __init__(
185
        self,
186
        in_channel,
187
        out_channel,
188
        kernel_size,
189
        style_dim,
190
        demodulate=True,
191
        upsample=False,
192
        downsample=False,
193
        blur_kernel=[1, 3, 3, 1],
194
    ):
195
        super().__init__()
196

197
        self.eps = 1e-8
198
        self.kernel_size = kernel_size
199
        self.in_channel = in_channel
200
        self.out_channel = out_channel
201
        self.upsample = upsample
202
        self.downsample = downsample
203

204
        if upsample:
205
            factor = 2
206
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
207
            pad0 = (p + 1) // 2 + factor - 1
208
            pad1 = p // 2 + 1
209

210
            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
211

212
        if downsample:
213
            factor = 2
214
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
215
            pad0 = (p + 1) // 2
216
            pad1 = p // 2
217

218
            self.blur = Blur(blur_kernel, pad=(pad0, pad1))
219

220
        fan_in = in_channel * kernel_size ** 2
221
        self.scale = 1 / math.sqrt(fan_in)
222
        self.padding = kernel_size // 2
223

224
        self.weight = nn.Parameter(
225
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
226
        )
227

228
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
229

230
        self.demodulate = demodulate
231

232
    def __repr__(self):
233
        return (
234
            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
235
            f'upsample={self.upsample}, downsample={self.downsample})'
236
        )
237

238
    def forward(self, input, style):
239
        batch, in_channel, height, width = input.shape
240

241
        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
242
        weight = self.scale * self.weight * style
243

244
        if self.demodulate:
245
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
246
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
247

248
        weight = weight.view(
249
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
250
        )
251

252
        if self.upsample:
253
            input = input.view(1, batch * in_channel, height, width)
254
            weight = weight.view(
255
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
256
            )
257
            weight = weight.transpose(1, 2).reshape(
258
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
259
            )
260
            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
261
            _, _, height, width = out.shape
262
            out = out.view(batch, self.out_channel, height, width)
263
            out = self.blur(out)
264

265
        elif self.downsample:
266
            input = self.blur(input)
267
            _, _, height, width = input.shape
268
            input = input.view(1, batch * in_channel, height, width)
269
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
270
            _, _, height, width = out.shape
271
            out = out.view(batch, self.out_channel, height, width)
272

273
        else:
274
            input = input.view(1, batch * in_channel, height, width)
275
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
276
            _, _, height, width = out.shape
277
            out = out.view(batch, self.out_channel, height, width)
278

279
        return out
280

281

282
class NoiseInjection(nn.Module):
283
    def __init__(self):
284
        super().__init__()
285

286
        self.weight = nn.Parameter(torch.zeros(1))
287

288
    def forward(self, image, noise=None):
289
        if noise is None:
290
            batch, _, height, width = image.shape
291
            noise = image.new_empty(batch, 1, height, width).normal_()
292

293
        return image + self.weight * noise
294

295

296
class ConstantInput(nn.Module):
297
    def __init__(self, channel, size=4):
298
        super().__init__()
299

300
        self.input = nn.Parameter(torch.randn(1, channel, size, size))
301

302
    def forward(self, input):
303
        batch = input.shape[0]
304
        out = self.input.repeat(batch, 1, 1, 1)
305

306
        return out
307

308

309
class StyledConv(nn.Module):
310
    def __init__(
311
        self,
312
        in_channel,
313
        out_channel,
314
        kernel_size,
315
        style_dim,
316
        upsample=False,
317
        blur_kernel=[1, 3, 3, 1],
318
        demodulate=True,
319
    ):
320
        super().__init__()
321

322
        self.conv = ModulatedConv2d(
323
            in_channel,
324
            out_channel,
325
            kernel_size,
326
            style_dim,
327
            upsample=upsample,
328
            blur_kernel=blur_kernel,
329
            demodulate=demodulate,
330
        )
331

332
        self.noise = NoiseInjection()
333
        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
334
        # self.activate = ScaledLeakyReLU(0.2)
335
        self.activate = FusedLeakyReLU(out_channel)
336

337
    def forward(self, input, style, noise=None):
338
        out = self.conv(input, style)
339
        out = self.noise(out, noise=noise)
340
        # out = out + self.bias
341
        out = self.activate(out)
342

343
        return out
344

345

346
class ToRGB(nn.Module):
347
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
348
        super().__init__()
349

350
        if upsample:
351
            self.upsample = Upsample(blur_kernel)
352

353
        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
354
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
355

356
    def forward(self, input, style, skip=None):
357
        out = self.conv(input, style)
358
        out = out + self.bias
359

360
        if skip is not None:
361
            skip = self.upsample(skip)
362

363
            out = out + skip
364

365
        return out
366

367

368
class Generator(nn.Module):
369
    def __init__(
370
        self,
371
        size,
372
        style_dim,
373
        n_mlp,
374
        channel_multiplier=2,
375
        blur_kernel=[1, 3, 3, 1],
376
        lr_mlp=0.01,
377
    ):
378
        super().__init__()
379

380
        self.size = size
381

382
        self.style_dim = style_dim
383

384
        layers = [PixelNorm()]
385

386
        for i in range(n_mlp):
387
            layers.append(
388
                EqualLinear(
389
                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
390
                )
391
            )
392

393
        self.style = nn.Sequential(*layers)
394

395
        self.channels = {
396
            4: 512,
397
            8: 512,
398
            16: 512,
399
            32: 512,
400
            64: 256 * channel_multiplier,
401
            128: 128 * channel_multiplier,
402
            256: 64 * channel_multiplier,
403
            512: 32 * channel_multiplier,
404
            1024: 16 * channel_multiplier,
405
        }
406

407

408
        self.input = ConstantInput(self.channels[4])
409
        self.conv1 = StyledConv(
410
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
411
        )
412
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
413

414
        self.log_size = int(math.log(size, 2))
415
        self.num_layers = (self.log_size - 2) * 2 + 1
416

417
        self.convs = nn.ModuleList()
418
        self.upsamples = nn.ModuleList()
419
        self.to_rgbs = nn.ModuleList()
420
        self.noises = nn.Module()
421

422
        in_channel = self.channels[4]
423

424
        for layer_idx in range(self.num_layers):
425
            res = (layer_idx + 5) // 2
426
            shape = [1, 1, 2 ** res, 2 ** res]
427
            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
428

429
        for i in range(3, self.log_size + 1):
430
            out_channel = self.channels[2 ** i]
431

432
            self.convs.append(
433
                StyledConv(
434
                    in_channel,
435
                    out_channel,
436
                    3,
437
                    style_dim,
438
                    upsample=True,
439
                    blur_kernel=blur_kernel,
440
                )
441
            )
442

443
            self.convs.append(
444
                StyledConv(
445
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
446
                )
447
            )
448

449
            self.to_rgbs.append(ToRGB(out_channel, style_dim))
450

451
            in_channel = out_channel
452

453
        self.n_latent = self.log_size * 2 - 2
454

455
    def make_noise(self):
456
        device = self.input.input.device
457

458
        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
459

460
        for i in range(3, self.log_size + 1):
461
            for _ in range(2):
462
                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
463

464
        return noises
465

466
    def mean_latent(self, n_latent):
467
        latent_in = torch.randn(
468
            n_latent, self.style_dim, device=self.input.input.device
469
        )
470
        latent = self.style(latent_in).mean(0, keepdim=True)
471

472
        return latent
473

474
    def get_latent(self, input):
475
        return self.style(input)
476

477
    def forward(
478
            self,
479
            styles,
480
            return_latents=False,
481
            inject_index=None,
482
            truncation=1,
483
            truncation_latent=None,
484
            input_is_latent=False,
485
            noise=None,
486
            randomize_noise=True,
487
            layer_in=None,
488
            skip=None,
489
            start_layer=0,
490
            end_layer=8,
491
            return_rgb=False,
492

493
    ):
494
        if not input_is_latent:
495
            styles = [self.style(s) for s in styles]
496

497
        if noise is None:
498
            if randomize_noise:
499
                noise = [None] * self.num_layers
500
            else:
501
                noise = [
502
                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
503
                ]
504

505
        if truncation < 1:
506
            style_t = []
507

508
            for style in styles:
509
                style_t.append(
510
                    truncation_latent + truncation * (style - truncation_latent)
511
                )
512

513
            styles = style_t
514

515
        if len(styles) < 2:
516
            inject_index = self.n_latent
517

518
            if styles[0].ndim < 3:
519
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
520

521
            else:
522
                latent = styles[0]
523

524
        else:
525
            if inject_index is None:
526
                inject_index = random.randint(1, self.n_latent - 1)
527

528
            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
529
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
530

531
            latent = torch.cat([latent, latent2], 1)
532
        out = self.input(latent)
533

534
        if start_layer == 0:
535
            out = self.conv1(out, latent[:, 0], noise=noise[0])  # 0th layer
536
            skip = self.to_rgb1(out, latent[:, 1])
537
        if end_layer == 0:
538
            return out, skip
539
        i = 1
540
        current_layer = 1
541
        for conv1, conv2, noise1, noise2, to_rgb in zip(
542
                self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
543
        ):
544
            if current_layer < start_layer:
545
                pass
546
            elif current_layer == start_layer:
547
                out = conv1(layer_in, latent[:, i], noise=noise1)
548
                out = conv2(out, latent[:, i + 1], noise=noise2)
549
                skip = to_rgb(out, latent[:, i + 2], skip)
550
            elif current_layer > end_layer:
551
                return out, skip
552
            else:
553
                out = conv1(out, latent[:, i], noise=noise1)
554
                out = conv2(out, latent[:, i + 1], noise=noise2)
555
                skip = to_rgb(out, latent[:, i + 2], skip)
556
            current_layer += 1
557
            i += 2
558

559
        image = skip
560

561
        if return_latents:
562
            return image, latent
563

564
        else:
565
            return image, None
566

567
    def generate_im_from_w_space(self, code, noises=None):
568
        latent = torch.from_numpy(code).cuda()
569
        I_G, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises, start_layer=0,
570
                           end_layer=8)
571
        I_G_0_1 = (I_G + 1) / 2
572
        im = np.array(toPIL(I_G_0_1[0].cpu().detach().clamp(0, 1)))
573
        return im
574

575
    def generate_initial_intermediate(self, code, noises=None):
576
        latent = torch.from_numpy(code).cuda()
577
        intermediate, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises,
578
                                            start_layer=0, end_layer=3)
579
        return intermediate
580

581

582
    def update_on_FS(self, code, initial_intermediate, initial_F, initial_S, noises=None):
583
        latent = torch.from_numpy(code).cuda()
584
        intermediate, _ = self([latent], input_is_latent=True, return_latents=False, noise=noises,
585
                               start_layer=0, end_layer=3)
586

587
        difference = initial_F - initial_intermediate
588
        new_intermediate = intermediate + difference
589

590
        I_G, _ = self([initial_S], input_is_latent=True, return_latents=False, noise=noises, start_layer=4,
591
                      end_layer=8, layer_in=new_intermediate)
592
        I_G_0_1 = (I_G + 1) / 2
593
        im = np.array(toPIL(I_G_0_1[0].cpu().detach().clamp(0, 1)))
594
        return im
595

596

597
class ConvLayer(nn.Sequential):
598
    def __init__(
599
        self,
600
        in_channel,
601
        out_channel,
602
        kernel_size,
603
        downsample=False,
604
        blur_kernel=[1, 3, 3, 1],
605
        bias=True,
606
        activate=True,
607
    ):
608
        layers = []
609

610
        if downsample:
611
            factor = 2
612
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
613
            pad0 = (p + 1) // 2
614
            pad1 = p // 2
615

616
            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
617

618
            stride = 2
619
            self.padding = 0
620

621
        else:
622
            stride = 1
623
            self.padding = kernel_size // 2
624

625
        layers.append(
626
            EqualConv2d(
627
                in_channel,
628
                out_channel,
629
                kernel_size,
630
                padding=self.padding,
631
                stride=stride,
632
                bias=bias and not activate,
633
            )
634
        )
635

636
        if activate:
637
            if bias:
638
                layers.append(FusedLeakyReLU(out_channel))
639

640
            else:
641
                layers.append(ScaledLeakyReLU(0.2))
642

643
        super().__init__(*layers)
644

645

646
class ResBlock(nn.Module):
647
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
648
        super().__init__()
649

650
        self.conv1 = ConvLayer(in_channel, in_channel, 3)
651
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
652

653
        self.skip = ConvLayer(
654
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
655
        )
656

657
    def forward(self, input):
658
        out = self.conv1(input)
659
        out = self.conv2(out)
660

661
        skip = self.skip(input)
662
        out = (out + skip) / math.sqrt(2)
663

664
        return out
665

666

667
class Discriminator(nn.Module):
668
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
669
        super().__init__()
670

671
        channels = {
672
            4: 512,
673
            8: 512,
674
            16: 512,
675
            32: 512,
676
            64: 256 * channel_multiplier,
677
            128: 128 * channel_multiplier,
678
            256: 64 * channel_multiplier,
679
            512: 32 * channel_multiplier,
680
            1024: 16 * channel_multiplier,
681
        }
682

683
        convs = [ConvLayer(3, channels[size], 1)]
684

685
        log_size = int(math.log(size, 2))
686

687
        in_channel = channels[size]
688

689
        for i in range(log_size, 2, -1):
690
            out_channel = channels[2 ** (i - 1)]
691

692
            convs.append(ResBlock(in_channel, out_channel, blur_kernel))
693

694
            in_channel = out_channel
695

696
        self.convs = nn.Sequential(*convs)
697

698
        self.stddev_group = 4
699
        self.stddev_feat = 1
700

701
        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
702
        self.final_linear = nn.Sequential(
703
            EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
704
            EqualLinear(channels[4], 1),
705
        )
706

707
    def forward(self, input):
708
        out = self.convs(input)
709

710
        batch, channel, height, width = out.shape
711
        group = min(batch, self.stddev_group)
712
        stddev = out.view(
713
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
714
        )
715
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
716
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
717
        stddev = stddev.repeat(group, 1, height, width)
718
        out = torch.cat([out, stddev], 1)
719

720
        out = self.final_conv(out)
721

722
        out = out.view(batch, -1)
723
        out = self.final_linear(out)
724

725
        return out
726

727

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.