HairFastGAN
118 строк · 5.4 Кб
1from collections import defaultdict2
3import torch4import torch.nn.functional as F5import torchvision.transforms as T6from torch import nn7from torch.utils.data import DataLoader8
9from datasets.image_dataset import ImagesDataset, image_collate10from models.FeatureStyleEncoder import FSencoder11from models.Net import Net, get_segmentation12from models.encoder4editing.utils.model_utils import setup_model, get_latents13from utils.bicubic import BicubicDownSample14from utils.save_utils import save_gen_image, save_latents15
16
17class Embedding(nn.Module):18"""19Module for image embedding
20"""
21
22def __init__(self, opts, net=None):23super().__init__()24self.opts = opts25if net is None:26self.net = Net(self.opts)27else:28self.net = net29
30self.encoder = FSencoder.get_trainer(self.opts.device)31self.e4e, _ = setup_model('pretrained_models/encoder4editing/e4e_ffhq_encode.pt', self.opts.device)32
33self.normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))34self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))35
36self.downsample_512 = BicubicDownSample(factor=2)37self.downsample_256 = BicubicDownSample(factor=4)38
39def setup_dataloader(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor], batch_size=None):40self.dataset = ImagesDataset(images)41self.dataloader = DataLoader(self.dataset, collate_fn=image_collate, shuffle=False,42batch_size=batch_size or self.opts.batch_size)43
44@torch.inference_mode()45def get_e4e_embed(self, images: list[torch.Tensor]) -> dict[str, torch.Tensor]:46device = self.opts.device47self.setup_dataloader(images, batch_size=len(images))48
49for image, _ in self.dataloader:50image = image.to(device)51latent_W = get_latents(self.e4e, image)52latent_F, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False,53start_layer=0, end_layer=3)54return {"F": latent_F, "W": latent_W}55
56@torch.inference_mode()57def embedding_images(self, images_to_name: dict[torch.Tensor, list[str]], **kwargs) -> dict[58str, dict[str, torch.Tensor]]:59device = self.opts.device60self.setup_dataloader(images_to_name)61
62name_to_embed = defaultdict(dict)63for image, names in self.dataloader:64image = image.to(device)65
66im_512 = self.downsample_512(image)67im_256 = self.downsample_256(image)68im_256_norm = self.normalize(im_256)69
70# E4E71latent_W = get_latents(self.e4e, im_256_norm)72
73# FS encoder74output = self.encoder.test(img=self.normalize(image), return_latent=True)75latent = output.pop() # [bs, 512, 16, 16]76latent_S = output.pop() # [bs, 18, 512]77
78latent_F, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,79start_layer=3, end_layer=3, layer_in=latent) # [bs, 512, 32, 32]80
81# BiSeNet82masks = torch.cat([get_segmentation(image.unsqueeze(0)) for image in self.to_bisenet(im_512)])83
84# Mixing if we change the color or shape85if len(images_to_name) > 1:86hair_mask = torch.where(masks == 13, torch.ones_like(masks, device=device),87torch.zeros_like(masks, device=device))88hair_mask = F.interpolate(hair_mask.float(), size=(32, 32), mode='bicubic')89
90latent_F_from_W = self.net.generator([latent_W], input_is_latent=True, return_latents=False,91start_layer=0, end_layer=3)[0]92latent_F = latent_F + self.opts.mixing * hair_mask * (latent_F_from_W - latent_F)93
94for k, names in enumerate(names):95for name in names:96name_to_embed[name]['W'] = latent_W[k].unsqueeze(0)97name_to_embed[name]['F'] = latent_F[k].unsqueeze(0)98name_to_embed[name]['S'] = latent_S[k].unsqueeze(0)99name_to_embed[name]['mask'] = masks[k].unsqueeze(0)100name_to_embed[name]['image_256'] = im_256[k].unsqueeze(0)101name_to_embed[name]['image_norm_256'] = im_256_norm[k].unsqueeze(0)102
103if self.opts.save_all:104gen_W_im, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False)105gen_FS_im, _ = self.net.generator([latent_S], input_is_latent=True, return_latents=False,106start_layer=4, end_layer=8, layer_in=latent_F)107
108exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else ""109output_dir = self.opts.save_all_dir / exp_name110for name, im_W, lat_W in zip(names, gen_W_im, latent_W):111save_gen_image(output_dir, 'W+', f'{name}.png', im_W)112save_latents(output_dir, 'W+', f'{name}.npz', latent_W=lat_W)113
114for name, im_F, lat_S, lat_F in zip(names, gen_FS_im, latent_S, latent_F):115save_gen_image(output_dir, 'FS', f'{name}.png', im_F)116save_latents(output_dir, 'FS', f'{name}.npz', latent_S=lat_S, latent_F=lat_F)117
118return name_to_embed119