facexlib

Форк
0
/
inference_headpose.py 
54 строки · 2.0 Кб
1
import argparse
2
import cv2
3
import numpy as np
4
import torch
5
from torchvision.transforms.functional import normalize
6

7
from facexlib.detection import init_detection_model
8
from facexlib.headpose import init_headpose_model
9
from facexlib.utils.misc import img2tensor
10
from facexlib.visualization import visualize_headpose
11

12

13
def main(args):
14
    # initialize model
15
    det_net = init_detection_model(args.detection_model_name, half=args.half)
16
    headpose_net = init_headpose_model(args.headpose_model_name, half=args.half)
17

18
    img = cv2.imread(args.img_path)
19
    with torch.no_grad():
20
        bboxes = det_net.detect_faces(img, 0.97)
21
        # x0, y0, x1, y1, confidence_score, five points (x, y)
22
        bbox = list(map(int, bboxes[0]))
23
        # crop face region
24
        thld = 10
25
        h, w, _ = img.shape
26
        top = max(bbox[1] - thld, 0)
27
        bottom = min(bbox[3] + thld, h)
28
        left = max(bbox[0] - thld, 0)
29
        right = min(bbox[2] + thld, w)
30

31
        det_face = img[top:bottom, left:right, :].astype(np.float32) / 255.
32

33
        # resize
34
        det_face = cv2.resize(det_face, (224, 224), interpolation=cv2.INTER_LINEAR)
35
        det_face = img2tensor(np.copy(det_face), bgr2rgb=False)
36

37
        # normalize
38
        normalize(det_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True)
39
        det_face = det_face.unsqueeze(0).cuda()
40

41
        yaw, pitch, roll = headpose_net(det_face)
42
        visualize_headpose(img, yaw, pitch, roll, args.save_path)
43

44

45
if __name__ == '__main__':
46
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
47
    parser.add_argument('--img_path', type=str, default='assets/test.jpg')
48
    parser.add_argument('--save_path', type=str, default='assets/test_headpose.png')
49
    parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
50
    parser.add_argument('--headpose_model_name', type=str, default='hopenet')
51
    parser.add_argument('--half', action='store_true')
52
    args = parser.parse_args()
53

54
    main(args)
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.