BasicSR

Форк
0
/
reds_dataset.py 
352 строки · 14.8 Кб
1
import numpy as np
2
import random
3
import torch
4
from pathlib import Path
5
from torch.utils import data as data
6

7
from basicsr.data.transforms import augment, paired_random_crop
8
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
from basicsr.utils.flow_util import dequantize_flow
10
from basicsr.utils.registry import DATASET_REGISTRY
11

12

13
@DATASET_REGISTRY.register()
14
class REDSDataset(data.Dataset):
15
    """REDS dataset for training.
16

17
    The keys are generated from a meta info txt file.
18
    basicsr/data/meta_info/meta_info_REDS_GT.txt
19

20
    Each line contains:
21
    1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
22
    a white space.
23
    Examples:
24
    000 100 (720,1280,3)
25
    001 100 (720,1280,3)
26
    ...
27

28
    Key examples: "000/00000000"
29
    GT (gt): Ground-Truth;
30
    LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
31

32
    Args:
33
        opt (dict): Config for train dataset. It contains the following keys:
34
        dataroot_gt (str): Data root path for gt.
35
        dataroot_lq (str): Data root path for lq.
36
        dataroot_flow (str, optional): Data root path for flow.
37
        meta_info_file (str): Path for meta information file.
38
        val_partition (str): Validation partition types. 'REDS4' or 'official'.
39
        io_backend (dict): IO backend type and other kwarg.
40
        num_frame (int): Window size for input frames.
41
        gt_size (int): Cropped patched size for gt patches.
42
        interval_list (list): Interval list for temporal augmentation.
43
        random_reverse (bool): Random reverse input frames.
44
        use_hflip (bool): Use horizontal flips.
45
        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
46
        scale (bool): Scale, which will be added automatically.
47
    """
48

49
    def __init__(self, opt):
50
        super(REDSDataset, self).__init__()
51
        self.opt = opt
52
        self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
53
        self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
54
        assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
55
        self.num_frame = opt['num_frame']
56
        self.num_half_frames = opt['num_frame'] // 2
57

58
        self.keys = []
59
        with open(opt['meta_info_file'], 'r') as fin:
60
            for line in fin:
61
                folder, frame_num, _ = line.split(' ')
62
                self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
63

64
        # remove the video clips used in validation
65
        if opt['val_partition'] == 'REDS4':
66
            val_partition = ['000', '011', '015', '020']
67
        elif opt['val_partition'] == 'official':
68
            val_partition = [f'{v:03d}' for v in range(240, 270)]
69
        else:
70
            raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
71
                             f"Supported ones are ['official', 'REDS4'].")
72
        self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
73

74
        # file client (io backend)
75
        self.file_client = None
76
        self.io_backend_opt = opt['io_backend']
77
        self.is_lmdb = False
78
        if self.io_backend_opt['type'] == 'lmdb':
79
            self.is_lmdb = True
80
            if self.flow_root is not None:
81
                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
82
                self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
83
            else:
84
                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
85
                self.io_backend_opt['client_keys'] = ['lq', 'gt']
86

87
        # temporal augmentation configs
88
        self.interval_list = opt['interval_list']
89
        self.random_reverse = opt['random_reverse']
90
        interval_str = ','.join(str(x) for x in opt['interval_list'])
91
        logger = get_root_logger()
92
        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
93
                    f'random reverse is {self.random_reverse}.')
94

95
    def __getitem__(self, index):
96
        if self.file_client is None:
97
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
98

99
        scale = self.opt['scale']
100
        gt_size = self.opt['gt_size']
101
        key = self.keys[index]
102
        clip_name, frame_name = key.split('/')  # key example: 000/00000000
103
        center_frame_idx = int(frame_name)
104

105
        # determine the neighboring frames
106
        interval = random.choice(self.interval_list)
107

108
        # ensure not exceeding the borders
109
        start_frame_idx = center_frame_idx - self.num_half_frames * interval
110
        end_frame_idx = center_frame_idx + self.num_half_frames * interval
111
        # each clip has 100 frames starting from 0 to 99
112
        while (start_frame_idx < 0) or (end_frame_idx > 99):
113
            center_frame_idx = random.randint(0, 99)
114
            start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
115
            end_frame_idx = center_frame_idx + self.num_half_frames * interval
