5
from torchvision.transforms.functional import normalize
7
from facexlib.detection import init_detection_model
8
from facexlib.headpose import init_headpose_model
9
from facexlib.utils.misc import img2tensor
10
from facexlib.visualization import visualize_headpose
15
det_net = init_detection_model(args.detection_model_name, half=args.half)
16
headpose_net = init_headpose_model(args.headpose_model_name, half=args.half)
18
img = cv2.imread(args.img_path)
20
bboxes = det_net.detect_faces(img, 0.97)
21
# x0, y0, x1, y1, confidence_score, five points (x, y)
22
bbox = list(map(int, bboxes[0]))
26
top = max(bbox[1] - thld, 0)
27
bottom = min(bbox[3] + thld, h)
28
left = max(bbox[0] - thld, 0)
29
right = min(bbox[2] + thld, w)
31
det_face = img[top:bottom, left:right, :].astype(np.float32) / 255.
34
det_face = cv2.resize(det_face, (224, 224), interpolation=cv2.INTER_LINEAR)
35
det_face = img2tensor(np.copy(det_face), bgr2rgb=False)
38
normalize(det_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True)
39
det_face = det_face.unsqueeze(0).cuda()
41
yaw, pitch, roll = headpose_net(det_face)
42
visualize_headpose(img, yaw, pitch, roll, args.save_path)
45
if __name__ == '__main__':
46
parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
47
parser.add_argument('--img_path', type=str, default='assets/test.jpg')
48
parser.add_argument('--save_path', type=str, default='assets/test_headpose.png')
49
parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
50
parser.add_argument('--headpose_model_name', type=str, default='hopenet')
51
parser.add_argument('--half', action='store_true')
52
args = parser.parse_args()