StyleFeatureEditor
289 строк · 10.5 Кб
1import os
2import sys
3import torch
4import pickle
5import argparse
6import numpy as np
7
8from editings import ganspace
9from editings.styleclip.mapper.styleclip_mapper import StyleCLIPMapper
10from editings.styleclip.mapper.gloabl_mapper import StyleCLIPGlobalDirection
11from editings.deltaedit.editor import DeltaEditor
12
13
14STYLESPACE_IDX = [
150,
161,
171,
182,
192,
203,
214,
224,
235,
246,
256,
267,
278,
288,
299,
3010,
3110,
3211,
3312,
3412,
3513,
3614,
3714,
3815,
3916,
4016,
41]
42
43
44class LatentEditor:
45def __init__(self, domain="human_faces"):
46
47self.domain = domain
48
49if self.domain == "human_faces":
50self.interfacegan_directions = {
51"age": "editings/interfacegan_directions/age.pt",
52"smile": "editings/interfacegan_directions/smile.pt",
53"rotation": "editings/interfacegan_directions/rotation.pt",
54}
55self.interfacegan_tensors = {
56name: torch.load(path).cuda()
57for name, path in self.interfacegan_directions.items()
58}
59
60self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt")
61self.ganspace_directions = {
62"eye_openness": (54, 7, 8, 5),
63"trimmed_beard": (58, 7, 9, 7),
64"lipstick": (34, 10, 11, 20),
65"face_roundness": (37, 0, 5, 20.0),
66"nose_length": (51, 4, 5, -30.0),
67"eyebrow_thickness": (37, 8, 9, 20.0),
68"head_angle_up": (11, 1, 4, -10.5),
69"displeased": (36, 4, 7, 10.0),
70}
71
72self.styleclip_directions = {
73"afro": [False, False, True],
74"angry": [False, False, True],
75"beyonce": [False, False, False],
76"bobcut": [False, False, True],
77"bowlcut": [False, False, True],
78"curly_hair": [False, False, True],
79"hilary_clinton": [False, False, False],
80"depp": [False, False, False],
81"mohawk": [False, False, True],
82"purple_hair": [False, False, False],
83"surprised": [False, False, True],
84"taylor_swift": [False, False, False],
85"trump": [False, False, False],
86"zuckerberg": [False, False, False],
87}
88self.styleclip_global_editor = self.load_styleclip_global()
89
90self.stylespace_directions = {
91"black hair": [(12, 479)],
92"blond hair": [(12, 479), (12, 266)],
93"grey hair": [(11, 286)],
94"wavy hair": [(6, 500), (8, 128), (5, 92), (6, 394), (6, 323)],
95"bangs": [
96(3, 259),
97(6, 285),
98(5, 414),
99(6, 128),
100(9, 295),
101(6, 322),
102(6, 487),
103(6, 504),
104],
105"receding hairline": [(5, 414), (6, 322), (6, 497), (6, 504)],
106"smiling": [(6, 501)],
107"sslipstick": [(15, 45)],
108"sideburns": [(12, 237)],
109"goatee": [(9, 421)],
110"earrings": [(8, 81)],
111"glasses": [(3, 288), (2, 175), (3, 120), (2, 97)],
112"wear suit": [(9, 441), (8, 292), (11, 358), (6, 223)],
113"gender": [(9, 6)],
114}
115
116self.fs_directions = {
117"fs_glasses": "editings/bound/Eyeglasses_boundary.npy",
118"fs_smiling": "editings/bound/Smiling_boundary.npy",
119"fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
120}
121
122self.deltaedit_editor = DeltaEditor()
123
124elif self.domain == "car":
125
126self.stylespace_directions = {
127"front": [(8, 411)],
128"headlights": [(8, 441), (9, 355)],
129"grill": [(9, 191)],
130"trees": [(9, 108)],
131"grass_ss": [(12, 107)],
132"sky": [(12, 76)],
133"hubcap": [(12, 113), (12, 439)],
134"car color": [(12, 142), (15, 227)],
135"logo": [(9, 185)],
136"wheel angle": [(8, 420)],
137}
138
139self.ganspace_pca = torch.load("editings/ganspace_pca/cars_pca.pt")
140self.ganspace_directions = {
141"pose_1": (0, 0, 5, 2),
142"pose_2": (0, 0, 5, -2),
143"cube": (16, 3, 6, 25),
144"color": (22, 9, 11, -8),
145"grass": (41, 9, 11, -18)
146}
147
148
149def load_styleclip_global(self):
150delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().cuda()
151with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
152_, s_std = pickle.load(channels_statistics)
153s_std = [torch.from_numpy(s_i).float().cuda() for s_i in s_std]
154with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
155text_prompt_templates = templates.readlines()
156global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
157return global_direction_calculator
158
159
160def get_styleclip_mapper_edits(self, start_w, factors, direction):
161latents_to_display = []
162mapper_checkpoint_path = os.path.join(
163"pretrained_models/styleclip_mappers",
164f"{direction}.pt",
165)
166ckpt = torch.load(mapper_checkpoint_path, map_location="cpu")
167opts = ckpt["opts"]
168styleclip_opts = argparse.Namespace(
169**{
170"mapper_type": "LevelsMapper",
171"no_coarse_mapper": self.styleclip_directions[direction][0],
172"no_medium_mapper": self.styleclip_directions[direction][1],
173"no_fine_mapper": self.styleclip_directions[direction][2],
174"stylegan_size": 1024,
175"checkpoint_path": mapper_checkpoint_path,
176}
177)
178opts.update(vars(styleclip_opts))
179opts = argparse.Namespace(**opts)
180style_clip_net = StyleCLIPMapper(opts)
181style_clip_net.eval()
182style_clip_net.cuda()
183direction = style_clip_net.mapper(start_w)
184for factor in factors:
185edited_latent = start_w + factor * direction
186latents_to_display.append(edited_latent)
187
188return latents_to_display
189
190
191def get_styleclip_global_edits(self, start_s, factors, direction):
192latents_to_display = []
193
194neutral_text, target_text, disentanglement = direction.split("_")
195disentanglement = float(disentanglement)
196
197directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
198factors = torch.tensor(factors).cuda().view(-1, 1)
199srart_ss, start_rgb = start_s
200
201edits_rgb = []
202edits_ss = []
203for i in range(26):
204if i in [1, 4, 7, 10, 13, 16, 19, 22, 25]:
205edits_rgb.append(directions[i].view(1, -1).repeat(len(factors), 1))
206else:
207edits_ss.append(directions[i].view(1, -1).repeat(len(factors), 1))
208
209edited_rgb = []
210edited_ss = []
211for orig, edit in zip(srart_ss, edits_ss):
212edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5)
213for orig, edit in zip(start_rgb, edits_rgb):
214edited_rgb.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)) / 1.5)
215
216
217return edited_ss, edited_rgb
218
219def get_deltaedit_edits(self, start_s, factors, direction, original_image):
220latents_to_display = []
221
222neutral_text, target_text, disentanglement = direction.split("_")
223disentanglement = float(disentanglement)
224
225factors = torch.tensor(factors).cuda().view(-1, 1)
226srart_ss, edited_rgb = start_s
227edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
228
229edited_rgb = [latent.repeat(len(factors), 1) for latent in edited_rgb]
230edited_ss = []
231for orig, edit in zip(srart_ss, edits_ss):
232edited_ss.append(orig.repeat(len(factors), 1) + edit * factors.repeat(1, orig.size(1)))
233
234return edited_ss, edited_rgb
235
236
237def get_ganspace_edits(self, start_w, factors, direction):
238latents_to_display = []
239for factor in factors:
240ganspace_direction = self.ganspace_directions[direction]
241edit_direction = list(ganspace_direction)
242edit_direction[-1] = factor
243edit_direction = tuple(edit_direction)
244
245new_w = ganspace.edit(start_w, self.ganspace_pca, [edit_direction])
246latents_to_display.append(new_w)
247return latents_to_display
248
249
250def get_interface_gan_edits(self, start_w, factors, direction):
251latents_to_display = []
252for factor in factors:
253tensor_direction = self.interfacegan_tensors[direction]
254edited_latent = start_w + factor / 2 * tensor_direction
255latents_to_display.append(edited_latent)
256return latents_to_display
257
258
259def get_stylespace_edits(self, start_s, factors, direction):
260edits = self.stylespace_directions[direction]
261start_stylespaces, start_stylespaces_rgb = start_s
262device = start_stylespaces[0].device
263latents_to_display = []
264
265edited_latent = [
266s.clone().repeat(len(factors), 1)
267for s in start_stylespaces
268]
269factors = torch.tensor(factors).to(device)
270for layer_num, feat_num in edits:
271edited_latent[STYLESPACE_IDX[layer_num]][:, feat_num] += factors * 3
272edited_stylespaces_rgb = [
273rgb.repeat(len(factors), 1) for rgb in start_stylespaces_rgb
274]
275
276return edited_latent, edited_stylespaces_rgb
277
278
279def get_fs_edits(self, w, factors, direction):
280path = self.fs_directions[direction]
281boundary = np.load(path)
282device = w.device
283bs = w.size(0)
284w_0 = w.cpu().numpy().reshape(bs, -1)
285boundary = boundary.reshape(1, -1).repeat(bs, 0)
286
287edits = [torch.tensor(w_0 + factor * boundary).view(bs, -1, 512).to(device) for factor in factors]
288
289return edits
290
291