116
        frame_name = f'{center_frame_idx:08d}'
117
        neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
118
        # random reverse
119
        if self.random_reverse and random.random() < 0.5:
120
            neighbor_list.reverse()
121

122
        assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
123

124
        # get the GT frame (as the center frame)
125
        if self.is_lmdb:
126
            img_gt_path = f'{clip_name}/{frame_name}'
127
        else:
128
            img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
129
        img_bytes = self.file_client.get(img_gt_path, 'gt')
130
        img_gt = imfrombytes(img_bytes, float32=True)
131

132
        # get the neighboring LQ frames
133
        img_lqs = []
134
        for neighbor in neighbor_list:
135
            if self.is_lmdb:
136
                img_lq_path = f'{clip_name}/{neighbor:08d}'
137
            else:
138
                img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
139
            img_bytes = self.file_client.get(img_lq_path, 'lq')
140
            img_lq = imfrombytes(img_bytes, float32=True)
141
            img_lqs.append(img_lq)
142

143
        # get flows
144
        if self.flow_root is not None:
145
            img_flows = []
146
            # read previous flows
147
            for i in range(self.num_half_frames, 0, -1):
148
                if self.is_lmdb:
149
                    flow_path = f'{clip_name}/{frame_name}_p{i}'
150
                else:
151
                    flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
152
                img_bytes = self.file_client.get(flow_path, 'flow')
153
                cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)  # uint8, [0, 255]
154
                dx, dy = np.split(cat_flow, 2, axis=0)
155
                flow = dequantize_flow(dx, dy, max_val=20, denorm=False)  # we use max_val 20 here.
156
                img_flows.append(flow)
157
            # read next flows
158
            for i in range(1, self.num_half_frames + 1):
159
                if self.is_lmdb:
160
                    flow_path = f'{clip_name}/{frame_name}_n{i}'
161
                else:
162
                    flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
163
                img_bytes = self.file_client.get(flow_path, 'flow')
164
                cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)  # uint8, [0, 255]
165
                dx, dy = np.split(cat_flow, 2, axis=0)
166
                flow = dequantize_flow(dx, dy, max_val=20, denorm=False)  # we use max_val 20 here.
167
                img_flows.append(flow)
168

169
            # for random crop, here, img_flows and img_lqs have the same
170
            # spatial size
171
            img_lqs.extend(img_flows)
172

173
        # randomly crop
174
        img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
175
        if self.flow_root is not None:
176
            img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
177

178
        # augmentation - flip, rotate
179
        img_lqs.append(img_gt)
180
        if self.flow_root is not None:
181
            img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
182
        else:
183
            img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184

185
        img_results = img2tensor(img_results)
186
        img_lqs = torch.stack(img_results[0:-1], dim=0)
187
        img_gt = img_results[-1]
188

189
        if self.flow_root is not None:
190
            img_flows = img2tensor(img_flows)
191
            # add the zero center flow
192
            img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
193
            img_flows = torch.stack(img_flows, dim=0)
194

195
        # img_lqs: (t, c, h, w)
196
        # img_flows: (t, 2, h, w)
197
        # img_gt: (c, h, w)
198
        # key: str
199
        if self.flow_root is not None:
200
            return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
201
        else:
202
            return {'lq': img_lqs, 'gt': img_gt, 'key': key}
203

204
    def __len__(self):
205
        return len(self.keys)
206

207

208
@DATASET_REGISTRY.register()
209
class REDSRecurrentDataset(data.Dataset):
210
    """REDS dataset for training recurrent networks.
211

212
    The keys are generated from a meta info txt file.
213
    basicsr/data/meta_info/meta_info_REDS_GT.txt
214

215
    Each line contains:
216
    1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
217
    a white space.
218
    Examples:
219
    000 100 (720,1280,3)
220
    001 100 (720,1280,3)
221
    ...
222

223
    Key examples: "000/00000000"
224
    GT (gt): Ground-Truth;
225
    LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
226

227
    Args:
228
        opt (dict): Config for train dataset. It contains the following keys:
229
        dataroot_gt (str): Data root path for gt.
230
        dataroot_lq (str): Data root path for lq.
231
        dataroot_flow (str, optional): Data root path for flow.
232
        meta_info_file (str): Path for meta information file.
233
        val_partition (str): Validation partition types. 'REDS4' or 'official'.
234
        io_backend (dict): IO backend type and other kwarg.
235
        num_frame (int): Window size for input frames.
236
        gt_size (int): Cropped patched size for gt patches.
237
        interval_list (list): Interval list for temporal augmentation.
238
        random_reverse (bool): Random reverse input frames.
239
        use_hflip (bool): Use horizontal flips.
240
        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
241
        scale (bool): Scale, which will be added automatically.
242
    """
