3
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
4
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
7
def test_stylegan2generatorsft():
8
"""Test arch: StyleGAN2GeneratorSFT."""
10
# model init and forward (gpu)
11
if torch.cuda.is_available():
12
net = StyleGAN2GeneratorSFT(
17
resample_kernel=(1, 3, 3, 1),
20
sft_half=False).cuda().eval()
21
style = torch.rand((1, 512), dtype=torch.float32).cuda()
22
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
23
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
24
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
25
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
26
output = net([style], conditions)
27
assert output[0].shape == (1, 3, 32, 32)
28
assert output[1] is None
30
# -------------------- with return_latents ----------------------- #
31
output = net([style], conditions, return_latents=True)
32
assert output[0].shape == (1, 3, 32, 32)
33
assert len(output[1]) == 1
35
assert output[1][0].shape == (8, 512)
37
# -------------------- with randomize_noise = False ----------------------- #
38
output = net([style], conditions, randomize_noise=False)
39
assert output[0].shape == (1, 3, 32, 32)
40
assert output[1] is None
42
# -------------------- with truncation = 0.5 and mixing----------------------- #
43
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
44
assert output[0].shape == (1, 3, 32, 32)
45
assert output[1] is None
49
"""Test arch: GFPGANv1."""
51
# model init and forward (gpu)
52
if torch.cuda.is_available():
57
resample_kernel=(1, 3, 3, 1),
58
decoder_load_path=None,
60
# for stylegan decoder
63
input_is_latent=False,
66
sft_half=True).cuda().eval()
67
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
69
assert output[0].shape == (1, 3, 32, 32)
70
assert len(output[1]) == 3
71
# check out_rgbs for intermediate loss
72
assert output[1][0].shape == (1, 3, 8, 8)
73
assert output[1][1].shape == (1, 3, 16, 16)
74
assert output[1][2].shape == (1, 3, 32, 32)
76
# -------------------- with different_w = True ----------------------- #
81
resample_kernel=(1, 3, 3, 1),
82
decoder_load_path=None,
84
# for stylegan decoder
87
input_is_latent=False,
90
sft_half=True).cuda().eval()
91
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
93
assert output[0].shape == (1, 3, 32, 32)
94
assert len(output[1]) == 3
95
# check out_rgbs for intermediate loss
96
assert output[1][0].shape == (1, 3, 8, 8)
97
assert output[1][1].shape == (1, 3, 16, 16)
98
assert output[1][2].shape == (1, 3, 32, 32)
101
def test_facialcomponentdiscriminator():
102
"""Test arch: FacialComponentDiscriminator."""
104
# model init and forward (gpu)
105
if torch.cuda.is_available():
106
net = FacialComponentDiscriminator().cuda().eval()
107
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
109
assert len(output) == 2
110
assert output[0].shape == (1, 1, 8, 8)
111
assert output[1] is None
113
# -------------------- return intermediate features ----------------------- #
114
output = net(img, return_feats=True)
115
assert len(output) == 2
116
assert output[0].shape == (1, 1, 8, 8)
117
assert len(output[1]) == 2
118
assert output[1][0].shape == (1, 128, 16, 16)
119
assert output[1][1].shape == (1, 256, 8, 8)
122
def test_stylegan2generatorcsft():
123
"""Test arch: StyleGAN2GeneratorCSFT."""
125
# model init and forward (gpu)
126
if torch.cuda.is_available():
127
net = StyleGAN2GeneratorCSFT(
128
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
129
style = torch.rand((1, 512), dtype=torch.float32).cuda()
130
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
131
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
132
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
133
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
134
output = net([style], conditions)
135
assert output[0].shape == (1, 3, 32, 32)
136
assert output[1] is None
138
# -------------------- with return_latents ----------------------- #
139
output = net([style], conditions, return_latents=True)
140
assert output[0].shape == (1, 3, 32, 32)
141
assert len(output[1]) == 1
143
assert output[1][0].shape == (8, 512)
145
# -------------------- with randomize_noise = False ----------------------- #
146
output = net([style], conditions, randomize_noise=False)
147
assert output[0].shape == (1, 3, 32, 32)
148
assert output[1] is None
150
# -------------------- with truncation = 0.5 and mixing----------------------- #
151
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
152
assert output[0].shape == (1, 3, 32, 32)
153
assert output[1] is None
156
def test_gfpganv1clean():
157
"""Test arch: GFPGANv1Clean."""
159
# model init and forward (gpu)
160
if torch.cuda.is_available():
164
channel_multiplier=1,
165
decoder_load_path=None,
167
# for stylegan decoder
169
input_is_latent=False,
172
sft_half=True).cuda().eval()
174
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
176
assert output[0].shape == (1, 3, 32, 32)
177
assert len(output[1]) == 3
178
# check out_rgbs for intermediate loss
179
assert output[1][0].shape == (1, 3, 8, 8)
180
assert output[1][1].shape == (1, 3, 16, 16)
181
assert output[1][2].shape == (1, 3, 32, 32)
183
# -------------------- with different_w = True ----------------------- #
187
channel_multiplier=1,
188
decoder_load_path=None,
190
# for stylegan decoder
192
input_is_latent=False,
195
sft_half=True).cuda().eval()
196
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
198
assert output[0].shape == (1, 3, 32, 32)
199
assert len(output[1]) == 3
200
# check out_rgbs for intermediate loss
201
assert output[1][0].shape == (1, 3, 8, 8)
202
assert output[1][1].shape == (1, 3, 16, 16)
203
assert output[1][2].shape == (1, 3, 32, 32)