4
from pathlib import Path
5
from torch.utils import data as data
7
from basicsr.data.transforms import augment, paired_random_crop
8
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
from basicsr.utils.flow_util import dequantize_flow
10
from basicsr.utils.registry import DATASET_REGISTRY
13
@DATASET_REGISTRY.register()
14
class REDSDataset(data.Dataset):
15
"""REDS dataset for training.
17
The keys are generated from a meta info txt file.
18
basicsr/data/meta_info/meta_info_REDS_GT.txt
21
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
28
Key examples: "000/00000000"
29
GT (gt): Ground-Truth;
30
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
33
opt (dict): Config for train dataset. It contains the following keys:
34
dataroot_gt (str): Data root path for gt.
35
dataroot_lq (str): Data root path for lq.
36
dataroot_flow (str, optional): Data root path for flow.
37
meta_info_file (str): Path for meta information file.
38
val_partition (str): Validation partition types. 'REDS4' or 'official'.
39
io_backend (dict): IO backend type and other kwarg.
40
num_frame (int): Window size for input frames.
41
gt_size (int): Cropped patched size for gt patches.
42
interval_list (list): Interval list for temporal augmentation.
43
random_reverse (bool): Random reverse input frames.
44
use_hflip (bool): Use horizontal flips.
45
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
46
scale (bool): Scale, which will be added automatically.
49
def __init__(self, opt):
50
super(REDSDataset, self).__init__()
52
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
53
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
54
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
55
self.num_frame = opt['num_frame']
56
self.num_half_frames = opt['num_frame'] // 2
59
with open(opt['meta_info_file'], 'r') as fin:
61
folder, frame_num, _ = line.split(' ')
62
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
65
if opt['val_partition'] == 'REDS4':
66
val_partition = ['000', '011', '015', '020']
67
elif opt['val_partition'] == 'official':
68
val_partition = [f'{v:03d}' for v in range(240, 270)]
70
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
71
f"Supported ones are ['official', 'REDS4'].")
72
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
75
self.file_client = None
76
self.io_backend_opt = opt['io_backend']
78
if self.io_backend_opt['type'] == 'lmdb':
80
if self.flow_root is not None:
81
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
82
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
84
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
85
self.io_backend_opt['client_keys'] = ['lq', 'gt']
88
self.interval_list = opt['interval_list']
89
self.random_reverse = opt['random_reverse']
90
interval_str = ','.join(str(x) for x in opt['interval_list'])
91
logger = get_root_logger()
92
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
93
f'random reverse is {self.random_reverse}.')
95
def __getitem__(self, index):
96
if self.file_client is None:
97
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
99
scale = self.opt['scale']
100
gt_size = self.opt['gt_size']
101
key = self.keys[index]
102
clip_name, frame_name = key.split('/')
103
center_frame_idx = int(frame_name)
106
interval = random.choice(self.interval_list)
109
start_frame_idx = center_frame_idx - self.num_half_frames * interval
110
end_frame_idx = center_frame_idx + self.num_half_frames * interval
112
while (start_frame_idx < 0) or (end_frame_idx > 99):
113
center_frame_idx = random.randint(0, 99)
114
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
115
end_frame_idx = center_frame_idx + self.num_half_frames * interval
116
frame_name = f'{center_frame_idx:08d}'
117
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
119
if self.random_reverse and random.random() < 0.5:
120
neighbor_list.reverse()
122
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
126
img_gt_path = f'{clip_name}/{frame_name}'
128
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
129
img_bytes = self.file_client.get(img_gt_path, 'gt')
130
img_gt = imfrombytes(img_bytes, float32=True)
134
for neighbor in neighbor_list:
136
img_lq_path = f'{clip_name}/{neighbor:08d}'
138
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
139
img_bytes = self.file_client.get(img_lq_path, 'lq')
140
img_lq = imfrombytes(img_bytes, float32=True)
141
img_lqs.append(img_lq)
144
if self.flow_root is not None:
147
for i in range(self.num_half_frames, 0, -1):
149
flow_path = f'{clip_name}/{frame_name}_p{i}'
151
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
152
img_bytes = self.file_client.get(flow_path, 'flow')
153
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)
154
dx, dy = np.split(cat_flow, 2, axis=0)
155
flow = dequantize_flow(dx, dy, max_val=20, denorm=False)
156
img_flows.append(flow)
158
for i in range(1, self.num_half_frames + 1):
160
flow_path = f'{clip_name}/{frame_name}_n{i}'
162
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
163
img_bytes = self.file_client.get(flow_path, 'flow')
164
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)
165
dx, dy = np.split(cat_flow, 2, axis=0)
166
flow = dequantize_flow(dx, dy, max_val=20, denorm=False)
167
img_flows.append(flow)
171
img_lqs.extend(img_flows)
174
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
175
if self.flow_root is not None:
176
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
179
img_lqs.append(img_gt)
180
if self.flow_root is not None:
181
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
183
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
185
img_results = img2tensor(img_results)
186
img_lqs = torch.stack(img_results[0:-1], dim=0)
187
img_gt = img_results[-1]
189
if self.flow_root is not None:
190
img_flows = img2tensor(img_flows)
192
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
193
img_flows = torch.stack(img_flows, dim=0)
199
if self.flow_root is not None:
200
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
202
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
205
return len(self.keys)
208
@DATASET_REGISTRY.register()
209
class REDSRecurrentDataset(data.Dataset):
210
"""REDS dataset for training recurrent networks.
212
The keys are generated from a meta info txt file.
213
basicsr/data/meta_info/meta_info_REDS_GT.txt
216
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
223
Key examples: "000/00000000"
224
GT (gt): Ground-Truth;
225
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
228
opt (dict): Config for train dataset. It contains the following keys:
229
dataroot_gt (str): Data root path for gt.
230
dataroot_lq (str): Data root path for lq.
231
dataroot_flow (str, optional): Data root path for flow.
232
meta_info_file (str): Path for meta information file.
233
val_partition (str): Validation partition types. 'REDS4' or 'official'.
234
io_backend (dict): IO backend type and other kwarg.
235
num_frame (int): Window size for input frames.
236
gt_size (int): Cropped patched size for gt patches.
237
interval_list (list): Interval list for temporal augmentation.
238
random_reverse (bool): Random reverse input frames.
239
use_hflip (bool): Use horizontal flips.
240
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
241
scale (bool): Scale, which will be added automatically.
244
def __init__(self, opt):
245
super(REDSRecurrentDataset, self).__init__()
247
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
248
self.num_frame = opt['num_frame']
251
with open(opt['meta_info_file'], 'r') as fin:
253
folder, frame_num, _ = line.split(' ')
254
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
257
if opt['val_partition'] == 'REDS4':
258
val_partition = ['000', '011', '015', '020']
259
elif opt['val_partition'] == 'official':
260
val_partition = [f'{v:03d}' for v in range(240, 270)]
262
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
263
f"Supported ones are ['official', 'REDS4'].")
265
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
267
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
270
self.file_client = None
271
self.io_backend_opt = opt['io_backend']
273
if self.io_backend_opt['type'] == 'lmdb':
275
if hasattr(self, 'flow_root') and self.flow_root is not None:
276
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
277
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
279
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
280
self.io_backend_opt['client_keys'] = ['lq', 'gt']
283
self.interval_list = opt.get('interval_list', [1])
284
self.random_reverse = opt.get('random_reverse', False)
285
interval_str = ','.join(str(x) for x in self.interval_list)
286
logger = get_root_logger()
287
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
288
f'random reverse is {self.random_reverse}.')
290
def __getitem__(self, index):
291
if self.file_client is None:
292
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
294
scale = self.opt['scale']
295
gt_size = self.opt['gt_size']
296
key = self.keys[index]
297
clip_name, frame_name = key.split('/')
300
interval = random.choice(self.interval_list)
303
start_frame_idx = int(frame_name)
304
if start_frame_idx > 100 - self.num_frame * interval:
305
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
306
end_frame_idx = start_frame_idx + self.num_frame * interval
308
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
311
if self.random_reverse and random.random() < 0.5:
312
neighbor_list.reverse()
317
for neighbor in neighbor_list:
319
img_lq_path = f'{clip_name}/{neighbor:08d}'
320
img_gt_path = f'{clip_name}/{neighbor:08d}'
322
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
323
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
326
img_bytes = self.file_client.get(img_lq_path, 'lq')
327
img_lq = imfrombytes(img_bytes, float32=True)
328
img_lqs.append(img_lq)
331
img_bytes = self.file_client.get(img_gt_path, 'gt')
332
img_gt = imfrombytes(img_bytes, float32=True)
333
img_gts.append(img_gt)
336
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
339
img_lqs.extend(img_gts)
340
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
342
img_results = img2tensor(img_results)
343
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
344
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
349
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
352
return len(self.keys)