lama

Форк
0
669 строк · 27.3 Кб
1
# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
2
import collections
3
from functools import partial
4
import functools
5
import logging
6
from collections import defaultdict
7

8
import numpy as np
9
import torch.nn as nn
10

11
from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
12
from saicinpainting.training.modules.ffc import FFCResnetBlock
13
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
14

15
class DotDict(defaultdict):
16
    # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
17
    """dot.notation access to dictionary attributes"""
18
    __getattr__ = defaultdict.get
19
    __setattr__ = defaultdict.__setitem__
20
    __delattr__ = defaultdict.__delitem__
21

22
class Identity(nn.Module):
23
    def __init__(self):
24
        super().__init__()
25

26
    def forward(self, x):
27
        return x
28

29

30
class ResnetBlock(nn.Module):
31
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
32
                 dilation=1, in_dim=None, groups=1, second_dilation=None):
33
        super(ResnetBlock, self).__init__()
34
        self.in_dim = in_dim
35
        self.dim = dim
36
        if second_dilation is None:
37
            second_dilation = dilation
38
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
39
                                                conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
40
                                                second_dilation=second_dilation)
41

42
        if self.in_dim is not None:
43
            self.input_conv = nn.Conv2d(in_dim, dim, 1)
44

45
        self.out_channnels = dim
46

47
    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
48
                         dilation=1, in_dim=None, groups=1, second_dilation=1):
49
        conv_layer = get_conv_block_ctor(conv_kind)
50

51
        conv_block = []
52
        p = 0
53
        if padding_type == 'reflect':
54
            conv_block += [nn.ReflectionPad2d(dilation)]
55
        elif padding_type == 'replicate':
56
            conv_block += [nn.ReplicationPad2d(dilation)]
57
        elif padding_type == 'zero':
58
            p = dilation
59
        else:
60
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
61

62
        if in_dim is None:
63
            in_dim = dim
64

65
        conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
66
                       norm_layer(dim),
67
                       activation]
68
        if use_dropout:
69
            conv_block += [nn.Dropout(0.5)]
70

71
        p = 0
72
        if padding_type == 'reflect':
73
            conv_block += [nn.ReflectionPad2d(second_dilation)]
74
        elif padding_type == 'replicate':
75
            conv_block += [nn.ReplicationPad2d(second_dilation)]
76
        elif padding_type == 'zero':
77
            p = second_dilation
78
        else:
79
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
80
        conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
81
                       norm_layer(dim)]
82

83
        return nn.Sequential(*conv_block)
84

85
    def forward(self, x):
86
        x_before = x
87
        if self.in_dim is not None:
88
            x = self.input_conv(x)
89
        out = x + self.conv_block(x_before)
90
        return out
91

92
class ResnetBlock5x5(nn.Module):
93
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
94
                 dilation=1, in_dim=None, groups=1, second_dilation=None):
95
        super(ResnetBlock5x5, self).__init__()
96
        self.in_dim = in_dim
97
        self.dim = dim
98
        if second_dilation is None:
99
            second_dilation = dilation
100
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
101
                                                conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
102
                                                second_dilation=second_dilation)
103

104
        if self.in_dim is not None:
105
            self.input_conv = nn.Conv2d(in_dim, dim, 1)
106

107
        self.out_channnels = dim
108

109
    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
110
                         dilation=1, in_dim=None, groups=1, second_dilation=1):
111
        conv_layer = get_conv_block_ctor(conv_kind)
112

113
        conv_block = []
114
        p = 0
115
        if padding_type == 'reflect':
116
            conv_block += [nn.ReflectionPad2d(dilation * 2)]
117
        elif padding_type == 'replicate':
118
            conv_block += [nn.ReplicationPad2d(dilation * 2)]
119
        elif padding_type == 'zero':
120
            p = dilation * 2
121
        else:
122
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
123

124
        if in_dim is None:
125
            in_dim = dim
126

127
        conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
128
                       norm_layer(dim),
129
                       activation]
130
        if use_dropout:
131
            conv_block += [nn.Dropout(0.5)]
132

133
        p = 0
134
        if padding_type == 'reflect':
135
            conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
136
        elif padding_type == 'replicate':
137
            conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
138
        elif padding_type == 'zero':
139
            p = second_dilation * 2
140
        else:
141
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
142
        conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
143
                       norm_layer(dim)]
144

145
        return nn.Sequential(*conv_block)
146

147
    def forward(self, x):
