lama
669 строк · 27.3 Кб
1# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
2import collections3from functools import partial4import functools5import logging6from collections import defaultdict7
8import numpy as np9import torch.nn as nn10
11from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation12from saicinpainting.training.modules.ffc import FFCResnetBlock13from saicinpainting.training.modules.multidilated_conv import MultidilatedConv14
15class DotDict(defaultdict):16# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary17"""dot.notation access to dictionary attributes"""18__getattr__ = defaultdict.get19__setattr__ = defaultdict.__setitem__20__delattr__ = defaultdict.__delitem__21
22class Identity(nn.Module):23def __init__(self):24super().__init__()25
26def forward(self, x):27return x28
29
30class ResnetBlock(nn.Module):31def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',32dilation=1, in_dim=None, groups=1, second_dilation=None):33super(ResnetBlock, self).__init__()34self.in_dim = in_dim35self.dim = dim36if second_dilation is None:37second_dilation = dilation38self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,39conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,40second_dilation=second_dilation)41
42if self.in_dim is not None:43self.input_conv = nn.Conv2d(in_dim, dim, 1)44
45self.out_channnels = dim46
47def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',48dilation=1, in_dim=None, groups=1, second_dilation=1):49conv_layer = get_conv_block_ctor(conv_kind)50
51conv_block = []52p = 053if padding_type == 'reflect':54conv_block += [nn.ReflectionPad2d(dilation)]55elif padding_type == 'replicate':56conv_block += [nn.ReplicationPad2d(dilation)]57elif padding_type == 'zero':58p = dilation59else:60raise NotImplementedError('padding [%s] is not implemented' % padding_type)61
62if in_dim is None:63in_dim = dim64
65conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),66norm_layer(dim),67activation]68if use_dropout:69conv_block += [nn.Dropout(0.5)]70
71p = 072if padding_type == 'reflect':73conv_block += [nn.ReflectionPad2d(second_dilation)]74elif padding_type == 'replicate':75conv_block += [nn.ReplicationPad2d(second_dilation)]76elif padding_type == 'zero':77p = second_dilation78else:79raise NotImplementedError('padding [%s] is not implemented' % padding_type)80conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),81norm_layer(dim)]82
83return nn.Sequential(*conv_block)84
85def forward(self, x):86x_before = x87if self.in_dim is not None:88x = self.input_conv(x)89out = x + self.conv_block(x_before)90return out91
92class ResnetBlock5x5(nn.Module):93def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',94dilation=1, in_dim=None, groups=1, second_dilation=None):95super(ResnetBlock5x5, self).__init__()96self.in_dim = in_dim97self.dim = dim98if second_dilation is None:99second_dilation = dilation100self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,101conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,102second_dilation=second_dilation)103
104if self.in_dim is not None:105self.input_conv = nn.Conv2d(in_dim, dim, 1)106
107self.out_channnels = dim108
109def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',110dilation=1, in_dim=None, groups=1, second_dilation=1):111conv_layer = get_conv_block_ctor(conv_kind)112
113conv_block = []114p = 0115if padding_type == 'reflect':116conv_block += [nn.ReflectionPad2d(dilation * 2)]117elif padding_type == 'replicate':118conv_block += [nn.ReplicationPad2d(dilation * 2)]119elif padding_type == 'zero':120p = dilation * 2121else:122raise NotImplementedError('padding [%s] is not implemented' % padding_type)123
124if in_dim is None:125in_dim = dim126
127conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),128norm_layer(dim),129activation]130if use_dropout:131conv_block += [nn.Dropout(0.5)]132
133p = 0134if padding_type == 'reflect':135conv_block += [nn.ReflectionPad2d(second_dilation * 2)]136elif padding_type == 'replicate':137conv_block += [nn.ReplicationPad2d(second_dilation * 2)]138elif padding_type == 'zero':139p = second_dilation * 2140else:141raise NotImplementedError('padding [%s] is not implemented' % padding_type)142conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),143norm_layer(dim)]144
145return nn.Sequential(*conv_block)146
147def forward(self, x):148x_before = x149if self.in_dim is not None:150x = self.input_conv(x)151out = x + self.conv_block(x_before)152return out153
154
155class MultidilatedResnetBlock(nn.Module):156def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):157super().__init__()158self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)159
160def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):161conv_block = []162conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),163norm_layer(dim),164activation]165if use_dropout:166conv_block += [nn.Dropout(0.5)]167
168conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),169norm_layer(dim)]170
171return nn.Sequential(*conv_block)172
173def forward(self, x):174out = x + self.conv_block(x)175return out176
177
178class MultiDilatedGlobalGenerator(nn.Module):179def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,180n_blocks=3, norm_layer=nn.BatchNorm2d,181padding_type='reflect', conv_kind='default',182deconv_kind='convtranspose', activation=nn.ReLU(True),183up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),184add_out_act=True, max_features=1024, multidilation_kwargs={},185ffc_positions=None, ffc_kwargs={}):186assert (n_blocks >= 0)187super().__init__()188
189conv_layer = get_conv_block_ctor(conv_kind)190resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)191norm_layer = get_norm_layer(norm_layer)192if affine is not None:193norm_layer = partial(norm_layer, affine=affine)194up_norm_layer = get_norm_layer(up_norm_layer)195if affine is not None:196up_norm_layer = partial(up_norm_layer, affine=affine)197
198model = [nn.ReflectionPad2d(3),199conv_layer(input_nc, ngf, kernel_size=7, padding=0),200norm_layer(ngf),201activation]202
203identity = Identity()204### downsample205for i in range(n_downsampling):206mult = 2 ** i207
208model += [conv_layer(min(max_features, ngf * mult),209min(max_features, ngf * mult * 2),210kernel_size=3, stride=2, padding=1),211norm_layer(min(max_features, ngf * mult * 2)),212activation]213
214mult = 2 ** n_downsampling215feats_num_bottleneck = min(max_features, ngf * mult)216
217### resnet blocks218for i in range(n_blocks):219if ffc_positions is not None and i in ffc_positions:220model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,221inline=True, **ffc_kwargs)]222model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,223conv_layer=resnet_conv_layer, activation=activation,224norm_layer=norm_layer)]225
226### upsample227for i in range(n_downsampling):228mult = 2 ** (n_downsampling - i)229model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)230model += [nn.ReflectionPad2d(3),231nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]232if add_out_act:233model.append(get_activation('tanh' if add_out_act is True else add_out_act))234self.model = nn.Sequential(*model)235
236def forward(self, input):237return self.model(input)238
239class ConfigGlobalGenerator(nn.Module):240def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,241n_blocks=3, norm_layer=nn.BatchNorm2d,242padding_type='reflect', conv_kind='default',243deconv_kind='convtranspose', activation=nn.ReLU(True),244up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),245add_out_act=True, max_features=1024,246manual_block_spec=[],247resnet_block_kind='multidilatedresnetblock',248resnet_conv_kind='multidilated',249resnet_dilation=1,250multidilation_kwargs={}):251assert (n_blocks >= 0)252super().__init__()253
254conv_layer = get_conv_block_ctor(conv_kind)255resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)256norm_layer = get_norm_layer(norm_layer)257if affine is not None:258norm_layer = partial(norm_layer, affine=affine)259up_norm_layer = get_norm_layer(up_norm_layer)260if affine is not None:261up_norm_layer = partial(up_norm_layer, affine=affine)262
263model = [nn.ReflectionPad2d(3),264conv_layer(input_nc, ngf, kernel_size=7, padding=0),265norm_layer(ngf),266activation]267
268identity = Identity()269
270### downsample271for i in range(n_downsampling):272mult = 2 ** i273model += [conv_layer(min(max_features, ngf * mult),274min(max_features, ngf * mult * 2),275kernel_size=3, stride=2, padding=1),276norm_layer(min(max_features, ngf * mult * 2)),277activation]278
279mult = 2 ** n_downsampling280feats_num_bottleneck = min(max_features, ngf * mult)281
282if len(manual_block_spec) == 0:283manual_block_spec = [284DotDict(lambda : None, {285'n_blocks': n_blocks,286'use_default': True})287]288
289### resnet blocks290for block_spec in manual_block_spec:291def make_and_add_blocks(model, block_spec):292block_spec = DotDict(lambda : None, block_spec)293if not block_spec.use_default:294resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)295resnet_conv_kind = block_spec.resnet_conv_kind296resnet_block_kind = block_spec.resnet_block_kind297if block_spec.resnet_dilation is not None:298resnet_dilation = block_spec.resnet_dilation299for i in range(block_spec.n_blocks):300if resnet_block_kind == "multidilatedresnetblock":301model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,302conv_layer=resnet_conv_layer, activation=activation,303norm_layer=norm_layer)]304if resnet_block_kind == "resnetblock":305model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,306conv_kind=resnet_conv_kind)]307if resnet_block_kind == "resnetblock5x5":308model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,309conv_kind=resnet_conv_kind)]310if resnet_block_kind == "resnetblockdwdil":311model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,312conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]313make_and_add_blocks(model, block_spec)314
315### upsample316for i in range(n_downsampling):317mult = 2 ** (n_downsampling - i)318model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)319model += [nn.ReflectionPad2d(3),320nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]321if add_out_act:322model.append(get_activation('tanh' if add_out_act is True else add_out_act))323self.model = nn.Sequential(*model)324
325def forward(self, input):326return self.model(input)327
328
329def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):330blocks = []331for i in range(dilated_blocks_n):332if dilation_block_kind == 'simple':333blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))334elif dilation_block_kind == 'multi':335blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))336else:337raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')338return blocks339
340
341class GlobalGenerator(nn.Module):342def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,343padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),344up_norm_layer=nn.BatchNorm2d, affine=None,345up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,346dilated_blocks_n_middle=0,347add_out_act=True,348max_features=1024, is_resblock_depthwise=False,349ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,350dilation_block_kind='simple', multidilation_kwargs={}):351assert (n_blocks >= 0)352super().__init__()353
354conv_layer = get_conv_block_ctor(conv_kind)355norm_layer = get_norm_layer(norm_layer)356if affine is not None:357norm_layer = partial(norm_layer, affine=affine)358up_norm_layer = get_norm_layer(up_norm_layer)359if affine is not None:360up_norm_layer = partial(up_norm_layer, affine=affine)361
362if ffc_positions is not None:363ffc_positions = collections.Counter(ffc_positions)364
365model = [nn.ReflectionPad2d(3),366conv_layer(input_nc, ngf, kernel_size=7, padding=0),367norm_layer(ngf),368activation]369
370identity = Identity()371### downsample372for i in range(n_downsampling):373mult = 2 ** i374
375model += [conv_layer(min(max_features, ngf * mult),376min(max_features, ngf * mult * 2),377kernel_size=3, stride=2, padding=1),378norm_layer(min(max_features, ngf * mult * 2)),379activation]380
381mult = 2 ** n_downsampling382feats_num_bottleneck = min(max_features, ngf * mult)383
384dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,385activation=activation, norm_layer=norm_layer)386if dilation_block_kind == 'simple':387dilated_block_kwargs['conv_kind'] = conv_kind388elif dilation_block_kind == 'multi':389dilated_block_kwargs['conv_layer'] = functools.partial(390get_conv_block_ctor('multidilated'), **multidilation_kwargs)391
392# dilated blocks at the start of the bottleneck sausage393if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:394model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)395
396# resnet blocks397for i in range(n_blocks):398# dilated blocks at the middle of the bottleneck sausage399if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:400model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)401
402if ffc_positions is not None and i in ffc_positions:403for _ in range(ffc_positions[i]): # same position can occur more than once404model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,405inline=True, **ffc_kwargs)]406
407if is_resblock_depthwise:408resblock_groups = feats_num_bottleneck409else:410resblock_groups = 1411
412model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,413norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,414dilation=dilation, second_dilation=second_dilation)]415
416
417# dilated blocks at the end of the bottleneck sausage418if dilated_blocks_n is not None and dilated_blocks_n > 0:419model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)420
421# upsample422for i in range(n_downsampling):423mult = 2 ** (n_downsampling - i)424model += [nn.ConvTranspose2d(min(max_features, ngf * mult),425min(max_features, int(ngf * mult / 2)),426kernel_size=3, stride=2, padding=1, output_padding=1),427up_norm_layer(min(max_features, int(ngf * mult / 2))),428up_activation]429model += [nn.ReflectionPad2d(3),430nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]431if add_out_act:432model.append(get_activation('tanh' if add_out_act is True else add_out_act))433self.model = nn.Sequential(*model)434
435def forward(self, input):436return self.model(input)437
438
439class GlobalGeneratorGated(GlobalGenerator):440def __init__(self, *args, **kwargs):441real_kwargs=dict(442conv_kind='gated_bn_relu',443activation=nn.Identity(),444norm_layer=nn.Identity445)446real_kwargs.update(kwargs)447super().__init__(*args, **real_kwargs)448
449
450class GlobalGeneratorFromSuperChannels(nn.Module):451def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):452super().__init__()453self.n_downsampling = n_downsampling454norm_layer = get_norm_layer(norm_layer)455if type(norm_layer) == functools.partial:456use_bias = (norm_layer.func == nn.InstanceNorm2d)457else:458use_bias = (norm_layer == nn.InstanceNorm2d)459
460channels = self.convert_super_channels(super_channels)461self.channels = channels462
463model = [nn.ReflectionPad2d(3),464nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),465norm_layer(channels[0]),466nn.ReLU(True)]467
468for i in range(n_downsampling): # add downsampling layers469mult = 2 ** i470model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),471norm_layer(channels[1+i]),472nn.ReLU(True)]473
474mult = 2 ** n_downsampling475
476n_blocks1 = n_blocks // 3477n_blocks2 = n_blocks1478n_blocks3 = n_blocks - n_blocks1 - n_blocks2479
480for i in range(n_blocks1):481c = n_downsampling482dim = channels[c]483model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]484
485for i in range(n_blocks2):486c = n_downsampling+1487dim = channels[c]488kwargs = {}489if i == 0:490kwargs = {"in_dim": channels[c-1]}491model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]492
493for i in range(n_blocks3):494c = n_downsampling+2495dim = channels[c]496kwargs = {}497if i == 0:498kwargs = {"in_dim": channels[c-1]}499model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]500
501for i in range(n_downsampling): # add upsampling layers502mult = 2 ** (n_downsampling - i)503model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],504channels[n_downsampling+3+i+1],505kernel_size=3, stride=2,506padding=1, output_padding=1,507bias=use_bias),508norm_layer(channels[n_downsampling+3+i+1]),509nn.ReLU(True)]510model += [nn.ReflectionPad2d(3)]511model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]512
513if add_out_act:514model.append(get_activation('tanh' if add_out_act is True else add_out_act))515self.model = nn.Sequential(*model)516
517def convert_super_channels(self, super_channels):518n_downsampling = self.n_downsampling519result = []520cnt = 0521
522if n_downsampling == 2:523N1 = 10524elif n_downsampling == 3:525N1 = 13526else:527raise NotImplementedError528
529for i in range(0, N1):530if i in [1,4,7,10]:531channel = super_channels[cnt] * (2 ** cnt)532config = {'channel': channel}533result.append(channel)534logging.info(f"Downsample channels {result[-1]}")535cnt += 1536
537for i in range(3):538for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):539if len(super_channels) == 6:540channel = super_channels[3] * 4541else:542channel = super_channels[i + 3] * 4543config = {'channel': channel}544if counter == 0:545result.append(channel)546logging.info(f"Bottleneck channels {result[-1]}")547cnt = 2548
549for i in range(N1+9, N1+21):550if i in [22, 25,28]:551cnt -= 1552if len(super_channels) == 6:553channel = super_channels[5 - cnt] * (2 ** cnt)554else:555channel = super_channels[7 - cnt] * (2 ** cnt)556result.append(int(channel))557logging.info(f"Upsample channels {result[-1]}")558return result559
560def forward(self, input):561return self.model(input)562
563
564# Defines the PatchGAN discriminator with the specified arguments.
565class NLayerDiscriminator(BaseDiscriminator):566def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):567super().__init__()568self.n_layers = n_layers569
570kw = 4571padw = int(np.ceil((kw-1.0)/2))572sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),573nn.LeakyReLU(0.2, True)]]574
575nf = ndf576for n in range(1, n_layers):577nf_prev = nf578nf = min(nf * 2, 512)579
580cur_model = []581cur_model += [582nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),583norm_layer(nf),584nn.LeakyReLU(0.2, True)585]586sequence.append(cur_model)587
588nf_prev = nf589nf = min(nf * 2, 512)590
591cur_model = []592cur_model += [593nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),594norm_layer(nf),595nn.LeakyReLU(0.2, True)596]597sequence.append(cur_model)598
599sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]600
601for n in range(len(sequence)):602setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))603
604def get_all_activations(self, x):605res = [x]606for n in range(self.n_layers + 2):607model = getattr(self, 'model' + str(n))608res.append(model(res[-1]))609return res[1:]610
611def forward(self, x):612act = self.get_all_activations(x)613return act[-1], act[:-1]614
615
616class MultidilatedNLayerDiscriminator(BaseDiscriminator):617def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):618super().__init__()619self.n_layers = n_layers620
621kw = 4622padw = int(np.ceil((kw-1.0)/2))623sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),624nn.LeakyReLU(0.2, True)]]625
626nf = ndf627for n in range(1, n_layers):628nf_prev = nf629nf = min(nf * 2, 512)630
631cur_model = []632cur_model += [633MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),634norm_layer(nf),635nn.LeakyReLU(0.2, True)636]637sequence.append(cur_model)638
639nf_prev = nf640nf = min(nf * 2, 512)641
642cur_model = []643cur_model += [644nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),645norm_layer(nf),646nn.LeakyReLU(0.2, True)647]648sequence.append(cur_model)649
650sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]651
652for n in range(len(sequence)):653setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))654
655def get_all_activations(self, x):656res = [x]657for n in range(self.n_layers + 2):658model = getattr(self, 'model' + str(n))659res.append(model(res[-1]))660return res[1:]661
662def forward(self, x):663act = self.get_all_activations(x)664return act[-1], act[:-1]665
666
667class NLayerDiscriminatorAsGen(NLayerDiscriminator):668def forward(self, x):669return super().forward(x)[0]670