BasicSR
70 строк · 2.2 Кб
1import math
2import os
3import torchvision.utils
4
5from basicsr.data import build_dataloader, build_dataset
6
7
8def main(mode='folder'):
9"""Test paired image dataset.
10
11Args:
12mode: There are three modes: 'lmdb', 'folder', 'meta_info_file'.
13"""
14opt = {}
15opt['dist'] = False
16opt['phase'] = 'train'
17
18opt['name'] = 'DIV2K'
19opt['type'] = 'PairedImageDataset'
20if mode == 'folder':
21opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
22opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
23opt['filename_tmpl'] = '{}'
24opt['io_backend'] = dict(type='disk')
25elif mode == 'meta_info_file':
26opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
27opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
28opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' # noqa:E501
29opt['filename_tmpl'] = '{}'
30opt['io_backend'] = dict(type='disk')
31elif mode == 'lmdb':
32opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
33opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb' # noqa:E501
34opt['io_backend'] = dict(type='lmdb')
35
36opt['gt_size'] = 128
37opt['use_hflip'] = True
38opt['use_rot'] = True
39
40opt['num_worker_per_gpu'] = 2
41opt['batch_size_per_gpu'] = 16
42opt['scale'] = 4
43
44opt['dataset_enlarge_ratio'] = 1
45
46os.makedirs('tmp', exist_ok=True)
47
48dataset = build_dataset(opt)
49data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
50
51nrow = int(math.sqrt(opt['batch_size_per_gpu']))
52padding = 2 if opt['phase'] == 'train' else 0
53
54print('start...')
55for i, data in enumerate(data_loader):
56if i > 5:
57break
58print(i)
59
60lq = data['lq']
61gt = data['gt']
62lq_path = data['lq_path']
63gt_path = data['gt_path']
64print(lq_path, gt_path)
65torchvision.utils.save_image(lq, f'tmp/lq_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
66torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
67
68
69if __name__ == '__main__':
70main()
71