6
import torchvision.transforms as transforms
9
from basicsr.archs.dfdnet_arch import DFDNet
10
from basicsr.utils import imwrite, tensor2img
13
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
15
print('Please install facexlib: pip install facexlib')
20
def get_part_location(landmarks):
21
"""Get part locations from landmarks."""
22
map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
23
map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
24
map_nose = list(range(29, 36))
25
map_mouth = list(range(48, 68))
28
mean_left_eye = np.mean(landmarks[map_left_eye], 0)
29
half_len_left_eye = np.max(
30
(np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16))
31
loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
32
loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
36
mean_right_eye = np.mean(landmarks[map_right_eye], 0)
37
half_len_right_eye = np.max(
38
(np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16))
39
loc_right_eye = np.hstack(
40
(mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
41
loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
43
mean_nose = np.mean(landmarks[map_nose], 0)
44
half_len_nose = np.max(
45
(np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16))
46
loc_nose = np.hstack((mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
47
loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
49
mean_mouth = np.mean(landmarks[map_mouth], 0)
50
half_len_mouth = np.max(
51
(np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16))
52
loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
53
loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)
55
return loc_left_eye, loc_right_eye, loc_nose, loc_mouth
58
if __name__ == '__main__':
59
"""We try to align to the official codes. But there are still slight
60
differences: 1) we use dlib for 68 landmark detection; 2) the used image
61
package are different (especially for reading and writing.)
63
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
parser = argparse.ArgumentParser()
66
parser.add_argument('--upscale_factor', type=int, default=2)
71
'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
76
'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
77
parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
78
parser.add_argument('--upsample_num_times', type=int, default=1)
79
parser.add_argument('--save_inverse_affine', action='store_true')
80
parser.add_argument('--only_keep_largest', action='store_true')
86
parser.add_argument('--official_adaption', type=bool, default=True)
93
'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat'
99
'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'
105
'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
108
args = parser.parse_args()
109
if args.test_path.endswith('/'):
110
args.test_path = args.test_path[:-1]
111
result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'
114
net = DFDNet(64, dict_path=args.dict_path).to(device)
115
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
116
net.load_state_dict(checkpoint['params'])
119
save_crop_root = os.path.join(result_root, 'cropped_faces')
120
save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
121
os.makedirs(save_inverse_affine_root, exist_ok=True)
122
save_restore_root = os.path.join(result_root, 'restored_faces')
123
save_final_root = os.path.join(result_root, 'final_results')
125
face_helper = FaceRestoreHelper(args.upscale_factor, face_size=512)
128
for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
129
img_name = os.path.basename(img_path)
130
print(f'Processing {img_name} image ...')
131
save_crop_path = os.path.join(save_crop_root, img_name)
132
if args.save_inverse_affine:
133
save_inverse_affine_path = os.path.join(save_inverse_affine_root, img_name)
135
save_inverse_affine_path = None
137
face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
139
num_det_faces = face_helper.detect_faces(
140
img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
142
num_landmarks = face_helper.get_face_landmarks_5()
143
print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
145
face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path)
147
if args.official_adaption:
148
path, ext = os.path.splitext(save_crop_path)
149
paths = sorted(glob.glob(f'{path}_[0-9]*.png'))
150
cropped_faces = [io.imread(path) for path in paths]
152
cropped_faces = face_helper.cropped_faces
155
num_landmarks = face_helper.get_face_landmarks_68()
156
print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
158
face_helper.free_dlib_gpu_memory()
160
print('\tFace restoration ...')
162
assert len(cropped_faces) == len(face_helper.all_landmarks_68)
163
for idx, (cropped_face, landmarks) in enumerate(zip(cropped_faces, face_helper.all_landmarks_68)):
164
if landmarks is None:
165
print(f'Landmarks is None, skip cropped faces with idx {idx}.')
167
restored_face = cropped_face
170
part_locations = get_part_location(landmarks)
171
cropped_face = transforms.ToTensor()(cropped_face)
172
cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
173
cropped_face = cropped_face.unsqueeze(0).to(device)
176
with torch.no_grad():
177
output = net(cropped_face, part_locations)
178
restored_face = tensor2img(output, min_max=(-1, 1))
180
torch.cuda.empty_cache()
181
except Exception as e:
182
print(f'DFDNet inference fail: {e}')
183
restored_face = tensor2img(cropped_face, min_max=(-1, 1))
185
path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
186
save_path = f'{path}_{idx:02d}.png'
187
imwrite(restored_face, save_path)
188
face_helper.add_restored_face(restored_face)
190
print('\tGenerate the final result ...')
192
face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name))
195
face_helper.clean_all()
197
print(f'\nAll results are saved in {result_root}')