148
        x_before = x
149
        if self.in_dim is not None:
150
            x = self.input_conv(x)
151
        out = x + self.conv_block(x_before)
152
        return out
153

154

155
class MultidilatedResnetBlock(nn.Module):
156
    def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
157
        super().__init__()
158
        self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
159

160
    def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
161
        conv_block = []
162
        conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
163
                       norm_layer(dim),
164
                       activation]
165
        if use_dropout:
166
            conv_block += [nn.Dropout(0.5)]
167

168
        conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
169
                       norm_layer(dim)]
170

171
        return nn.Sequential(*conv_block)
172

173
    def forward(self, x):
174
        out = x + self.conv_block(x)
175
        return out
176

177

178
class MultiDilatedGlobalGenerator(nn.Module):
179
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
180
                 n_blocks=3, norm_layer=nn.BatchNorm2d,
181
                 padding_type='reflect', conv_kind='default',
182
                 deconv_kind='convtranspose', activation=nn.ReLU(True),
183
                 up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
184
                 add_out_act=True, max_features=1024, multidilation_kwargs={},
185
                 ffc_positions=None, ffc_kwargs={}):
186
        assert (n_blocks >= 0)
187
        super().__init__()
188

189
        conv_layer = get_conv_block_ctor(conv_kind)
190
        resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
191
        norm_layer = get_norm_layer(norm_layer)
192
        if affine is not None:
193
            norm_layer = partial(norm_layer, affine=affine)
194
        up_norm_layer = get_norm_layer(up_norm_layer)
195
        if affine is not None:
196
            up_norm_layer = partial(up_norm_layer, affine=affine)
197

198
        model = [nn.ReflectionPad2d(3),
199
                 conv_layer(input_nc, ngf, kernel_size=7, padding=0),
200
                 norm_layer(ngf),
201
                 activation]
202

203
        identity = Identity()
204
        ### downsample
205
        for i in range(n_downsampling):
206
            mult = 2 ** i
207

208
            model += [conv_layer(min(max_features, ngf * mult),
209
                                    min(max_features, ngf * mult * 2),
210
                                    kernel_size=3, stride=2, padding=1),
211
                        norm_layer(min(max_features, ngf * mult * 2)),
212
                        activation]
213

214
        mult = 2 ** n_downsampling
215
        feats_num_bottleneck = min(max_features, ngf * mult)
216

217
        ### resnet blocks
218
        for i in range(n_blocks):
219
            if ffc_positions is not None and i in ffc_positions:
220
                model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
221
                                         inline=True, **ffc_kwargs)]
222
            model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
223
                                              conv_layer=resnet_conv_layer, activation=activation,
224
                                              norm_layer=norm_layer)]
225

226
        ### upsample
227
        for i in range(n_downsampling):
228
            mult = 2 ** (n_downsampling - i)
229
            model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
230
        model += [nn.ReflectionPad2d(3),
231
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
232
        if add_out_act:
233
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
234
        self.model = nn.Sequential(*model)
235

236
    def forward(self, input):
237
        return self.model(input)
238

239
class ConfigGlobalGenerator(nn.Module):
240
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
241
                 n_blocks=3, norm_layer=nn.BatchNorm2d,
242
                 padding_type='reflect', conv_kind='default',
243
                 deconv_kind='convtranspose', activation=nn.ReLU(True),
244
                 up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
245
                 add_out_act=True, max_features=1024,
246
                 manual_block_spec=[],
247
                 resnet_block_kind='multidilatedresnetblock',
248
                 resnet_conv_kind='multidilated',
249
                 resnet_dilation=1,
250
                 multidilation_kwargs={}):
251
        assert (n_blocks >= 0)
252
        super().__init__()
253

254
        conv_layer = get_conv_block_ctor(conv_kind)
255
        resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
256
        norm_layer = get_norm_layer(norm_layer)
257
        if affine is not None:
258
            norm_layer = partial(norm_layer, affine=affine)
259
        up_norm_layer = get_norm_layer(up_norm_layer)
260
        if affine is not None:
261
            up_norm_layer = partial(up_norm_layer, affine=affine)
262

263
        model = [nn.ReflectionPad2d(3),
264
                 conv_layer(input_nc, ngf, kernel_size=7, padding=0),
265
                 norm_layer(ngf),
266
                 activation]
267

268
        identity = Identity()
269

270
        ### downsample
271
        for i in range(n_downsampling):
272
            mult = 2 ** i
