GFPGAN

Форк
0
/
test_gfpgan_arch.py 
203 строки · 7.7 Кб
1
import torch
2

3
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
4
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
5

6

7
def test_stylegan2generatorsft():
8
    """Test arch: StyleGAN2GeneratorSFT."""
9

10
    # model init and forward (gpu)
11
    if torch.cuda.is_available():
12
        net = StyleGAN2GeneratorSFT(
13
            out_size=32,
14
            num_style_feat=512,
15
            num_mlp=8,
16
            channel_multiplier=1,
17
            resample_kernel=(1, 3, 3, 1),
18
            lr_mlp=0.01,
19
            narrow=1,
20
            sft_half=False).cuda().eval()
21
        style = torch.rand((1, 512), dtype=torch.float32).cuda()
22
        condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
23
        condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
24
        condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
25
        conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
26
        output = net([style], conditions)
27
        assert output[0].shape == (1, 3, 32, 32)
28
        assert output[1] is None
29

30
        # -------------------- with return_latents ----------------------- #
31
        output = net([style], conditions, return_latents=True)
32
        assert output[0].shape == (1, 3, 32, 32)
33
        assert len(output[1]) == 1
34
        # check latent
35
        assert output[1][0].shape == (8, 512)
36

37
        # -------------------- with randomize_noise = False ----------------------- #
38
        output = net([style], conditions, randomize_noise=False)
39
        assert output[0].shape == (1, 3, 32, 32)
40
        assert output[1] is None
41

42
        # -------------------- with truncation = 0.5 and mixing----------------------- #
43
        output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
44
        assert output[0].shape == (1, 3, 32, 32)
45
        assert output[1] is None
46

47

48
def test_gfpganv1():
49
    """Test arch: GFPGANv1."""
50

51
    # model init and forward (gpu)
52
    if torch.cuda.is_available():
53
        net = GFPGANv1(
54
            out_size=32,
55
            num_style_feat=512,
56
            channel_multiplier=1,
57
            resample_kernel=(1, 3, 3, 1),
58
            decoder_load_path=None,
59
            fix_decoder=True,
60
            # for stylegan decoder
61
            num_mlp=8,
62
            lr_mlp=0.01,
63
            input_is_latent=False,
64
            different_w=False,
65
            narrow=1,
66
            sft_half=True).cuda().eval()
67
        img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
68
        output = net(img)
69
        assert output[0].shape == (1, 3, 32, 32)
70
        assert len(output[1]) == 3
71
        # check out_rgbs for intermediate loss
72
        assert output[1][0].shape == (1, 3, 8, 8)
73
        assert output[1][1].shape == (1, 3, 16, 16)
74
        assert output[1][2].shape == (1, 3, 32, 32)
75

76
        # -------------------- with different_w = True ----------------------- #
77
        net = GFPGANv1(
78
            out_size=32,
79
            num_style_feat=512,
80
            channel_multiplier=1,
81
            resample_kernel=(1, 3, 3, 1),
82
            decoder_load_path=None,
83
            fix_decoder=True,
84
            # for stylegan decoder
85
            num_mlp=8,
86
            lr_mlp=0.01,
87
            input_is_latent=False,
88
            different_w=True,
89
            narrow=1,
90
            sft_half=True).cuda().eval()
91
        img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
92
        output = net(img)
93
        assert output[0].shape == (1, 3, 32, 32)
94
        assert len(output[1]) == 3
95
        # check out_rgbs for intermediate loss
96
        assert output[1][0].shape == (1, 3, 8, 8)
97
        assert output[1][1].shape == (1, 3, 16, 16)
98
        assert output[1][2].shape == (1, 3, 32, 32)
99

100

101
def test_facialcomponentdiscriminator():
102
    """Test arch: FacialComponentDiscriminator."""
103

104
    # model init and forward (gpu)
105
    if torch.cuda.is_available():
106
        net = FacialComponentDiscriminator().cuda().eval()
107
        img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
108
        output = net(img)
109
        assert len(output) == 2
110
        assert output[0].shape == (1, 1, 8, 8)
111
        assert output[1] is None
112

113
        # -------------------- return intermediate features ----------------------- #
114
        output = net(img, return_feats=True)
115
        assert len(output) == 2
116
        assert output[0].shape == (1, 1, 8, 8)
117
        assert len(output[1]) == 2
118
        assert output[1][0].shape == (1, 128, 16, 16)
119
        assert output[1][1].shape == (1, 256, 8, 8)
120

121

122
def test_stylegan2generatorcsft():
123
    """Test arch: StyleGAN2GeneratorCSFT."""
124

125
    # model init and forward (gpu)
126
    if torch.cuda.is_available():
127
        net = StyleGAN2GeneratorCSFT(
128
            out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
129
        style = torch.rand((1, 512), dtype=torch.float32).cuda()
130
        condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
131
        condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
132
        condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
133
        conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
134
        output = net([style], conditions)
135
        assert output[0].shape == (1, 3, 32, 32)
136
        assert output[1] is None
137

138
        # -------------------- with return_latents ----------------------- #
139
        output = net([style], conditions, return_latents=True)
140
        assert output[0].shape == (1, 3, 32, 32)
141
        assert len(output[1]) == 1
142
        # check latent
143
        assert output[1][0].shape == (8, 512)
144

145
        # -------------------- with randomize_noise = False ----------------------- #
146
        output = net([style], conditions, randomize_noise=False)
147
        assert output[0].shape == (1, 3, 32, 32)
148
        assert output[1] is None
149

150
        # -------------------- with truncation = 0.5 and mixing----------------------- #
151
        output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
152
        assert output[0].shape == (1, 3, 32, 32)
153
        assert output[1] is None
154

155

156
def test_gfpganv1clean():
157
    """Test arch: GFPGANv1Clean."""
158

159
    # model init and forward (gpu)
160
    if torch.cuda.is_available():
161
        net = GFPGANv1Clean(
162
            out_size=32,
163
            num_style_feat=512,
164
            channel_multiplier=1,
165
            decoder_load_path=None,
166
            fix_decoder=True,
167
            # for stylegan decoder
168
            num_mlp=8,
169
            input_is_latent=False,
170
            different_w=False,
171
            narrow=1,
172
            sft_half=True).cuda().eval()
173

174
        img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
175
        output = net(img)
176
        assert output[0].shape == (1, 3, 32, 32)
177
        assert len(output[1]) == 3
178
        # check out_rgbs for intermediate loss
179
        assert output[1][0].shape == (1, 3, 8, 8)
180
        assert output[1][1].shape == (1, 3, 16, 16)
181
        assert output[1][2].shape == (1, 3, 32, 32)
182

183
        # -------------------- with different_w = True ----------------------- #
184
        net = GFPGANv1Clean(
185
            out_size=32,
186
            num_style_feat=512,
187
            channel_multiplier=1,
188
            decoder_load_path=None,
189
            fix_decoder=True,
190
            # for stylegan decoder
191
            num_mlp=8,
192
            input_is_latent=False,
193
            different_w=True,
194
            narrow=1,
195
            sft_half=True).cuda().eval()
196
        img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
197
        output = net(img)
198
        assert output[0].shape == (1, 3, 32, 32)
199
        assert len(output[1]) == 3
200
        # check out_rgbs for intermediate loss
201
        assert output[1][0].shape == (1, 3, 8, 8)
202
        assert output[1][1].shape == (1, 3, 16, 16)
203
        assert output[1][2].shape == (1, 3, 32, 32)
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.