BasicSR

Форк
0
/
test_reds_dataset.py 
72 строки · 2.1 Кб
1
import math
2
import os
3
import torchvision.utils
4

5
from basicsr.data import build_dataloader, build_dataset
6

7

8
def main(mode='folder'):
9
    """Test reds dataset.
10

11
    Args:
12
        mode: There are two modes: 'lmdb', 'folder'.
13
    """
14
    opt = {}
15
    opt['dist'] = False
16
    opt['phase'] = 'train'
17

18
    opt['name'] = 'REDS'
19
    opt['type'] = 'REDSDataset'
20
    if mode == 'folder':
21
        opt['dataroot_gt'] = 'datasets/REDS/train_sharp'
22
        opt['dataroot_lq'] = 'datasets/REDS/train_sharp_bicubic'
23
        opt['dataroot_flow'] = None
24
        opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_REDS_GT.txt'
25
        opt['io_backend'] = dict(type='disk')
26
    elif mode == 'lmdb':
27
        opt['dataroot_gt'] = 'datasets/REDS/train_sharp_with_val.lmdb'
28
        opt['dataroot_lq'] = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb'
29
        opt['dataroot_flow'] = None
30
        opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_REDS_GT.txt'
31
        opt['io_backend'] = dict(type='lmdb')
32

33
    opt['val_partition'] = 'REDS4'
34
    opt['num_frame'] = 5
35
    opt['gt_size'] = 256
36
    opt['interval_list'] = [1]
37
    opt['random_reverse'] = True
38
    opt['use_hflip'] = True
39
    opt['use_rot'] = True
40

41
    opt['num_worker_per_gpu'] = 1
42
    opt['batch_size_per_gpu'] = 16
43
    opt['scale'] = 4
44

45
    opt['dataset_enlarge_ratio'] = 1
46

47
    os.makedirs('tmp', exist_ok=True)
48

49
    dataset = build_dataset(opt)
50
    data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
51

52
    nrow = int(math.sqrt(opt['batch_size_per_gpu']))
53
    padding = 2 if opt['phase'] == 'train' else 0
54

55
    print('start...')
56
    for i, data in enumerate(data_loader):
57
        if i > 5:
58
            break
59
        print(i)
60

61
        lq = data['lq']
62
        gt = data['gt']
63
        key = data['key']
64
        print(key)
65
        for j in range(opt['num_frame']):
66
            torchvision.utils.save_image(
67
                lq[:, j, :, :, :], f'tmp/lq_{i:03d}_frame{j}.png', nrow=nrow, padding=padding, normalize=False)
68
        torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
69

70

71
if __name__ == '__main__':
72
    main()
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.