facexlib

Форк
0
/
inference_parsing.py 
74 строки · 3.0 Кб
1
import argparse
2
import cv2
3
import numpy as np
4
import os
5
import torch
6
from torchvision.transforms.functional import normalize
7

8
from facexlib.parsing import init_parsing_model
9
from facexlib.utils.misc import img2tensor
10

11

12
def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None):
13
    # Colors for all 20 parts
14
    part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0],
15
                   [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255],
16
                   [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255],
17
                   [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]]
18
    # 0: 'background'
19
    # attributions = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
20
    #                 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose',
21
    #                 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l',
22
    #                 16 'cloth', 17 'hair', 18 'hat']
23
    vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
24
    vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
25
    if save_anno_path is not None:
26
        cv2.imwrite(save_anno_path, vis_parsing_anno)
27

28
    if save_vis_path is not None:
29
        vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
30
        num_of_class = np.max(vis_parsing_anno)
31
        for pi in range(1, num_of_class + 1):
32
            index = np.where(vis_parsing_anno == pi)
33
            vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
34

35
        vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
36
        vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0)
37

38
        cv2.imwrite(save_vis_path, vis_im)
39

40

41
def main(img_path, output):
42
    net = init_parsing_model(model_name='bisenet')
43

44
    img_name = os.path.basename(img_path)
45
    img_basename = os.path.splitext(img_name)[0]
46

47
    img_input = cv2.imread(img_path)
48
    img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR)
49
    img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True)
50
    normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True)
51
    img = torch.unsqueeze(img, 0).cuda()
52

53
    with torch.no_grad():
54
        out = net(img)[0]
55
    out = out.squeeze(0).cpu().numpy().argmax(0)
56

57
    vis_parsing_maps(
58
        img_input,
59
        out,
60
        stride=1,
61
        save_anno_path=os.path.join(output, f'{img_basename}.png'),
62
        save_vis_path=os.path.join(output, f'{img_basename}_vis.png'))
63

64

65
if __name__ == '__main__':
66
    parser = argparse.ArgumentParser()
67

68
    parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png')
69
    parser.add_argument('--output', type=str, default='results', help='output folder')
70
    args = parser.parse_args()
71

72
    os.makedirs(args.output, exist_ok=True)
73

74
    main(args.input, args.output)
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.