Real-ESRGAN

Форк
0
/
test_dataset.py 
151 строка · 5.9 Кб
1
import pytest
2
import yaml
3

4
from realesrgan.data.realesrgan_dataset import RealESRGANDataset
5
from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
6

7

8
def test_realesrgan_dataset():
9

10
    with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
11
        opt = yaml.load(f, Loader=yaml.FullLoader)
12

13
    dataset = RealESRGANDataset(opt)
14
    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
15
    assert len(dataset) == 2  # whether to read correct meta info
16
    assert dataset.kernel_list == [
17
        'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
18
    ]  # correct initialization the degradation configurations
19
    assert dataset.betag_range2 == [0.5, 4]
20

21
    # test __getitem__
22
    result = dataset.__getitem__(0)
23
    # check returned keys
24
    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
25
    assert set(expected_keys).issubset(set(result.keys()))
26
    # check shape and contents
27
    assert result['gt'].shape == (3, 400, 400)
28
    assert result['kernel1'].shape == (21, 21)
29
    assert result['kernel2'].shape == (21, 21)
30
    assert result['sinc_kernel'].shape == (21, 21)
31
    assert result['gt_path'] == 'tests/data/gt/baboon.png'
32

33
    # ------------------ test lmdb backend -------------------- #
34
    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
35
    opt['io_backend']['type'] = 'lmdb'
36

37
    dataset = RealESRGANDataset(opt)
38
    assert dataset.io_backend_opt['type'] == 'lmdb'  # io backend
39
    assert len(dataset.paths) == 2  # whether to read correct meta info
40
    assert dataset.kernel_list == [
41
        'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
42
    ]  # correct initialization the degradation configurations
43
    assert dataset.betag_range2 == [0.5, 4]
44

45
    # test __getitem__
46
    result = dataset.__getitem__(1)
47
    # check returned keys
48
    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
49
    assert set(expected_keys).issubset(set(result.keys()))
50
    # check shape and contents
51
    assert result['gt'].shape == (3, 400, 400)
52
    assert result['kernel1'].shape == (21, 21)
53
    assert result['kernel2'].shape == (21, 21)
54
    assert result['sinc_kernel'].shape == (21, 21)
55
    assert result['gt_path'] == 'comic'
56

57
    # ------------------ test with sinc_prob = 0 -------------------- #
58
    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
59
    opt['io_backend']['type'] = 'lmdb'
60
    opt['sinc_prob'] = 0
61
    opt['sinc_prob2'] = 0
62
    opt['final_sinc_prob'] = 0
63
    dataset = RealESRGANDataset(opt)
64
    result = dataset.__getitem__(0)
65
    # check returned keys
66
    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
67
    assert set(expected_keys).issubset(set(result.keys()))
68
    # check shape and contents
69
    assert result['gt'].shape == (3, 400, 400)
70
    assert result['kernel1'].shape == (21, 21)
71
    assert result['kernel2'].shape == (21, 21)
72
    assert result['sinc_kernel'].shape == (21, 21)
73
    assert result['gt_path'] == 'baboon'
74

75
    # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
76
    with pytest.raises(ValueError):
77
        opt['dataroot_gt'] = 'tests/data/gt'
78
        opt['io_backend']['type'] = 'lmdb'
79
        dataset = RealESRGANDataset(opt)
80

81

82
def test_realesrgan_paired_dataset():
83

84
    with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
85
        opt = yaml.load(f, Loader=yaml.FullLoader)
86

87
    dataset = RealESRGANPairedDataset(opt)
88
    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
89
    assert len(dataset) == 2  # whether to read correct meta info
90

91
    # test __getitem__
92
    result = dataset.__getitem__(0)
93
    # check returned keys
94
    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
95
    assert set(expected_keys).issubset(set(result.keys()))
96
    # check shape and contents
97
    assert result['gt'].shape == (3, 128, 128)
98
    assert result['lq'].shape == (3, 32, 32)
99
    assert result['gt_path'] == 'tests/data/gt/baboon.png'
100
    assert result['lq_path'] == 'tests/data/lq/baboon.png'
101

102
    # ------------------ test lmdb backend -------------------- #
103
    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
104
    opt['dataroot_lq'] = 'tests/data/lq.lmdb'
105
    opt['io_backend']['type'] = 'lmdb'
106

107
    dataset = RealESRGANPairedDataset(opt)
108
    assert dataset.io_backend_opt['type'] == 'lmdb'  # io backend
109
    assert len(dataset) == 2  # whether to read correct meta info
110

111
    # test __getitem__
112
    result = dataset.__getitem__(1)
113
    # check returned keys
114
    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
115
    assert set(expected_keys).issubset(set(result.keys()))
116
    # check shape and contents
117
    assert result['gt'].shape == (3, 128, 128)
118
    assert result['lq'].shape == (3, 32, 32)
119
    assert result['gt_path'] == 'comic'
120
    assert result['lq_path'] == 'comic'
121

122
    # ------------------ test paired_paths_from_folder -------------------- #
123
    opt['dataroot_gt'] = 'tests/data/gt'
124
    opt['dataroot_lq'] = 'tests/data/lq'
125
    opt['io_backend'] = dict(type='disk')
126
    opt['meta_info'] = None
127

128
    dataset = RealESRGANPairedDataset(opt)
129
    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
130
    assert len(dataset) == 2  # whether to read correct meta info
131

132
    # test __getitem__
133
    result = dataset.__getitem__(0)
134
    # check returned keys
135
    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
136
    assert set(expected_keys).issubset(set(result.keys()))
137
    # check shape and contents
138
    assert result['gt'].shape == (3, 128, 128)
139
    assert result['lq'].shape == (3, 32, 32)
140

141
    # ------------------ test normalization -------------------- #
142
    dataset.mean = [0.5, 0.5, 0.5]
143
    dataset.std = [0.5, 0.5, 0.5]
144
    # test __getitem__
145
    result = dataset.__getitem__(0)
146
    # check returned keys
147
    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
148
    assert set(expected_keys).issubset(set(result.keys()))
149
    # check shape and contents
150
    assert result['gt'].shape == (3, 128, 128)
151
    assert result['lq'].shape == (3, 32, 32)
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.