HairFastGAN

Форк
0
/
hair_swap.py 
139 строк · 5.6 Кб
1
import argparse
2
import typing as tp
3
from collections import defaultdict
4
from functools import wraps
5
from pathlib import Path
6

7
import numpy as np
8
import torch
9
import torchvision.transforms.functional as F
10
from PIL import Image
11
from torchvision.io import read_image, ImageReadMode
12

13
from models.Alignment import Alignment
14
from models.Blending import Blending
15
from models.Embedding import Embedding
16
from models.Net import Net
17
from utils.image_utils import equal_replacer
18
from utils.seed import seed_setter
19
from utils.shape_predictor import align_face
20
from utils.time import bench_session
21

22
TImage = tp.TypeVar('TImage', torch.Tensor, Image.Image, np.ndarray)
23
TPath = tp.TypeVar('TPath', Path, str)
24
TReturn = tp.TypeVar('TReturn', torch.Tensor, tuple[torch.Tensor, ...])
25

26

27
class HairFast:
28
    """
29
    HairFast implementation with hairstyle transfer interface
30
    """
31

32
    def __init__(self, args):
33
        self.args = args
34
        self.net = Net(self.args)
35
        self.embed = Embedding(args, net=self.net)
36
        self.align = Alignment(args, self.embed.get_e4e_embed, net=self.net)
37
        self.blend = Blending(args, net=self.net)
38

39
    @seed_setter
40
    @bench_session
41
    def __swap_from_tensors(self, face: torch.Tensor, shape: torch.Tensor, color: torch.Tensor,
42
                            **kwargs) -> torch.Tensor:
43
        images_to_name = defaultdict(list)
44
        for image, name in zip((face, shape, color), ('face', 'shape', 'color')):
45
            images_to_name[image].append(name)
46

47
        # Embedding stage
48
        name_to_embed = self.embed.embedding_images(images_to_name, **kwargs)
49

50
        # Alignment stage
51
        align_shape = self.align.align_images('face', 'shape', name_to_embed, **kwargs)
52

53
        # Shape Module stage for blending
54
        if shape is not color:
55
            align_color = self.align.shape_module('face', 'color', name_to_embed, **kwargs)
56
        else:
57
            align_color = align_shape
58

59
        # Blending and Post Process stage
60
        final_image = self.blend.blend_images(align_shape, align_color, name_to_embed, **kwargs)
61
        return final_image
62

63
    def swap(self, face_img: TImage | TPath, shape_img: TImage | TPath, color_img: TImage | TPath,
64
             benchmark=False, align=False, seed=None, exp_name=None, **kwargs) -> TReturn:
65
        """
66
        Run HairFast on the input images to transfer hair shape and color to the desired images.
67
        :param face_img:  face image in Tensor, PIL Image, array or file path format
68
        :param shape_img: shape image in Tensor, PIL Image, array or file path format
69
        :param color_img: color image in Tensor, PIL Image, array or file path format
70
        :param benchmark: starts counting the speed of the session
71
        :param align:     for arbitrary photos crops images to faces
72
        :param seed:      fixes seed for reproducibility, default 3407
73
        :param exp_name:  used as a folder name when 'save_all' model is enabled
74
        :return:          returns the final image as a Tensor
75
        """
76
        images: list[torch.Tensor] = []
77
        path_to_images: dict[TPath, torch.Tensor] = {}
78

79
        for img in (face_img, shape_img, color_img):
80
            if isinstance(img, (torch.Tensor, Image.Image, np.ndarray)):
81
                if not isinstance(img, torch.Tensor):
82
                    img = F.to_tensor(img)
83
            elif isinstance(img, (Path, str)):
84
                path_img = img
85
                if path_img not in path_to_images:
86
                    path_to_images[path_img] = read_image(str(path_img), mode=ImageReadMode.RGB)
87
                img = path_to_images[path_img]
88
            else:
89
                raise TypeError(f'Unsupported image format {type(img)}')
90

91
            images.append(img)
92

93
        if align:
94
            images = align_face(images)
95
        images = equal_replacer(images)
96

97
        final_image = self.__swap_from_tensors(*images, seed=seed, benchmark=benchmark, exp_name=exp_name, **kwargs)
98

99
        if align:
100
            return final_image, *images
101
        return final_image
102

103
    @wraps(swap)
104
    def __call__(self, *args, **kwargs):
105
        return self.swap(*args, **kwargs)
106

107

108
def get_parser():
109
    parser = argparse.ArgumentParser(description='HairFast')
110

111
    # I/O arguments
112
    parser.add_argument('--save_all_dir', type=Path, default=Path('output'),
113
                        help='the directory to save the latent codes and inversion images')
114

115
    # StyleGAN2 setting
116
    parser.add_argument('--size', type=int, default=1024)
117
    parser.add_argument('--ckpt', type=str, default="pretrained_models/StyleGAN/ffhq.pt")
118
    parser.add_argument('--channel_multiplier', type=int, default=2)
119
    parser.add_argument('--latent', type=int, default=512)
120
    parser.add_argument('--n_mlp', type=int, default=8)
121

122
    # Arguments
123
    parser.add_argument('--device', type=str, default='cuda')
124
    parser.add_argument('--batch_size', type=int, default=3, help='batch size for encoding images')
125
    parser.add_argument('--save_all', action='store_true', help='save and print mode information')
126

127
    # HairFast setting
128
    parser.add_argument('--mixing', type=float, default=0.95, help='hair blending in alignment')
129
    parser.add_argument('--smooth', type=int, default=5, help='dilation and erosion parameter')
130
    parser.add_argument('--rotate_checkpoint', type=str, default='pretrained_models/Rotate/rotate_best.pth')
131
    parser.add_argument('--blending_checkpoint', type=str, default='pretrained_models/Blending/checkpoint.pth')
132
    parser.add_argument('--pp_checkpoint', type=str, default='pretrained_models/PostProcess/pp_model.pth')
133
    return parser
134

135

136
if __name__ == '__main__':
137
    model_args = get_parser()
138
    args = model_args.parse_args()
139
    hair_fast = HairFast(args)
140

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.