HairFastGAN

Форк
0
/
Embedding.py 
118 строк · 5.4 Кб
1
from collections import defaultdict
2

3
import torch
4
import torch.nn.functional as F
5
import torchvision.transforms as T
6
from torch import nn
7
from torch.utils.data import DataLoader
8

9
from datasets.image_dataset import ImagesDataset, image_collate
10
from models.FeatureStyleEncoder import FSencoder
11
from models.Net import Net, get_segmentation
12
from models.encoder4editing.utils.model_utils import setup_model, get_latents
13
from utils.bicubic import BicubicDownSample
14
from utils.save_utils import save_gen_image, save_latents
15

16

17
class Embedding(nn.Module):
18
    """
19
    Module for image embedding
20
    """
21

22
    def __init__(self, opts, net=None):
23
        super().__init__()
24
        self.opts = opts
25
        if net is None:
26
            self.net = Net(self.opts)
27
        else:
28
            self.net = net
29

30
        self.encoder = FSencoder.get_trainer(self.opts.device)
31
        self.e4e, _ = setup_model('pretrained_models/encoder4editing/e4e_ffhq_encode.pt', self.opts.device)
32

33
        self.normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
        self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
35

36
        self.downsample_512 = BicubicDownSample(factor=2)
37
        self.downsample_256 = BicubicDownSample(factor=4)
38

39
    def setup_dataloader(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor], batch_size=None):
40
        self.dataset = ImagesDataset(images)
41
        self.dataloader = DataLoader(self.dataset, collate_fn=image_collate, shuffle=False,
42
                                     batch_size=batch_size or self.opts.batch_size)
43

44
    @torch.inference_mode()
45
    def get_e4e_embed(self, images: list[torch.Tensor]) -> dict[str, torch.Tensor]:
46
        device = self.opts.device
47
        self.setup_dataloader(images, batch_size=len(images))
48

49
        for image, _ in self.dataloader:
50
            image = image.to(device)
51
            latent_W = get_latents(self.e4e, image)
52
            latent_F, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False,
53
                                             start_layer=0, end_layer=3)
54
            return {"F": latent_F, "W": latent_W}
55

56
    @torch.inference_mode()
57
    def embedding_images(self, images_to_name: dict[torch.Tensor, list[str]], **kwargs) -> dict[
58
        str, dict[str, torch.Tensor]]:
59
        device = self.opts.device
60
        self.setup_dataloader(images_to_name)
61

62
        name_to_embed = defaultdict(dict)
63
        for image, names in self.dataloader:
64
            image = image.to(device)
65

66
            im_512 = self.downsample_512(image)
67
            im_256 = self.downsample_256(image)
68
            im_256_norm = self.normalize(im_256)
69

70
            # E4E
71
            latent_W = get_latents(self.e4e, im_256_norm)
72

73
            # FS encoder
74
            output = self.encoder.test(img=self.normalize(image), return_latent=True)
75
            latent = output.pop()  # [bs, 512, 16, 16]
76
            latent_S = output.pop()  # [bs, 18, 512]
77

78
            latent_F, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,
79
                                             start_layer=3, end_layer=3, layer_in=latent)  # [bs, 512, 32, 32]
80

81
            # BiSeNet
82
            masks = torch.cat([get_segmentation(image.unsqueeze(0)) for image in self.to_bisenet(im_512)])
83

84
            # Mixing if we change the color or shape
85
            if len(images_to_name) > 1:
86
                hair_mask = torch.where(masks == 13, torch.ones_like(masks, device=device),
87
                                        torch.zeros_like(masks, device=device))
88
                hair_mask = F.interpolate(hair_mask.float(), size=(32, 32), mode='bicubic')
89

90
                latent_F_from_W = self.net.generator([latent_W], input_is_latent=True, return_latents=False,
91
                                                     start_layer=0, end_layer=3)[0]
92
                latent_F = latent_F + self.opts.mixing * hair_mask * (latent_F_from_W - latent_F)
93

94
            for k, names in enumerate(names):
95
                for name in names:
96
                    name_to_embed[name]['W'] = latent_W[k].unsqueeze(0)
97
                    name_to_embed[name]['F'] = latent_F[k].unsqueeze(0)
98
                    name_to_embed[name]['S'] = latent_S[k].unsqueeze(0)
99
                    name_to_embed[name]['mask'] = masks[k].unsqueeze(0)
100
                    name_to_embed[name]['image_256'] = im_256[k].unsqueeze(0)
101
                    name_to_embed[name]['image_norm_256'] = im_256_norm[k].unsqueeze(0)
102

103
            if self.opts.save_all:
104
                gen_W_im, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False)
105
                gen_FS_im, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,
106
                                                  start_layer=4, end_layer=8, layer_in=latent_F)
107

108
                exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else ""
109
                output_dir = self.opts.save_all_dir / exp_name
110
                for name, im_W, lat_W in zip(names, gen_W_im, latent_W):
111
                    save_gen_image(output_dir, 'W+', f'{name}.png', im_W)
112
                    save_latents(output_dir, 'W+', f'{name}.npz', latent_W=lat_W)
113

114
                for name, im_F, lat_S, lat_F in zip(names, gen_FS_im, latent_S, latent_F):
115
                    save_gen_image(output_dir, 'FS', f'{name}.png', im_F)
116
                    save_latents(output_dir, 'FS', f'{name}.npz', latent_S=lat_S, latent_F=lat_F)
117

118
        return name_to_embed
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.