4
import torch.nn.functional as F
5
from torchvision.transforms.functional import normalize
7
from facexlib.matting import init_matting_model
8
from facexlib.utils import img2tensor
12
modnet = init_matting_model()
15
img = cv2.imread(args.img_path) / 255.
16
# unify image channels to 3
17
if len(img.shape) == 2:
20
img = np.repeat(img, 3, axis=2)
21
elif img.shape[2] == 4:
24
img_t = img2tensor(img, bgr2rgb=True, float32=True)
25
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
26
img_t = img_t.unsqueeze(0).cuda()
28
# resize image for input
29
_, _, im_h, im_w = img_t.shape
31
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
34
im_rw = int(im_w / im_h * ref_size)
37
im_rh = int(im_h / im_w * ref_size)
41
im_rw = im_rw - im_rw % 32
42
im_rh = im_rh - im_rh % 32
43
img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
46
_, _, matte = modnet(img_t, True)
48
# resize and save matte
49
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
50
matte = matte[0][0].data.cpu().numpy()
51
cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
54
matte = matte[:, :, None]
55
foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
56
cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
59
if __name__ == '__main__':
60
parser = argparse.ArgumentParser()
61
parser.add_argument('--img_path', type=str, default='assets/test.jpg')
62
parser.add_argument('--save_path', type=str, default='test_matting.png')
63
args = parser.parse_args()