facexlib

Форк
0
/
inference_hyperiqa.py 
63 строки · 2.2 Кб
1
import argparse
2
import cv2
3
import numpy as np
4
import os
5
import torch
6
import torchvision
7
from PIL import Image
8

9
from facexlib.assessment import init_assessment_model
10
from facexlib.detection import init_detection_model
11

12

13
def main(args):
14
    """Scripts about evaluating face quality.
15
        Two steps:
16
        1) detect the face region and crop the face
17
        2) evaluate the face quality by hyperIQA
18
    """
19
    # initialize model
20
    det_net = init_detection_model(args.detection_model_name, half=False)
21
    assess_net = init_assessment_model(args.assess_model_name, half=False)
22

23
    # specified face transformation in original hyperIQA
24
    transforms = torchvision.transforms.Compose([
25
        torchvision.transforms.Resize((512, 384)),
26
        torchvision.transforms.RandomCrop(size=224),
27
        torchvision.transforms.ToTensor(),
28
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
29
    ])
30

31
    img = cv2.imread(args.img_path)
32
    img_name = os.path.basename(args.img_path)
33
    basename, _ = os.path.splitext(img_name)
34
    with torch.no_grad():
35
        bboxes = det_net.detect_faces(img, 0.97)
36
        box = list(map(int, bboxes[0]))
37
        pred_scores = []
38
        # BRG -> RGB
39
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40

41
        for i in range(10):
42
            detect_face = img[box[1]:box[3], box[0]:box[2], :]
43
            detect_face = Image.fromarray(detect_face)
44

45
            detect_face = transforms(detect_face)
46
            detect_face = torch.tensor(detect_face.cuda()).unsqueeze(0)
47

48
            pred = assess_net(detect_face)
49
            pred_scores.append(float(pred.item()))
50
        score = np.mean(pred_scores)
51
        # quality score ranges from 0-100, a higher score indicates a better quality
52
        print(f'{basename} {score:.4f}')
53

54

55
if __name__ == '__main__':
56
    parser = argparse.ArgumentParser()
57
    parser.add_argument('--img_path', type=str, default='assets/test2.jpg')
58
    parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
59
    parser.add_argument('--assess_model_name', type=str, default='hypernet')
60
    parser.add_argument('--half', action='store_true')
61
    args = parser.parse_args()
62

63
    main(args)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.