273
            model += [conv_layer(min(max_features, ngf * mult),
274
                                    min(max_features, ngf * mult * 2),
275
                                    kernel_size=3, stride=2, padding=1),
276
                        norm_layer(min(max_features, ngf * mult * 2)),
277
                        activation]
278

279
        mult = 2 ** n_downsampling
280
        feats_num_bottleneck = min(max_features, ngf * mult)
281

282
        if len(manual_block_spec) == 0:
283
            manual_block_spec = [
284
                DotDict(lambda : None, {
285
                    'n_blocks': n_blocks,
286
                    'use_default': True})
287
            ]
288

289
        ### resnet blocks
290
        for block_spec in manual_block_spec:
291
            def make_and_add_blocks(model, block_spec):
292
                block_spec = DotDict(lambda : None, block_spec)
293
                if not block_spec.use_default:
294
                    resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
295
                    resnet_conv_kind = block_spec.resnet_conv_kind
296
                    resnet_block_kind = block_spec.resnet_block_kind
297
                    if block_spec.resnet_dilation is not None:
298
                        resnet_dilation = block_spec.resnet_dilation
299
                for i in range(block_spec.n_blocks):
300
                    if resnet_block_kind == "multidilatedresnetblock":
301
                        model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
302
                                                        conv_layer=resnet_conv_layer, activation=activation,
303
                                                        norm_layer=norm_layer)]
304
                    if resnet_block_kind == "resnetblock":                                            
305
                        model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
306
                                            conv_kind=resnet_conv_kind)]
307
                    if resnet_block_kind == "resnetblock5x5":                                            
308
                        model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
309
                                            conv_kind=resnet_conv_kind)]
310
                    if resnet_block_kind == "resnetblockdwdil":
311
                        model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
312
                                            conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
313
            make_and_add_blocks(model, block_spec)
314
        
315
        ### upsample
316
        for i in range(n_downsampling):
317
            mult = 2 ** (n_downsampling - i)
318
            model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
319
        model += [nn.ReflectionPad2d(3),
320
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
321
        if add_out_act:
322
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
323
        self.model = nn.Sequential(*model)
324

325
    def forward(self, input):
326
        return self.model(input)
327

328

329
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
330
    blocks = []
331
    for i in range(dilated_blocks_n):
332
        if dilation_block_kind == 'simple':
333
            blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
334
        elif dilation_block_kind == 'multi':
335
            blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
336
        else:
337
            raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
338
    return blocks
339

340

341
class GlobalGenerator(nn.Module):
342
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
343
                 padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
344
                 up_norm_layer=nn.BatchNorm2d, affine=None,
345
                 up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
346
                 dilated_blocks_n_middle=0,
347
                 add_out_act=True,
348
                 max_features=1024, is_resblock_depthwise=False,
349
                 ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
350
                 dilation_block_kind='simple', multidilation_kwargs={}):
351
        assert (n_blocks >= 0)
352
        super().__init__()
353

354
        conv_layer = get_conv_block_ctor(conv_kind)
355
        norm_layer = get_norm_layer(norm_layer)
356
        if affine is not None:
357
            norm_layer = partial(norm_layer, affine=affine)
358
        up_norm_layer = get_norm_layer(up_norm_layer)
359
        if affine is not None:
360
            up_norm_layer = partial(up_norm_layer, affine=affine)
361

362
        if ffc_positions is not None:
363
            ffc_positions = collections.Counter(ffc_positions)
364

365
        model = [nn.ReflectionPad2d(3),
366
                 conv_layer(input_nc, ngf, kernel_size=7, padding=0),
367
                 norm_layer(ngf),
368
                 activation]
369

370
        identity = Identity()
371
        ### downsample
372
        for i in range(n_downsampling):
373
            mult = 2 ** i
374

375
            model += [conv_layer(min(max_features, ngf * mult),
376
                                min(max_features, ngf * mult * 2),
377
                                kernel_size=3, stride=2, padding=1),
378
                        norm_layer(min(max_features, ngf * mult * 2)),
379
                        activation]
380

381
        mult = 2 ** n_downsampling
382
        feats_num_bottleneck = min(max_features, ngf * mult)
383

384
        dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
385
                                    activation=activation, norm_layer=norm_layer)
386
        if dilation_block_kind == 'simple':
387
            dilated_block_kwargs['conv_kind'] = conv_kind
388
        elif dilation_block_kind == 'multi':