243

244
    def __init__(self, opt):
245
        super(REDSRecurrentDataset, self).__init__()
246
        self.opt = opt
247
        self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
248
        self.num_frame = opt['num_frame']
249

250
        self.keys = []
251
        with open(opt['meta_info_file'], 'r') as fin:
252
            for line in fin:
253
                folder, frame_num, _ = line.split(' ')
254
                self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
255

256
        # remove the video clips used in validation
257
        if opt['val_partition'] == 'REDS4':
258
            val_partition = ['000', '011', '015', '020']
259
        elif opt['val_partition'] == 'official':
260
            val_partition = [f'{v:03d}' for v in range(240, 270)]
261
        else:
262
            raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
263
                             f"Supported ones are ['official', 'REDS4'].")
264
        if opt['test_mode']:
265
            self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
266
        else:
267
            self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
268

269
        # file client (io backend)
270
        self.file_client = None
271
        self.io_backend_opt = opt['io_backend']
272
        self.is_lmdb = False
273
        if self.io_backend_opt['type'] == 'lmdb':
274
            self.is_lmdb = True
275
            if hasattr(self, 'flow_root') and self.flow_root is not None:
276
                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
277
                self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
278
            else:
279
                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
280
                self.io_backend_opt['client_keys'] = ['lq', 'gt']
281

282
        # temporal augmentation configs
283
        self.interval_list = opt.get('interval_list', [1])
284
        self.random_reverse = opt.get('random_reverse', False)
285
        interval_str = ','.join(str(x) for x in self.interval_list)
286
        logger = get_root_logger()
287
        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
288
                    f'random reverse is {self.random_reverse}.')
289

290
    def __getitem__(self, index):
291
        if self.file_client is None:
292
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
293

294
        scale = self.opt['scale']
295
        gt_size = self.opt['gt_size']
296
        key = self.keys[index]
297
        clip_name, frame_name = key.split('/')  # key example: 000/00000000
298

299
        # determine the neighboring frames
300
        interval = random.choice(self.interval_list)
301

302
        # ensure not exceeding the borders
303
        start_frame_idx = int(frame_name)
304
        if start_frame_idx > 100 - self.num_frame * interval:
305
            start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
306
        end_frame_idx = start_frame_idx + self.num_frame * interval
307

308
        neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
309

310
        # random reverse
311
        if self.random_reverse and random.random() < 0.5:
312
            neighbor_list.reverse()
313

314
        # get the neighboring LQ and GT frames
315
        img_lqs = []
316
        img_gts = []
317
        for neighbor in neighbor_list:
318
            if self.is_lmdb:
319
                img_lq_path = f'{clip_name}/{neighbor:08d}'
320
                img_gt_path = f'{clip_name}/{neighbor:08d}'
321
            else:
322
                img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
323
                img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
324

325
            # get LQ
326
            img_bytes = self.file_client.get(img_lq_path, 'lq')
327
            img_lq = imfrombytes(img_bytes, float32=True)
328
            img_lqs.append(img_lq)
329

330
            # get GT
331
            img_bytes = self.file_client.get(img_gt_path, 'gt')
332
            img_gt = imfrombytes(img_bytes, float32=True)
333
            img_gts.append(img_gt)
334

335
        # randomly crop
336
        img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
337

338
        # augmentation - flip, rotate
339
        img_lqs.extend(img_gts)
340
        img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
341

342
        img_results = img2tensor(img_results)
343
        img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
344
        img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
345

346
        # img_lqs: (t, c, h, w)
347
        # img_gts: (t, c, h, w)
348
        # key: str
349
        return {'lq': img_lqs, 'gt': img_gts, 'key': key}
350

351
    def __len__(self):
352
        return len(self.keys)
353

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.