StyleFeatureEditor

Форк
0
/
model_utils.py 
84 строки · 2.5 Кб
1
RESNET_MAPPING = {
2
    "layer1.0": "body.0",
3
    "layer1.1": "body.1",
4
    "layer1.2": "body.2",
5
    "layer2.0": "body.3",
6
    "layer2.1": "body.4",
7
    "layer2.2": "body.5",
8
    "layer2.3": "body.6",
9
    "layer3.0": "body.7",
10
    "layer3.1": "body.8",
11
    "layer3.2": "body.9",
12
    "layer3.3": "body.10",
13
    "layer3.4": "body.11",
14
    "layer3.5": "body.12",
15
    "layer4.0": "body.13",
16
    "layer4.1": "body.14",
17
    "layer4.2": "body.15",
18
}
19

20

21
def count_parameters(model):
22
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
23

24

25
def toogle_grad(model, flag=True):
26
    for p in model.parameters():
27
        p.requires_grad = flag
28

29

30
def stylegan_to_classifier(x, out_size=(224, 224)):
31
    """Clip image to range(0,1)"""
32
    img_tmp = x.clone()
33
    img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1)
34
    img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear')
35
    img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229
36
    img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224
37
    img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225
38
    return img_tmp
39

40

41
def get_stylespace_from_w(w, G):
42
    style_space = []
43
    to_rgb_stylespaces = []
44

45
    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
46
    latent = w
47
    style_space.append(G.conv1.conv.modulation(latent[:, 0]))
48
    to_rgb_stylespaces.append(G.to_rgb1.conv.modulation(latent[:, 1]))
49

50
    i = 1
51
    for conv1, conv2, noise1, noise2, to_rgb in zip(
52
            G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
53
    ):
54
        style_space.append(conv1.conv.modulation(latent[:, i]))
55
        style_space.append(conv2.conv.modulation(latent[:, i + 1]))
56
        to_rgb_stylespaces.append(to_rgb.conv.modulation(latent[:, i + 2]))
57
        i += 2
58
    return style_space, to_rgb_stylespaces
59

60

61
def get_stylespace_from_w_hyperinv(w, G):
62
    with torch.no_grad():
63
        style_space = []
64
        to_rgb_stylespaces = []
65
        G = G.synthesis
66

67
        block_ws = []
68
        w_idx = 0
69
        for res in G.block_resolutions:
70
            block = getattr(G, f"b{res}")
71
            block_ws.append(w.narrow(1, w_idx, block.num_conv + block.num_torgb))
72
            w_idx += block.num_conv
73

74
        i = 0
75
        for res, cur_ws in zip(G.block_resolutions, block_ws):
76
            block = getattr(G, f"b{res}")
77
            if i != 0:
78
                style_space.append(block.conv0.affine(w[:, i]))
79
                i += 1
80
            style_space.append(block.conv1.affine(w[:, i]))
81
            i += 1
82
            to_rgb_stylespaces.append(block.torgb.affine(w[:, i]))
83

84
    return style_space, to_rgb_stylespaces
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.