BasicSR

Форк
0
/
inference_basicvsrpp.py 
73 строки · 2.7 Кб
1
import argparse
2
import cv2
3
import glob
4
import os
5
import shutil
6
import torch
7

8
from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus
9
from basicsr.data.data_util import read_img_seq
10
from basicsr.utils.img_util import tensor2img
11

12

13
def inference(imgs, imgnames, model, save_path):
14
    with torch.no_grad():
15
        outputs = model(imgs)
16
    # save imgs
17
    outputs = outputs.squeeze()
18
    outputs = list(outputs)
19
    for output, imgname in zip(outputs, imgnames):
20
        output = tensor2img(output)
21
        cv2.imwrite(os.path.join(save_path, f'{imgname}_BasicVSRPP.png'), output)
22

23

24
def main():
25
    parser = argparse.ArgumentParser()
26
    parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_REDS4.pth')
27
    parser.add_argument(
28
        '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder')
29
    parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path')
30
    parser.add_argument('--interval', type=int, default=100, help='interval size')
31
    args = parser.parse_args()
32

33
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34

35
    # set up model
36
    model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7)
37
    model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
38
    model.eval()
39
    model = model.to(device)
40

41
    os.makedirs(args.save_path, exist_ok=True)
42

43
    # extract images from video format files
44
    input_path = args.input_path
45
    use_ffmpeg = False
46
    if not os.path.isdir(input_path):
47
        use_ffmpeg = True
48
        video_name = os.path.splitext(os.path.split(args.input_path)[-1])[0]
49
        input_path = os.path.join('./BasicVSRPP_tmp', video_name)
50
        os.makedirs(os.path.join('./BasicVSRPP_tmp', video_name), exist_ok=True)
51
        os.system(f'ffmpeg -i {args.input_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0  {input_path} /frame%08d.png')
52

53
    # load data and inference
54
    imgs_list = sorted(glob.glob(os.path.join(input_path, '*')))
55
    num_imgs = len(imgs_list)
56
    if len(imgs_list) <= args.interval:  # too many images may cause CUDA out of memory
57
        imgs, imgnames = read_img_seq(imgs_list, return_imgname=True)
58
        imgs = imgs.unsqueeze(0).to(device)
59
        inference(imgs, imgnames, model, args.save_path)
60
    else:
61
        for idx in range(0, num_imgs, args.interval):
62
            interval = min(args.interval, num_imgs - idx)
63
            imgs, imgnames = read_img_seq(imgs_list[idx:idx + interval], return_imgname=True)
64
            imgs = imgs.unsqueeze(0).to(device)
65
            inference(imgs, imgnames, model, args.save_path)
66

67
    # delete ffmpeg output images
68
    if use_ffmpeg:
69
        shutil.rmtree(input_path)
70

71

72
if __name__ == '__main__':
73
    main()
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.