StyleFeatureEditor

Форк
0
32 строки · 1.1 Кб
1
import torch
2

3

4
def edit(latents, pca, edit_directions):
5
    edit_latents = []
6
    for latent in latents:
7
        for pca_idx, start, end, strength in edit_directions:
8
            delta = get_delta(pca, latent, pca_idx, strength)
9
            delta_padded = torch.zeros(latent.shape).to("cuda")
10
            delta_padded[start:end] += delta.repeat(end - start, 1)
11
            edit_latents.append(latent + delta_padded)
12
    return torch.stack(edit_latents)
13

14

15
def get_delta(pca, latent, idx, strength):
16
    w_centered = latent - pca["mean"].to("cuda")
17
    lat_comp = pca["comp"].to("cuda")
18
    lat_std = pca["std"].to("cuda")
19
    w_coord = (
20
        torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
21
    )
22
    delta = (strength - w_coord) * lat_comp[idx] * lat_std[idx]
23
    return delta
24

25

26
def edit_latent(latent, pca, edit_direction):
27
    pca_idx, start, end, strength = edit_direction
28
    delta = get_delta(pca, latent, pca_idx, strength)
29
    delta_padded = torch.zeros(latent.shape).to("cuda")
30
    delta_padded[start:end] += delta.repeat(end - start, 1)
31
    edit_latent = latent + delta_padded
32
    return edit_latent
33

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.