3
from os import path as osp
4
from torch.utils import data as data
6
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
7
from basicsr.utils import get_root_logger, scandir
8
from basicsr.utils.registry import DATASET_REGISTRY
11
@DATASET_REGISTRY.register()
12
class VideoTestDataset(data.Dataset):
13
"""Video test dataset.
15
Supported datasets: Vid4, REDS4, REDSofficial.
16
More generally, it supports testing dataset with following structures:
31
For testing datasets, there is no need to prepare LMDB files.
34
opt (dict): Config for train dataset. It contains the following keys:
35
dataroot_gt (str): Data root path for gt.
36
dataroot_lq (str): Data root path for lq.
37
io_backend (dict): IO backend type and other kwarg.
38
cache_data (bool): Whether to cache testing datasets.
39
name (str): Dataset name.
40
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
41
in the dataroot will be used.
42
num_frame (int): Window size for input frames.
43
padding (str): Padding mode.
46
def __init__(self, opt):
47
super(VideoTestDataset, self).__init__()
49
self.cache_data = opt['cache_data']
50
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
51
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
53
self.file_client = None
54
self.io_backend_opt = opt['io_backend']
55
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
57
logger = get_root_logger()
58
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
59
self.imgs_lq, self.imgs_gt = {}, {}
60
if 'meta_info_file' in opt:
61
with open(opt['meta_info_file'], 'r') as fin:
62
subfolders = [line.split(' ')[0] for line in fin]
63
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
64
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
66
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
67
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
69
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
70
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
72
subfolder_name = osp.basename(subfolder_lq)
73
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
74
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
76
max_idx = len(img_paths_lq)
77
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
78
f' and gt folders ({len(img_paths_gt)})')
80
self.data_info['lq_path'].extend(img_paths_lq)
81
self.data_info['gt_path'].extend(img_paths_gt)
82
self.data_info['folder'].extend([subfolder_name] * max_idx)
83
for i in range(max_idx):
84
self.data_info['idx'].append(f'{i}/{max_idx}')
85
border_l = [0] * max_idx
86
for i in range(self.opt['num_frame'] // 2):
88
border_l[max_idx - i - 1] = 1
89
self.data_info['border'].extend(border_l)
93
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
94
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
95
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
97
self.imgs_lq[subfolder_name] = img_paths_lq
98
self.imgs_gt[subfolder_name] = img_paths_gt
100
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
102
def __getitem__(self, index):
103
folder = self.data_info['folder'][index]
104
idx, max_idx = self.data_info['idx'][index].split('/')
105
idx, max_idx = int(idx), int(max_idx)
106
border = self.data_info['border'][index]
107
lq_path = self.data_info['lq_path'][index]
109
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
112
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
113
img_gt = self.imgs_gt[folder][idx]
115
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
116
imgs_lq = read_img_seq(img_paths_lq)
117
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
124
'idx': self.data_info['idx'][index],
130
return len(self.data_info['gt_path'])
133
@DATASET_REGISTRY.register()
134
class VideoTestVimeo90KDataset(data.Dataset):
135
"""Video test dataset for Vimeo90k-Test dataset.
137
It only keeps the center frame for testing.
138
For testing datasets, there is no need to prepare LMDB files.
141
opt (dict): Config for train dataset. It contains the following keys:
142
dataroot_gt (str): Data root path for gt.
143
dataroot_lq (str): Data root path for lq.
144
io_backend (dict): IO backend type and other kwarg.
145
cache_data (bool): Whether to cache testing datasets.
146
name (str): Dataset name.
147
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
148
in the dataroot will be used.
149
num_frame (int): Window size for input frames.
150
padding (str): Padding mode.
153
def __init__(self, opt):
154
super(VideoTestVimeo90KDataset, self).__init__()
156
self.cache_data = opt['cache_data']
158
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
159
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
160
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
161
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
164
self.file_client = None
165
self.io_backend_opt = opt['io_backend']
166
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
168
logger = get_root_logger()
169
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
170
with open(opt['meta_info_file'], 'r') as fin:
171
subfolders = [line.split(' ')[0] for line in fin]
172
for idx, subfolder in enumerate(subfolders):
173
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
174
self.data_info['gt_path'].append(gt_path)
175
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
176
self.data_info['lq_path'].append(lq_paths)
177
self.data_info['folder'].append('vimeo90k')
178
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
179
self.data_info['border'].append(0)
181
def __getitem__(self, index):
182
lq_path = self.data_info['lq_path'][index]
183
gt_path = self.data_info['gt_path'][index]
184
imgs_lq = read_img_seq(lq_path)
185
img_gt = read_img_seq([gt_path])
191
'folder': self.data_info['folder'][index],
192
'idx': self.data_info['idx'][index],
193
'border': self.data_info['border'][index],
194
'lq_path': lq_path[self.opt['num_frame'] // 2]
198
return len(self.data_info['gt_path'])
201
@DATASET_REGISTRY.register()
202
class VideoTestDUFDataset(VideoTestDataset):
203
""" Video test dataset for DUF dataset.
206
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
207
It has the following extra keys:
208
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
209
scale (bool): Scale, which will be added automatically.
212
def __getitem__(self, index):
213
folder = self.data_info['folder'][index]
214
idx, max_idx = self.data_info['idx'][index].split('/')
215
idx, max_idx = int(idx), int(max_idx)
216
border = self.data_info['border'][index]
217
lq_path = self.data_info['lq_path'][index]
219
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
222
if self.opt['use_duf_downsampling']:
224
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
225
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
227
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
228
img_gt = self.imgs_gt[folder][idx]
230
if self.opt['use_duf_downsampling']:
231
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
233
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
234
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
236
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
237
imgs_lq = read_img_seq(img_paths_lq)
238
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
245
'idx': self.data_info['idx'][index],
251
@DATASET_REGISTRY.register()
252
class VideoRecurrentTestDataset(VideoTestDataset):
253
"""Video test dataset for recurrent architectures, which takes LR video
254
frames as input and output corresponding HR video frames.
257
opt (dict): Same as VideoTestDataset. Unused opt:
258
padding (str): Padding mode.
262
def __init__(self, opt):
263
super(VideoRecurrentTestDataset, self).__init__(opt)
265
self.folders = sorted(list(set(self.data_info['folder'])))
267
def __getitem__(self, index):
268
folder = self.folders[index]
271
imgs_lq = self.imgs_lq[folder]
272
imgs_gt = self.imgs_gt[folder]
274
raise NotImplementedError('Without cache_data is not implemented.')
283
return len(self.folders)