StyleFeatureEditor
84 строки · 2.5 Кб
1RESNET_MAPPING = {
2"layer1.0": "body.0",
3"layer1.1": "body.1",
4"layer1.2": "body.2",
5"layer2.0": "body.3",
6"layer2.1": "body.4",
7"layer2.2": "body.5",
8"layer2.3": "body.6",
9"layer3.0": "body.7",
10"layer3.1": "body.8",
11"layer3.2": "body.9",
12"layer3.3": "body.10",
13"layer3.4": "body.11",
14"layer3.5": "body.12",
15"layer4.0": "body.13",
16"layer4.1": "body.14",
17"layer4.2": "body.15",
18}
19
20
21def count_parameters(model):
22return sum(p.numel() for p in model.parameters() if p.requires_grad)
23
24
25def toogle_grad(model, flag=True):
26for p in model.parameters():
27p.requires_grad = flag
28
29
30def stylegan_to_classifier(x, out_size=(224, 224)):
31"""Clip image to range(0,1)"""
32img_tmp = x.clone()
33img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1)
34img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear')
35img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229
36img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224
37img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225
38return img_tmp
39
40
41def get_stylespace_from_w(w, G):
42style_space = []
43to_rgb_stylespaces = []
44
45noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
46latent = w
47style_space.append(G.conv1.conv.modulation(latent[:, 0]))
48to_rgb_stylespaces.append(G.to_rgb1.conv.modulation(latent[:, 1]))
49
50i = 1
51for conv1, conv2, noise1, noise2, to_rgb in zip(
52G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
53):
54style_space.append(conv1.conv.modulation(latent[:, i]))
55style_space.append(conv2.conv.modulation(latent[:, i + 1]))
56to_rgb_stylespaces.append(to_rgb.conv.modulation(latent[:, i + 2]))
57i += 2
58return style_space, to_rgb_stylespaces
59
60
61def get_stylespace_from_w_hyperinv(w, G):
62with torch.no_grad():
63style_space = []
64to_rgb_stylespaces = []
65G = G.synthesis
66
67block_ws = []
68w_idx = 0
69for res in G.block_resolutions:
70block = getattr(G, f"b{res}")
71block_ws.append(w.narrow(1, w_idx, block.num_conv + block.num_torgb))
72w_idx += block.num_conv
73
74i = 0
75for res, cur_ws in zip(G.block_resolutions, block_ws):
76block = getattr(G, f"b{res}")
77if i != 0:
78style_space.append(block.conv0.affine(w[:, i]))
79i += 1
80style_space.append(block.conv1.affine(w[:, i]))
81i += 1
82to_rgb_stylespaces.append(block.torgb.affine(w[:, i]))
83
84return style_space, to_rgb_stylespaces
85