StyleFeatureEditor

Форк
0
/
latent_editor.py 
289 строк · 10.5 Кб
1
import os
2
import sys
3
import torch
4
import pickle
5
import argparse
6
import numpy as np
7

8
from editings import ganspace
9
from editings.styleclip.mapper.styleclip_mapper import StyleCLIPMapper
10
from editings.styleclip.mapper.gloabl_mapper import StyleCLIPGlobalDirection
11
from editings.deltaedit.editor import DeltaEditor
12

13

14
STYLESPACE_IDX = [
15
            0,
16
            1,
17
            1,
18
            2,
19
            2,
20
            3,
21
            4,
22
            4,
23
            5,
24
            6,
25
            6,
26
            7,
27
            8,
28
            8,
29
            9,
30
            10,
31
            10,
32
            11,
33
            12,
34
            12,
35
            13,
36
            14,
37
            14,
38
            15,
39
            16,
40
            16,
41
        ]
42

43

44
class LatentEditor:
45
    def __init__(self, domain="human_faces"):
46

47
        self.domain = domain
48

49
        if self.domain == "human_faces":
50
            self.interfacegan_directions = {
51
                "age": "editings/interfacegan_directions/age.pt",
52
                "smile": "editings/interfacegan_directions/smile.pt",
53
                "rotation": "editings/interfacegan_directions/rotation.pt",
54
            }
55
            self.interfacegan_tensors = {
56
                name: torch.load(path).cuda()
57
                for name, path in self.interfacegan_directions.items()
58
            }
59

60
            self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt")
61
            self.ganspace_directions = {
62
                "eye_openness": (54, 7, 8, 5),
63
                "trimmed_beard": (58, 7, 9, 7),
64
                "lipstick": (34, 10, 11, 20),
65
                "face_roundness": (37, 0, 5, 20.0),
66
                "nose_length": (51, 4, 5, -30.0),
67
                "eyebrow_thickness": (37, 8, 9, 20.0),
68
                "head_angle_up": (11, 1, 4, -10.5),
69
                "displeased": (36, 4, 7, 10.0),
70
            }
71

72
            self.styleclip_directions = {
73
                "afro": [False, False, True],
74
                "angry": [False, False, True],
75
                "beyonce": [False, False, False],
76
                "bobcut": [False, False, True],
77
                "bowlcut": [False, False, True],
78
                "curly_hair": [False, False, True],
79
                "hilary_clinton": [False, False, False],
80
                "depp": [False, False, False],
81
                "mohawk": [False, False, True],
82
                "purple_hair": [False, False, False],
83
                "surprised": [False, False, True],
84
                "taylor_swift": [False, False, False],
85
                "trump": [False, False, False],
86
                "zuckerberg": [False, False, False],
87
            }
88
            self.styleclip_global_editor = self.load_styleclip_global()
89

90
            self.stylespace_directions = {
91
                "black hair": [(12, 479)],
92
                "blond hair": [(12, 479), (12, 266)],
93
                "grey hair": [(11, 286)],
94
                "wavy hair": [(6, 500), (8, 128), (5, 92), (6, 394), (6, 323)],
95
                "bangs": [
96
                    (3, 259),
97
                    (6, 285),
98
                    (5, 414),
99
                    (6, 128),
100
                    (9, 295),
101
                    (6, 322),
102
                    (6, 487),
103
                    (6, 504),
104
                ],
105
                "receding hairline": [(5, 414), (6, 322), (6, 497), (6, 504)],
106
                "smiling": [(6, 501)],
107
                "sslipstick": [(15, 45)],
108
                "sideburns": [(12, 237)],
109
                "goatee": [(9, 421)],
110
                "earrings": [(8, 81)],
111
                "glasses": [(3, 288), (2, 175), (3, 120), (2, 97)],
112
                "wear suit": [(9, 441), (8, 292), (11, 358), (6, 223)],
113
                "gender": [(9, 6)],
114
            }
115
            
116
            self.fs_directions = {
117
                "fs_glasses": "editings/bound/Eyeglasses_boundary.npy",
118
                "fs_smiling": "editings/bound/Smiling_boundary.npy",
119
                "fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
120
             }
121

122
            self.deltaedit_editor = DeltaEditor()
123

124
        elif self.domain == "car":
125

126
            self.stylespace_directions = {
127
                "front": [(8, 411)],
128
                "headlights": [(8, 441), (9, 355)],
129
                "grill": [(9, 191)],
130
                "trees": [(9, 108)],
131
                "grass_ss": [(12, 107)],
132
                "sky": [(12, 76)],
133
                "hubcap": [(12, 113), (12, 439)],
134
                "car color": [(12, 142), (15, 227)],
135
                "logo": [(9, 185)],
136
                "wheel angle": [(8, 420)],
137
            }
138

139
            self.ganspace_pca = torch.load("editings/ganspace_pca/cars_pca.pt")
140
            self.ganspace_directions = {
141
                "pose_1": (0, 0, 5, 2),
142
                "pose_2": (0, 0, 5, -2),
143
                "cube": (16, 3, 6, 25),
144
                "color": (22, 9, 11, -8),
145
                "grass": (41, 9, 11, -18)
146
            }
147

148

149
    def load_styleclip_global(self):
150
        delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().cuda()
151
        with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
152
            _, s_std = pickle.load(channels_statistics)