389
            dilated_block_kwargs['conv_layer'] = functools.partial(
390
                get_conv_block_ctor('multidilated'), **multidilation_kwargs)
391

392
        # dilated blocks at the start of the bottleneck sausage
393
        if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
394
            model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
395

396
        # resnet blocks
397
        for i in range(n_blocks):
398
            # dilated blocks at the middle of the bottleneck sausage
399
            if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
400
                model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
401
            
402
            if ffc_positions is not None and i in ffc_positions:
403
                for _ in range(ffc_positions[i]):  # same position can occur more than once
404
                    model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
405
                                             inline=True, **ffc_kwargs)]
406

407
            if is_resblock_depthwise:
408
                resblock_groups = feats_num_bottleneck
409
            else:
410
                resblock_groups = 1
411

412
            model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
413
                                    norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
414
                                    dilation=dilation, second_dilation=second_dilation)]
415
            
416

417
        # dilated blocks at the end of the bottleneck sausage
418
        if dilated_blocks_n is not None and dilated_blocks_n > 0:
419
            model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
420

421
        # upsample
422
        for i in range(n_downsampling):
423
            mult = 2 ** (n_downsampling - i)
424
            model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
425
                                         min(max_features, int(ngf * mult / 2)),
426
                                         kernel_size=3, stride=2, padding=1, output_padding=1),
427
                      up_norm_layer(min(max_features, int(ngf * mult / 2))),
428
                      up_activation]
429
        model += [nn.ReflectionPad2d(3),
430
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
431
        if add_out_act:
432
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
433
        self.model = nn.Sequential(*model)
434

435
    def forward(self, input):
436
        return self.model(input)
437

438

439
class GlobalGeneratorGated(GlobalGenerator):
440
    def __init__(self, *args, **kwargs):
441
        real_kwargs=dict(
442
            conv_kind='gated_bn_relu',
443
            activation=nn.Identity(),
444
            norm_layer=nn.Identity
445
        )
446
        real_kwargs.update(kwargs)
447
        super().__init__(*args, **real_kwargs)
448

449

450
class GlobalGeneratorFromSuperChannels(nn.Module):
451
    def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
452
        super().__init__()
453
        self.n_downsampling = n_downsampling
454
        norm_layer = get_norm_layer(norm_layer)
455
        if type(norm_layer) == functools.partial:
456
            use_bias = (norm_layer.func == nn.InstanceNorm2d)
457
        else:
458
            use_bias = (norm_layer == nn.InstanceNorm2d)
459

460
        channels = self.convert_super_channels(super_channels)
461
        self.channels = channels
462

463
        model = [nn.ReflectionPad2d(3),
464
                 nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
465
                 norm_layer(channels[0]),
466
                 nn.ReLU(True)]
467

468
        for i in range(n_downsampling):  # add downsampling layers
469
            mult = 2 ** i
470
            model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
471
                      norm_layer(channels[1+i]),
472
                      nn.ReLU(True)]
473

474
        mult = 2 ** n_downsampling
475

476
        n_blocks1 = n_blocks // 3
477
        n_blocks2 = n_blocks1
478
        n_blocks3 = n_blocks - n_blocks1 - n_blocks2
479

480
        for i in range(n_blocks1):
481
            c = n_downsampling
482
            dim = channels[c]
483
            model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
484

485
        for i in range(n_blocks2):
486
            c = n_downsampling+1
487
            dim = channels[c]
488
            kwargs = {}
489
            if i == 0:
490
                kwargs = {"in_dim": channels[c-1]}
491
            model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
492

493
        for i in range(n_blocks3):
494
            c = n_downsampling+2
495
            dim = channels[c]
496
            kwargs = {}
497
            if i == 0:
498
                kwargs = {"in_dim": channels[c-1]}
499
            model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
500

501
        for i in range(n_downsampling):  # add upsampling layers
502
            mult = 2 ** (n_downsampling - i)
503
            model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
504
                                           channels[n_downsampling+3+i+1],
505
                                           kernel_size=3, stride=2,
506
                                           padding=1, output_padding=1,
507
                                           bias=use_bias),
508
                      norm_layer(channels[n_downsampling+3+i+1]),
509
                      nn.ReLU(True)]
510
        model += [nn.ReflectionPad2d(3)]
511
        model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
512

513
        if add_out_act:
514
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
515
        self.model = nn.Sequential(*model)
516

517
    def convert_super_channels(self, super_channels):
518
        n_downsampling = self.n_downsampling
