8
from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus
9
from basicsr.data.data_util import read_img_seq
10
from basicsr.utils.img_util import tensor2img
13
def inference(imgs, imgnames, model, save_path):
17
outputs = outputs.squeeze()
18
outputs = list(outputs)
19
for output, imgname in zip(outputs, imgnames):
20
output = tensor2img(output)
21
cv2.imwrite(os.path.join(save_path, f'{imgname}_BasicVSRPP.png'), output)
25
parser = argparse.ArgumentParser()
26
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_REDS4.pth')
28
'--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder')
29
parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path')
30
parser.add_argument('--interval', type=int, default=100, help='interval size')
31
args = parser.parse_args()
33
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7)
37
model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
39
model = model.to(device)
41
os.makedirs(args.save_path, exist_ok=True)
44
input_path = args.input_path
46
if not os.path.isdir(input_path):
48
video_name = os.path.splitext(os.path.split(args.input_path)[-1])[0]
49
input_path = os.path.join('./BasicVSRPP_tmp', video_name)
50
os.makedirs(os.path.join('./BasicVSRPP_tmp', video_name), exist_ok=True)
51
os.system(f'ffmpeg -i {args.input_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {input_path} /frame%08d.png')
54
imgs_list = sorted(glob.glob(os.path.join(input_path, '*')))
55
num_imgs = len(imgs_list)
56
if len(imgs_list) <= args.interval:
57
imgs, imgnames = read_img_seq(imgs_list, return_imgname=True)
58
imgs = imgs.unsqueeze(0).to(device)
59
inference(imgs, imgnames, model, args.save_path)
61
for idx in range(0, num_imgs, args.interval):
62
interval = min(args.interval, num_imgs - idx)
63
imgs, imgnames = read_img_seq(imgs_list[idx:idx + interval], return_imgname=True)
64
imgs = imgs.unsqueeze(0).to(device)
65
inference(imgs, imgnames, model, args.save_path)
69
shutil.rmtree(input_path)
72
if __name__ == '__main__':