HairFastGAN
/
hair_swap.py
139 строк · 5.6 Кб
1import argparse2import typing as tp3from collections import defaultdict4from functools import wraps5from pathlib import Path6
7import numpy as np8import torch9import torchvision.transforms.functional as F10from PIL import Image11from torchvision.io import read_image, ImageReadMode12
13from models.Alignment import Alignment14from models.Blending import Blending15from models.Embedding import Embedding16from models.Net import Net17from utils.image_utils import equal_replacer18from utils.seed import seed_setter19from utils.shape_predictor import align_face20from utils.time import bench_session21
22TImage = tp.TypeVar('TImage', torch.Tensor, Image.Image, np.ndarray)23TPath = tp.TypeVar('TPath', Path, str)24TReturn = tp.TypeVar('TReturn', torch.Tensor, tuple[torch.Tensor, ...])25
26
27class HairFast:28"""29HairFast implementation with hairstyle transfer interface
30"""
31
32def __init__(self, args):33self.args = args34self.net = Net(self.args)35self.embed = Embedding(args, net=self.net)36self.align = Alignment(args, self.embed.get_e4e_embed, net=self.net)37self.blend = Blending(args, net=self.net)38
39@seed_setter40@bench_session41def __swap_from_tensors(self, face: torch.Tensor, shape: torch.Tensor, color: torch.Tensor,42**kwargs) -> torch.Tensor:43images_to_name = defaultdict(list)44for image, name in zip((face, shape, color), ('face', 'shape', 'color')):45images_to_name[image].append(name)46
47# Embedding stage48name_to_embed = self.embed.embedding_images(images_to_name, **kwargs)49
50# Alignment stage51align_shape = self.align.align_images('face', 'shape', name_to_embed, **kwargs)52
53# Shape Module stage for blending54if shape is not color:55align_color = self.align.shape_module('face', 'color', name_to_embed, **kwargs)56else:57align_color = align_shape58
59# Blending and Post Process stage60final_image = self.blend.blend_images(align_shape, align_color, name_to_embed, **kwargs)61return final_image62
63def swap(self, face_img: TImage | TPath, shape_img: TImage | TPath, color_img: TImage | TPath,64benchmark=False, align=False, seed=None, exp_name=None, **kwargs) -> TReturn:65"""66Run HairFast on the input images to transfer hair shape and color to the desired images.
67:param face_img: face image in Tensor, PIL Image, array or file path format
68:param shape_img: shape image in Tensor, PIL Image, array or file path format
69:param color_img: color image in Tensor, PIL Image, array or file path format
70:param benchmark: starts counting the speed of the session
71:param align: for arbitrary photos crops images to faces
72:param seed: fixes seed for reproducibility, default 3407
73:param exp_name: used as a folder name when 'save_all' model is enabled
74:return: returns the final image as a Tensor
75"""
76images: list[torch.Tensor] = []77path_to_images: dict[TPath, torch.Tensor] = {}78
79for img in (face_img, shape_img, color_img):80if isinstance(img, (torch.Tensor, Image.Image, np.ndarray)):81if not isinstance(img, torch.Tensor):82img = F.to_tensor(img)83elif isinstance(img, (Path, str)):84path_img = img85if path_img not in path_to_images:86path_to_images[path_img] = read_image(str(path_img), mode=ImageReadMode.RGB)87img = path_to_images[path_img]88else:89raise TypeError(f'Unsupported image format {type(img)}')90
91images.append(img)92
93if align:94images = align_face(images)95images = equal_replacer(images)96
97final_image = self.__swap_from_tensors(*images, seed=seed, benchmark=benchmark, exp_name=exp_name, **kwargs)98
99if align:100return final_image, *images101return final_image102
103@wraps(swap)104def __call__(self, *args, **kwargs):105return self.swap(*args, **kwargs)106
107
108def get_parser():109parser = argparse.ArgumentParser(description='HairFast')110
111# I/O arguments112parser.add_argument('--save_all_dir', type=Path, default=Path('output'),113help='the directory to save the latent codes and inversion images')114
115# StyleGAN2 setting116parser.add_argument('--size', type=int, default=1024)117parser.add_argument('--ckpt', type=str, default="pretrained_models/StyleGAN/ffhq.pt")118parser.add_argument('--channel_multiplier', type=int, default=2)119parser.add_argument('--latent', type=int, default=512)120parser.add_argument('--n_mlp', type=int, default=8)121
122# Arguments123parser.add_argument('--device', type=str, default='cuda')124parser.add_argument('--batch_size', type=int, default=3, help='batch size for encoding images')125parser.add_argument('--save_all', action='store_true', help='save and print mode information')126
127# HairFast setting128parser.add_argument('--mixing', type=float, default=0.95, help='hair blending in alignment')129parser.add_argument('--smooth', type=int, default=5, help='dilation and erosion parameter')130parser.add_argument('--rotate_checkpoint', type=str, default='pretrained_models/Rotate/rotate_best.pth')131parser.add_argument('--blending_checkpoint', type=str, default='pretrained_models/Blending/checkpoint.pth')132parser.add_argument('--pp_checkpoint', type=str, default='pretrained_models/PostProcess/pp_model.pth')133return parser134
135
136if __name__ == '__main__':137model_args = get_parser()138args = model_args.parse_args()139hair_fast = HairFast(args)140