153
            s_std = [torch.from_numpy(s_i).float().cuda() for s_i in s_std]
154
        with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
155
            text_prompt_templates = templates.readlines()
156
        global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
157
        return global_direction_calculator
158

159

160
    def get_styleclip_mapper_edits(self, start_w, factors, direction):
161
        latents_to_display = []
162
        mapper_checkpoint_path = os.path.join(
163
            "pretrained_models/styleclip_mappers",
164
            f"{direction}.pt",
165
        )
166
        ckpt = torch.load(mapper_checkpoint_path, map_location="cpu")
167
        opts = ckpt["opts"]
168
        styleclip_opts = argparse.Namespace(
169
            **{
170
                "mapper_type": "LevelsMapper",
171
                "no_coarse_mapper": self.styleclip_directions[direction][0],
172
                "no_medium_mapper": self.styleclip_directions[direction][1],
173
                "no_fine_mapper": self.styleclip_directions[direction][2],
174
                "stylegan_size": 1024,
175
                "checkpoint_path": mapper_checkpoint_path,
176
            }
177
        )
178
        opts.update(vars(styleclip_opts))
179
        opts = argparse.Namespace(**opts)
180
        style_clip_net = StyleCLIPMapper(opts)
181
        style_clip_net.eval()
182
        style_clip_net.cuda()
183
        direction = style_clip_net.mapper(start_w)
184
        for factor in factors:
185
            edited_latent = start_w + factor * direction
186
            latents_to_display.append(edited_latent)
187

188
        return latents_to_display
189

190

191
    def get_styleclip_global_edits(self, start_s, factors, direction):
192
        latents_to_display = []
193

194
        neutral_text, target_text, disentanglement = direction.split("_")
195
        disentanglement = float(disentanglement)
196

197
        directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
198
        factors = torch.tensor(factors).cuda().view(-1, 1)
199
        srart_ss, start_rgb = start_s
200

201
        edits_rgb = []
202
        edits_ss = []
203
        for i in range(26):
204
            if i in [1, 4, 7, 10, 13, 16, 19, 22, 25]:
205
                edits_rgb.append(directions[i].view(1, -1).repeat(len(factors), 1))
206
            else:
207
                edits_ss.append(directions[i].view(1, -1).repeat(len(factors), 1))
208

209
        edited_rgb = []
210
        edited_ss = []
211
        for orig, edit in zip(srart_ss, edits_ss):
212
            edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5)
213
        for orig, edit in zip(start_rgb, edits_rgb):
214
            edited_rgb.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5)
215
 
216

217
        return edited_ss, edited_rgb
218

219
    def get_deltaedit_edits(self, start_s, factors, direction, original_image):
220
        latents_to_display = []
221

222
        neutral_text, target_text, disentanglement = direction.split("_")
223
        disentanglement = float(disentanglement)
224

225
        factors = torch.tensor(factors).cuda().view(-1, 1)
226
        srart_ss, edited_rgb = start_s
227
        edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
228

229
        edited_rgb = [latent.repeat(len(factors), 1) for latent in edited_rgb]
230
        edited_ss = []
231
        for orig, edit in zip(srart_ss, edits_ss):
232
            edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)))
233

234
        return edited_ss, edited_rgb
235

236

237
    def get_ganspace_edits(self, start_w, factors, direction):
238
        latents_to_display = []
239
        for factor in factors:
240
            ganspace_direction = self.ganspace_directions[direction]
241
            edit_direction = list(ganspace_direction)
242
            edit_direction[-1] = factor
243
            edit_direction = tuple(edit_direction)
244

245
            new_w = ganspace.edit(start_w, self.ganspace_pca, [edit_direction])
246
            latents_to_display.append(new_w)
247
        return latents_to_display
248

249

250
    def get_interface_gan_edits(self, start_w, factors, direction):
251
        latents_to_display = []
252
        for factor in factors:
253
            tensor_direction = self.interfacegan_tensors[direction]
254
            edited_latent = start_w + factor / 2 * tensor_direction
255
            latents_to_display.append(edited_latent)
256
        return latents_to_display
257

258

259
    def get_stylespace_edits(self, start_s, factors, direction):
260
        edits = self.stylespace_directions[direction]
261
        start_stylespaces, start_stylespaces_rgb = start_s
262
        device = start_stylespaces[0].device
263
        latents_to_display = []
264

265
        edited_latent = [
266
            s.clone().repeat(len(factors), 1)
267
            for s in start_stylespaces
268
        ]
269
        factors = torch.tensor(factors).to(device)
270
        for layer_num, feat_num in edits:
271
            edited_latent[STYLESPACE_IDX[layer_num]][:, feat_num] += factors * 3
272
        edited_stylespaces_rgb = [
273
            rgb.repeat(len(factors), 1) for rgb in start_stylespaces_rgb
274
        ]
275

276
        return edited_latent, edited_stylespaces_rgb
277
        
278

279
    def get_fs_edits(self, w, factors, direction):
280
        path = self.fs_directions[direction]
281
        boundary = np.load(path)
282
        device = w.device
283
        bs = w.size(0)
284
        w_0 = w.cpu().numpy().reshape(bs, -1)
285
        boundary = boundary.reshape(1, -1).repeat(bs, 0)
286

287
        edits = [torch.tensor(w_0 + factor * boundary).view(bs, -1, 512).to(device) for factor in factors]
288

289
        return edits
290

291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.