StyleFeatureEditor
32 строки · 1.1 Кб
1import torch
2
3
4def edit(latents, pca, edit_directions):
5edit_latents = []
6for latent in latents:
7for pca_idx, start, end, strength in edit_directions:
8delta = get_delta(pca, latent, pca_idx, strength)
9delta_padded = torch.zeros(latent.shape).to("cuda")
10delta_padded[start:end] += delta.repeat(end - start, 1)
11edit_latents.append(latent + delta_padded)
12return torch.stack(edit_latents)
13
14
15def get_delta(pca, latent, idx, strength):
16w_centered = latent - pca["mean"].to("cuda")
17lat_comp = pca["comp"].to("cuda")
18lat_std = pca["std"].to("cuda")
19w_coord = (
20torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
21)
22delta = (strength - w_coord) * lat_comp[idx] * lat_std[idx]
23return delta
24
25
26def edit_latent(latent, pca, edit_direction):
27pca_idx, start, end, strength = edit_direction
28delta = get_delta(pca, latent, pca_idx, strength)
29delta_padded = torch.zeros(latent.shape).to("cuda")
30delta_padded[start:end] += delta.repeat(end - start, 1)
31edit_latent = latent + delta_padded
32return edit_latent
33