BasicSR

Форк
0
/
test_paired_image_dataset.py 
70 строк · 2.2 Кб
1
import math
2
import os
3
import torchvision.utils
4

5
from basicsr.data import build_dataloader, build_dataset
6

7

8
def main(mode='folder'):
9
    """Test paired image dataset.
10

11
    Args:
12
        mode: There are three modes: 'lmdb', 'folder', 'meta_info_file'.
13
    """
14
    opt = {}
15
    opt['dist'] = False
16
    opt['phase'] = 'train'
17

18
    opt['name'] = 'DIV2K'
19
    opt['type'] = 'PairedImageDataset'
20
    if mode == 'folder':
21
        opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
22
        opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
23
        opt['filename_tmpl'] = '{}'
24
        opt['io_backend'] = dict(type='disk')
25
    elif mode == 'meta_info_file':
26
        opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
27
        opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
28
        opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'  # noqa:E501
29
        opt['filename_tmpl'] = '{}'
30
        opt['io_backend'] = dict(type='disk')
31
    elif mode == 'lmdb':
32
        opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
33
        opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb'  # noqa:E501
34
        opt['io_backend'] = dict(type='lmdb')
35

36
    opt['gt_size'] = 128
37
    opt['use_hflip'] = True
38
    opt['use_rot'] = True
39

40
    opt['num_worker_per_gpu'] = 2
41
    opt['batch_size_per_gpu'] = 16
42
    opt['scale'] = 4
43

44
    opt['dataset_enlarge_ratio'] = 1
45

46
    os.makedirs('tmp', exist_ok=True)
47

48
    dataset = build_dataset(opt)
49
    data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
50

51
    nrow = int(math.sqrt(opt['batch_size_per_gpu']))
52
    padding = 2 if opt['phase'] == 'train' else 0
53

54
    print('start...')
55
    for i, data in enumerate(data_loader):
56
        if i > 5:
57
            break
58
        print(i)
59

60
        lq = data['lq']
61
        gt = data['gt']
62
        lq_path = data['lq_path']
63
        gt_path = data['gt_path']
64
        print(lq_path, gt_path)
65
        torchvision.utils.save_image(lq, f'tmp/lq_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
66
        torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
67

68

69
if __name__ == '__main__':
70
    main()
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.