BasicSR

Форк
0
/
video_test_dataset.py 
283 строки · 11.7 Кб
1
import glob
2
import torch
3
from os import path as osp
4
from torch.utils import data as data
5

6
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
7
from basicsr.utils import get_root_logger, scandir
8
from basicsr.utils.registry import DATASET_REGISTRY
9

10

11
@DATASET_REGISTRY.register()
12
class VideoTestDataset(data.Dataset):
13
    """Video test dataset.
14

15
    Supported datasets: Vid4, REDS4, REDSofficial.
16
    More generally, it supports testing dataset with following structures:
17

18
    ::
19

20
        dataroot
21
        ├── subfolder1
22
            ├── frame000
23
            ├── frame001
24
            ├── ...
25
        ├── subfolder2
26
            ├── frame000
27
            ├── frame001
28
            ├── ...
29
        ├── ...
30

31
    For testing datasets, there is no need to prepare LMDB files.
32

33
    Args:
34
        opt (dict): Config for train dataset. It contains the following keys:
35
        dataroot_gt (str): Data root path for gt.
36
        dataroot_lq (str): Data root path for lq.
37
        io_backend (dict): IO backend type and other kwarg.
38
        cache_data (bool): Whether to cache testing datasets.
39
        name (str): Dataset name.
40
        meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
41
            in the dataroot will be used.
42
        num_frame (int): Window size for input frames.
43
        padding (str): Padding mode.
44
    """
45

46
    def __init__(self, opt):
47
        super(VideoTestDataset, self).__init__()
48
        self.opt = opt
49
        self.cache_data = opt['cache_data']
50
        self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
51
        self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
52
        # file client (io backend)
53
        self.file_client = None
54
        self.io_backend_opt = opt['io_backend']
55
        assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
56

57
        logger = get_root_logger()
58
        logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
59
        self.imgs_lq, self.imgs_gt = {}, {}
60
        if 'meta_info_file' in opt:
61
            with open(opt['meta_info_file'], 'r') as fin:
62
                subfolders = [line.split(' ')[0] for line in fin]
63
                subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
64
                subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
65
        else:
66
            subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
67
            subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
68

69
        if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
70
            for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
71
                # get frame list for lq and gt
72
                subfolder_name = osp.basename(subfolder_lq)
73
                img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
74
                img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
75

76
                max_idx = len(img_paths_lq)
77
                assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
78
                                                      f' and gt folders ({len(img_paths_gt)})')
79

80
                self.data_info['lq_path'].extend(img_paths_lq)
81
                self.data_info['gt_path'].extend(img_paths_gt)
82
                self.data_info['folder'].extend([subfolder_name] * max_idx)
83
                for i in range(max_idx):
84
                    self.data_info['idx'].append(f'{i}/{max_idx}')
85
                border_l = [0] * max_idx
