5
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
8
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
9
for ori_k, ori_v in checkpoint_bilinear.items():
10
if 'stylegan_decoder' in ori_k:
11
if 'style_mlp' in ori_k:
13
prefix, name, idx, var = ori_k.split('.')
14
idx = (int(idx) * 2) - 1
15
crt_k = f'{prefix}.{name}.{idx}.{var}'
17
_, c_in = ori_v.size()
18
scale = (1 / math.sqrt(c_in)) * lr_mul
19
crt_v = ori_v * scale * 2**0.5
21
crt_v = ori_v * lr_mul * 2**0.5
22
checkpoint_clean[crt_k] = crt_v
23
elif 'modulation' in ori_k:
26
var = ori_k.split('.')[-1]
28
_, c_in = ori_v.size()
29
scale = (1 / math.sqrt(c_in)) * lr_mul
32
crt_v = ori_v * lr_mul
33
checkpoint_clean[crt_k] = crt_v
34
elif 'style_conv' in ori_k:
36
if 'activate' in ori_k:
39
split_rlt = ori_k.split('.')
40
if len(split_rlt) == 4:
41
prefix, name, _, var = split_rlt
42
crt_k = f'{prefix}.{name}.{var}'
43
elif len(split_rlt) == 5:
44
prefix, name, idx, _, var = split_rlt
45
crt_k = f'{prefix}.{name}.{idx}.{var}'
46
crt_v = ori_v * 2**0.5
48
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
49
elif 'modulated_conv' in ori_k:
52
_, c_out, c_in, k1, k2 = ori_v.size()
53
scale = 1 / math.sqrt(c_in * k1 * k2)
55
checkpoint_clean[crt_k] = ori_v * scale
56
elif 'weight' in ori_k:
58
checkpoint_clean[crt_k] = ori_v * 2**0.5
59
elif 'to_rgb' in ori_k:
60
if 'modulated_conv' in ori_k:
63
_, c_out, c_in, k1, k2 = ori_v.size()
64
scale = 1 / math.sqrt(c_in * k1 * k2)
66
checkpoint_clean[crt_k] = ori_v * scale
69
checkpoint_clean[crt_k] = ori_v
72
checkpoint_clean[crt_k] = ori_v
74
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
76
name, _, var = ori_k.split('.')
77
crt_k = f'{name}.{var}'
80
c_out, c_in, k1, k2 = ori_v.size()
81
scale = 1 / math.sqrt(c_in * k1 * k2)
82
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
84
checkpoint_clean[crt_k] = ori_v * 2**0.5
85
elif 'conv_body' in ori_k:
86
if 'conv_body_up' in ori_k:
87
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
88
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
89
name1, idx1, name2, _, var = ori_k.split('.')
90
crt_k = f'{name1}.{idx1}.{name2}.{var}'
92
c_out, c_in, k1, k2 = ori_v.size()
93
scale = 1 / math.sqrt(c_in * k1 * k2)
94
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
97
c_out, c_in, k1, k2 = ori_v.size()
98
scale = 1 / math.sqrt(c_in * k1 * k2)
99
checkpoint_clean[crt_k] = ori_v * scale
101
checkpoint_clean[crt_k] = ori_v
103
checkpoint_clean[crt_k] *= 2**0.5
104
elif 'toRGB' in ori_k:
106
if 'weight' in ori_k:
107
c_out, c_in, k1, k2 = ori_v.size()
108
scale = 1 / math.sqrt(c_in * k1 * k2)
109
checkpoint_clean[crt_k] = ori_v * scale
111
checkpoint_clean[crt_k] = ori_v
112
elif 'final_linear' in ori_k:
114
if 'weight' in ori_k:
115
_, c_in = ori_v.size()
116
scale = 1 / math.sqrt(c_in)
117
checkpoint_clean[crt_k] = ori_v * scale
119
checkpoint_clean[crt_k] = ori_v
120
elif 'condition' in ori_k:
122
if '0.weight' in ori_k:
123
c_out, c_in, k1, k2 = ori_v.size()
124
scale = 1 / math.sqrt(c_in * k1 * k2)
125
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
126
elif '0.bias' in ori_k:
127
checkpoint_clean[crt_k] = ori_v * 2**0.5
128
elif '2.weight' in ori_k:
129
c_out, c_in, k1, k2 = ori_v.size()
130
scale = 1 / math.sqrt(c_in * k1 * k2)
131
checkpoint_clean[crt_k] = ori_v * scale
132
elif '2.bias' in ori_k:
133
checkpoint_clean[crt_k] = ori_v
135
return checkpoint_clean
138
if __name__ == '__main__':
139
parser = argparse.ArgumentParser()
140
parser.add_argument('--ori_path', type=str, help='Path to the original model')
141
parser.add_argument('--narrow', type=float, default=1)
142
parser.add_argument('--channel_multiplier', type=float, default=2)
143
parser.add_argument('--save_path', type=str)
144
args = parser.parse_args()
146
ori_ckpt = torch.load(args.ori_path)['params_ema']
151
channel_multiplier=args.channel_multiplier,
152
decoder_load_path=None,
156
input_is_latent=True,
160
crt_ckpt = net.state_dict()
162
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
163
print(f'Save to {args.save_path}.')
164
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)