BasicSR

Форк
0
/
inference_dfdnet.py 
197 строк · 8.6 Кб
1
import argparse
2
import glob
3
import numpy as np
4
import os
5
import torch
6
import torchvision.transforms as transforms
7
from skimage import io
8

9
from basicsr.archs.dfdnet_arch import DFDNet
10
from basicsr.utils import imwrite, tensor2img
11

12
try:
13
    from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14
except ImportError:
15
    print('Please install facexlib: pip install facexlib')
16

17
# TODO: need to modify, as we have updated the FaceRestorationHelper
18

19

20
def get_part_location(landmarks):
21
    """Get part locations from landmarks."""
22
    map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
23
    map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
24
    map_nose = list(range(29, 36))
25
    map_mouth = list(range(48, 68))
26

27
    # left eye
28
    mean_left_eye = np.mean(landmarks[map_left_eye], 0)  # (x, y)
29
    half_len_left_eye = np.max(
30
        (np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16))  # A number
31
    loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
32
    loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
33
    # (1, 4), the four numbers forms two  coordinates in the diagonal
34

35
    # right eye
36
    mean_right_eye = np.mean(landmarks[map_right_eye], 0)
37
    half_len_right_eye = np.max(
38
        (np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16))
39
    loc_right_eye = np.hstack(
40
        (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
41
    loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
42
    # nose
43
    mean_nose = np.mean(landmarks[map_nose], 0)
44
    half_len_nose = np.max(
45
        (np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16))  # noqa: E126
46
    loc_nose = np.hstack((mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
47
    loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
48
    # mouth
49
    mean_mouth = np.mean(landmarks[map_mouth], 0)
50
    half_len_mouth = np.max(
51
        (np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16))  # noqa: E126
52
    loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
53
    loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)
54

55
    return loc_left_eye, loc_right_eye, loc_nose, loc_mouth
56

57

58
if __name__ == '__main__':
59
    """We try to align to the official codes. But there are still slight
60
    differences: 1) we use dlib for 68 landmark detection; 2) the used image
61
    package are different (especially for reading and writing.)
62
    """
63
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
    parser = argparse.ArgumentParser()
65

66
    parser.add_argument('--upscale_factor', type=int, default=2)
67
    parser.add_argument(
68
        '--model_path',
69
        type=str,
70
        default=  # noqa: E251
71
        'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
72
    parser.add_argument(
73
        '--dict_path',
74
        type=str,
75
        default=  # noqa: E251
76
        'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
77
    parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
78
    parser.add_argument('--upsample_num_times', type=int, default=1)
79
    parser.add_argument('--save_inverse_affine', action='store_true')
80
    parser.add_argument('--only_keep_largest', action='store_true')
81
    # The official codes use skimage.io to read the cropped images from disk
82
    # instead of directly using the intermediate results in the memory (as we
83
    # do). Such a different operation brings slight differences due to
84
    # skimage.io. For aligning with the official results, we could set the
85
    # official_adaption to True.
86
    parser.add_argument('--official_adaption', type=bool, default=True)
87

88
    # The following are the paths for dlib models
89
    parser.add_argument(
90
        '--detection_path',
91
        type=str,
92
        default=  # noqa: E251
93
        'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat'  # noqa: E501
94
    )
95
    parser.add_argument(
96
        '--landmark5_path',
97
        type=str,
98
        default=  # noqa: E251
99
        'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'  # noqa: E501
100
    )
101
    parser.add_argument(
102
        '--landmark68_path',
103
        type=str,
104
        default=  # noqa: E251
105
        'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'  # noqa: E501
106
    )
107

108
    args = parser.parse_args()
109
    if args.test_path.endswith('/'):  # solve when path ends with /
110
        args.test_path = args.test_path[:-1]
111
    result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'
112

113
    # set up the DFDNet
114
    net = DFDNet(64, dict_path=args.dict_path).to(device)
115
    checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
116
    net.load_state_dict(checkpoint['params'])
117
    net.eval()
118

119
    save_crop_root = os.path.join(result_root, 'cropped_faces')
120
    save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
121
    os.makedirs(save_inverse_affine_root, exist_ok=True)
122
    save_restore_root = os.path.join(result_root, 'restored_faces')
123
    save_final_root = os.path.join(result_root, 'final_results')
124

125
    face_helper = FaceRestoreHelper(args.upscale_factor, face_size=512)
126

127
    # scan all the jpg and png images
128
    for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
129
        img_name = os.path.basename(img_path)
130
        print(f'Processing {img_name} image ...')
131
        save_crop_path = os.path.join(save_crop_root, img_name)
132
        if args.save_inverse_affine:
133
            save_inverse_affine_path = os.path.join(save_inverse_affine_root, img_name)
134
        else:
135
            save_inverse_affine_path = None
136

137
        face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
138
        # detect faces
139
        num_det_faces = face_helper.detect_faces(
140
            img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
141
        # get 5 face landmarks for each face
142
        num_landmarks = face_helper.get_face_landmarks_5()
143
        print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
144
        # warp and crop each face
145
        face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path)
146

147
        if args.official_adaption:
148
            path, ext = os.path.splitext(save_crop_path)
149
            paths = sorted(glob.glob(f'{path}_[0-9]*.png'))
150
            cropped_faces = [io.imread(path) for path in paths]
151
        else:
152
            cropped_faces = face_helper.cropped_faces
153

154
        # get 68 landmarks for each cropped face
155
        num_landmarks = face_helper.get_face_landmarks_68()
156
        print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
157

158
        face_helper.free_dlib_gpu_memory()
159

160
        print('\tFace restoration ...')
161
        # face restoration for each cropped face
162
        assert len(cropped_faces) == len(face_helper.all_landmarks_68)
163
        for idx, (cropped_face, landmarks) in enumerate(zip(cropped_faces, face_helper.all_landmarks_68)):
164
            if landmarks is None:
165
                print(f'Landmarks is None, skip cropped faces with idx {idx}.')
166
                # just copy the cropped faces to the restored faces
167
                restored_face = cropped_face
168
            else:
169
                # prepare data
170
                part_locations = get_part_location(landmarks)
171
                cropped_face = transforms.ToTensor()(cropped_face)
172
                cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
173
                cropped_face = cropped_face.unsqueeze(0).to(device)
174

175
                try:
176
                    with torch.no_grad():
177
                        output = net(cropped_face, part_locations)
178
                        restored_face = tensor2img(output, min_max=(-1, 1))
179
                    del output
180
                    torch.cuda.empty_cache()
181
                except Exception as e:
182
                    print(f'DFDNet inference fail: {e}')
183
                    restored_face = tensor2img(cropped_face, min_max=(-1, 1))
184

185
            path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
186
            save_path = f'{path}_{idx:02d}.png'
187
            imwrite(restored_face, save_path)
188
            face_helper.add_restored_face(restored_face)
189

190
        print('\tGenerate the final result ...')
191
        # paste each restored face to the input image
192
        face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name))
193

194
        # clean all the intermediate results to process the next image
195
        face_helper.clean_all()
196

197
    print(f'\nAll results are saved in {result_root}')
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.