86
                for i in range(self.opt['num_frame'] // 2):
87
                    border_l[i] = 1
88
                    border_l[max_idx - i - 1] = 1
89
                self.data_info['border'].extend(border_l)
90

91
                # cache data or save the frame list
92
                if self.cache_data:
93
                    logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
94
                    self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
95
                    self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
96
                else:
97
                    self.imgs_lq[subfolder_name] = img_paths_lq
98
                    self.imgs_gt[subfolder_name] = img_paths_gt
99
        else:
100
            raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
101

102
    def __getitem__(self, index):
103
        folder = self.data_info['folder'][index]
104
        idx, max_idx = self.data_info['idx'][index].split('/')
105
        idx, max_idx = int(idx), int(max_idx)
106
        border = self.data_info['border'][index]
107
        lq_path = self.data_info['lq_path'][index]
108

109
        select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
110

111
        if self.cache_data:
112
            imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
113
            img_gt = self.imgs_gt[folder][idx]
114
        else:
115
            img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
116
            imgs_lq = read_img_seq(img_paths_lq)
117
            img_gt = read_img_seq([self.imgs_gt[folder][idx]])
118
            img_gt.squeeze_(0)
119

120
        return {
121
            'lq': imgs_lq,  # (t, c, h, w)
122
            'gt': img_gt,  # (c, h, w)
123
            'folder': folder,  # folder name
124
            'idx': self.data_info['idx'][index],  # e.g., 0/99
125
            'border': border,  # 1 for border, 0 for non-border
126
            'lq_path': lq_path  # center frame
127
        }
128

129
    def __len__(self):
130
        return len(self.data_info['gt_path'])
131

132

133
@DATASET_REGISTRY.register()
134
class VideoTestVimeo90KDataset(data.Dataset):
135
    """Video test dataset for Vimeo90k-Test dataset.
136

137
    It only keeps the center frame for testing.
138
    For testing datasets, there is no need to prepare LMDB files.
139

140
    Args:
141
        opt (dict): Config for train dataset. It contains the following keys:
142
        dataroot_gt (str): Data root path for gt.
143
        dataroot_lq (str): Data root path for lq.
144
        io_backend (dict): IO backend type and other kwarg.
145
        cache_data (bool): Whether to cache testing datasets.
146
        name (str): Dataset name.
147
        meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
148
            in the dataroot will be used.
149
        num_frame (int): Window size for input frames.
150
        padding (str): Padding mode.
151
    """
152

153
    def __init__(self, opt):
154
        super(VideoTestVimeo90KDataset, self).__init__()
155
        self.opt = opt
156
        self.cache_data = opt['cache_data']
157
        if self.cache_data:
158
            raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
159
        self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
160
        self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
161
        neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
162

163
        # file client (io backend)
164
        self.file_client = None
165
        self.io_backend_opt = opt['io_backend']
166
        assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
167

168
        logger = get_root_logger()
169
        logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
170
        with open(opt['meta_info_file'], 'r') as fin:
171
            subfolders = [line.split(' ')[0] for line in fin]
172
        for idx, subfolder in enumerate(subfolders):
173
            gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
174
            self.data_info['gt_path'].append(gt_path)
175
            lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
176
            self.data_info['lq_path'].append(lq_paths)
177
            self.data_info['folder'].append('vimeo90k')
178
            self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
179
            self.data_info['border'].append(0)
180

181
    def __getitem__(self, index):
182
        lq_path = self.data_info['lq_path'][index]
183
        gt_path = self.data_info['gt_path'][index]
184
        imgs_lq = read_img_seq(lq_path)
185
        img_gt = read_img_seq([gt_path])
186
        img_gt.squeeze_(0)
187

188
        return {
189
            'lq': imgs_lq,  # (t, c, h, w)
190
            'gt': img_gt,  # (c, h, w)
191
            'folder': self.data_info['folder'][index],  # folder name
192
            'idx': self.data_info['idx'][index],  # e.g., 0/843
193
            'border': self.data_info['border'][index],  # 0 for non-border
194
            'lq_path': lq_path[self.opt['num_frame'] // 2]  # center frame
195
        }
196

197
    def __len__(self):
198
        return len(self.data_info['gt_path'])
199

200

201
@DATASET_REGISTRY.register()
202
class VideoTestDUFDataset(VideoTestDataset):
203
    """ Video test dataset for DUF dataset.
204

205
    Args:
206
        opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
207
            It has the following extra keys:
208
        use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
209
        scale (bool): Scale, which will be added automatically.
210
    """
211

212
    def __getitem__(self, index):
213
        folder = self.data_info['folder'][index]
214
        idx, max_idx = self.data_info['idx'][index].split('/')
215
        idx, max_idx = int(idx), int(max_idx)
216
        border = self.data_info['border'][index]
217
        lq_path = self.data_info['lq_path'][index]
218

219
        select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
220

221
        if self.cache_data:
222
            if self.opt['use_duf_downsampling']:
223
                # read imgs_gt to generate low-resolution frames
224
                imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
225
                imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
226
            else:
227
                imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
228
            img_gt = self.imgs_gt[folder][idx]
229
        else:
230
            if self.opt['use_duf_downsampling']:
231
                img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
232
                # read imgs_gt to generate low-resolution frames
233
                imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
234
                imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
235
            else:
236
                img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
237
                imgs_lq = read_img_seq(img_paths_lq)
238
            img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
239
            img_gt.squeeze_(0)
240

241
        return {
242
            'lq': imgs_lq,  # (t, c, h, w)
243
            'gt': img_gt,  # (c, h, w)
244
            'folder': folder,  # folder name
245
            'idx': self.data_info['idx'][index],  # e.g., 0/99
246
            'border': border,  # 1 for border, 0 for non-border
247
            'lq_path': lq_path  # center frame
248
        }
249

250

251
@DATASET_REGISTRY.register()
252
class VideoRecurrentTestDataset(VideoTestDataset):
253
    """Video test dataset for recurrent architectures, which takes LR video
254
    frames as input and output corresponding HR video frames.
255

256
    Args:
257
        opt (dict): Same as VideoTestDataset. Unused opt:
258
        padding (str): Padding mode.
259

260
    """
261

262
    def __init__(self, opt):
263
        super(VideoRecurrentTestDataset, self).__init__(opt)
264
        # Find unique folder strings
265
        self.folders = sorted(list(set(self.data_info['folder'])))
266

267
    def __getitem__(self, index):
268
        folder = self.folders[index]
269

270
        if self.cache_data:
271
            imgs_lq = self.imgs_lq[folder]
272
            imgs_gt = self.imgs_gt[folder]
273
        else:
274
            raise NotImplementedError('Without cache_data is not implemented.')
275

276
        return {
277
            'lq': imgs_lq,
278
            'gt': imgs_gt,
279
            'folder': folder,
280
        }
281

282
    def __len__(self):
283
        return len(self.folders)
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.