9
from facexlib.assessment import init_assessment_model
10
from facexlib.detection import init_detection_model
14
"""Scripts about evaluating face quality.
16
1) detect the face region and crop the face
17
2) evaluate the face quality by hyperIQA
20
det_net = init_detection_model(args.detection_model_name, half=False)
21
assess_net = init_assessment_model(args.assess_model_name, half=False)
24
transforms = torchvision.transforms.Compose([
25
torchvision.transforms.Resize((512, 384)),
26
torchvision.transforms.RandomCrop(size=224),
27
torchvision.transforms.ToTensor(),
28
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
31
img = cv2.imread(args.img_path)
32
img_name = os.path.basename(args.img_path)
33
basename, _ = os.path.splitext(img_name)
35
bboxes = det_net.detect_faces(img, 0.97)
36
box = list(map(int, bboxes[0]))
39
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42
detect_face = img[box[1]:box[3], box[0]:box[2], :]
43
detect_face = Image.fromarray(detect_face)
45
detect_face = transforms(detect_face)
46
detect_face = torch.tensor(detect_face.cuda()).unsqueeze(0)
48
pred = assess_net(detect_face)
49
pred_scores.append(float(pred.item()))
50
score = np.mean(pred_scores)
52
print(f'{basename} {score:.4f}')
55
if __name__ == '__main__':
56
parser = argparse.ArgumentParser()
57
parser.add_argument('--img_path', type=str, default='assets/test2.jpg')
58
parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
59
parser.add_argument('--assess_model_name', type=str, default='hypernet')
60
parser.add_argument('--half', action='store_true')
61
args = parser.parse_args()