519
        result = []
520
        cnt = 0
521

522
        if n_downsampling == 2:
523
            N1 = 10
524
        elif n_downsampling == 3:
525
            N1 = 13
526
        else:
527
            raise NotImplementedError
528

529
        for i in range(0, N1):
530
            if i in [1,4,7,10]:
531
                channel = super_channels[cnt] * (2 ** cnt)
532
                config = {'channel': channel}
533
                result.append(channel)
534
                logging.info(f"Downsample channels {result[-1]}")
535
                cnt += 1
536

537
        for i in range(3):
538
            for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
539
                if len(super_channels) == 6:
540
                    channel = super_channels[3] * 4
541
                else:
542
                    channel = super_channels[i + 3] * 4
543
                config = {'channel': channel}
544
                if counter == 0:
545
                    result.append(channel)
546
                    logging.info(f"Bottleneck channels {result[-1]}")
547
        cnt = 2
548

549
        for i in range(N1+9, N1+21):
550
            if i in [22, 25,28]:
551
                cnt -= 1
552
                if len(super_channels) == 6:
553
                    channel = super_channels[5 - cnt] * (2 ** cnt)
554
                else:
555
                    channel = super_channels[7 - cnt] * (2 ** cnt)
556
                result.append(int(channel))
557
                logging.info(f"Upsample channels {result[-1]}")
558
        return result
559

560
    def forward(self, input):
561
        return self.model(input)
562

563

564
# Defines the PatchGAN discriminator with the specified arguments.
565
class NLayerDiscriminator(BaseDiscriminator):
566
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
567
        super().__init__()
568
        self.n_layers = n_layers
569

570
        kw = 4
571
        padw = int(np.ceil((kw-1.0)/2))
572
        sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
573
                     nn.LeakyReLU(0.2, True)]]
574

575
        nf = ndf
576
        for n in range(1, n_layers):
577
            nf_prev = nf
578
            nf = min(nf * 2, 512)
579

580
            cur_model = []
581
            cur_model += [
582
                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
583
                norm_layer(nf),
584
                nn.LeakyReLU(0.2, True)
585
            ]
586
            sequence.append(cur_model)
587

588
        nf_prev = nf
589
        nf = min(nf * 2, 512)
590

591
        cur_model = []
592
        cur_model += [
593
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
594
            norm_layer(nf),
595
            nn.LeakyReLU(0.2, True)
596
        ]
597
        sequence.append(cur_model)
598

599
        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
600

601
        for n in range(len(sequence)):
602
            setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
603

604
    def get_all_activations(self, x):
605
        res = [x]
606
        for n in range(self.n_layers + 2):
607
            model = getattr(self, 'model' + str(n))
608
            res.append(model(res[-1]))
609
        return res[1:]
610

611
    def forward(self, x):
612
        act = self.get_all_activations(x)
613
        return act[-1], act[:-1]
614

615

616
class MultidilatedNLayerDiscriminator(BaseDiscriminator):
617
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
618
        super().__init__()
619
        self.n_layers = n_layers
620

621
        kw = 4
622
        padw = int(np.ceil((kw-1.0)/2))
623
        sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
624
                     nn.LeakyReLU(0.2, True)]]
625

626
        nf = ndf
627
        for n in range(1, n_layers):
628
            nf_prev = nf
629
            nf = min(nf * 2, 512)
630

631
            cur_model = []
632
            cur_model += [
633
                MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
634
                norm_layer(nf),
635
                nn.LeakyReLU(0.2, True)
636
            ]
637
            sequence.append(cur_model)
638

639
        nf_prev = nf
640
        nf = min(nf * 2, 512)
641

642
        cur_model = []
643
        cur_model += [
644
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
645
            norm_layer(nf),
646
            nn.LeakyReLU(0.2, True)
647
        ]
648
        sequence.append(cur_model)
649

650
        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
651

652
        for n in range(len(sequence)):
653
            setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
654

655
    def get_all_activations(self, x):
656
        res = [x]
657
        for n in range(self.n_layers + 2):
658
            model = getattr(self, 'model' + str(n))
659
            res.append(model(res[-1]))
660
        return res[1:]
661

662
    def forward(self, x):
663
        act = self.get_all_activations(x)
664
        return act[-1], act[:-1]
665

666

667
class NLayerDiscriminatorAsGen(NLayerDiscriminator):
668
    def forward(self, x):
669
        return super().forward(x)[0]
670

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.