facexlib

Форк
0
/
inference_matting.py 
65 строк · 1.9 Кб
1
import argparse
2
import cv2
3
import numpy as np
4
import torch.nn.functional as F
5
from torchvision.transforms.functional import normalize
6

7
from facexlib.matting import init_matting_model
8
from facexlib.utils import img2tensor
9

10

11
def main(args):
12
    modnet = init_matting_model()
13

14
    # read image
15
    img = cv2.imread(args.img_path) / 255.
16
    # unify image channels to 3
17
    if len(img.shape) == 2:
18
        img = img[:, :, None]
19
    if img.shape[2] == 1:
20
        img = np.repeat(img, 3, axis=2)
21
    elif img.shape[2] == 4:
22
        img = img[:, :, 0:3]
23

24
    img_t = img2tensor(img, bgr2rgb=True, float32=True)
25
    normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
26
    img_t = img_t.unsqueeze(0).cuda()
27

28
    # resize image for input
29
    _, _, im_h, im_w = img_t.shape
30
    ref_size = 512
31
    if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
32
        if im_w >= im_h:
33
            im_rh = ref_size
34
            im_rw = int(im_w / im_h * ref_size)
35
        elif im_w < im_h:
36
            im_rw = ref_size
37
            im_rh = int(im_h / im_w * ref_size)
38
    else:
39
        im_rh = im_h
40
        im_rw = im_w
41
    im_rw = im_rw - im_rw % 32
42
    im_rh = im_rh - im_rh % 32
43
    img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
44

45
    # inference
46
    _, _, matte = modnet(img_t, True)
47

48
    # resize and save matte
49
    matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
50
    matte = matte[0][0].data.cpu().numpy()
51
    cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
52

53
    # get foreground
54
    matte = matte[:, :, None]
55
    foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
56
    cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
57

58

59
if __name__ == '__main__':
60
    parser = argparse.ArgumentParser()
61
    parser.add_argument('--img_path', type=str, default='assets/test.jpg')
62
    parser.add_argument('--save_path', type=str, default='test_matting.png')
63
    args = parser.parse_args()
64

65
    main(args)
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.