BasicSR

Форк
0
/
data_util.py 
315 строк · 11.5 Кб
1
import cv2
2
import numpy as np
3
import torch
4
from os import path as osp
5
from torch.nn import functional as F
6

7
from basicsr.data.transforms import mod_crop
8
from basicsr.utils import img2tensor, scandir
9

10

11
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
    """Read a sequence of images from a given folder path.
13

14
    Args:
15
        path (list[str] | str): List of image paths or image folder path.
16
        require_mod_crop (bool): Require mod crop for each image.
17
            Default: False.
18
        scale (int): Scale factor for mod_crop. Default: 1.
19
        return_imgname(bool): Whether return image names. Default False.
20

21
    Returns:
22
        Tensor: size (t, c, h, w), RGB, [0, 1].
23
        list[str]: Returned image name list.
24
    """
25
    if isinstance(path, list):
26
        img_paths = path
27
    else:
28
        img_paths = sorted(list(scandir(path, full_path=True)))
29
    imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30

31
    if require_mod_crop:
32
        imgs = [mod_crop(img, scale) for img in imgs]
33
    imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
    imgs = torch.stack(imgs, dim=0)
35

36
    if return_imgname:
37
        imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
        return imgs, imgnames
39
    else:
40
        return imgs
41

42

43
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
    """Generate an index list for reading `num_frames` frames from a sequence
45
    of images.
46

47
    Args:
48
        crt_idx (int): Current center index.
49
        max_frame_num (int): Max number of the sequence of images (from 1).
50
        num_frames (int): Reading num_frames frames.
51
        padding (str): Padding mode, one of
52
            'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
            Examples: current_idx = 0, num_frames = 5
54
            The generated frame indices under different padding mode:
55
            replicate: [0, 0, 0, 1, 2]
56
            reflection: [2, 1, 0, 1, 2]
57
            reflection_circle: [4, 3, 0, 1, 2]
58
            circle: [3, 4, 0, 1, 2]
59

60
    Returns:
61
        list[int]: A list of indices.
62
    """
63
    assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
    assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65

66
    max_frame_num = max_frame_num - 1  # start from 0
67
    num_pad = num_frames // 2
68

69
    indices = []
70
    for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
        if i < 0:
72
            if padding == 'replicate':
73
                pad_idx = 0
74
            elif padding == 'reflection':
75
                pad_idx = -i
76
            elif padding == 'reflection_circle':
77
                pad_idx = crt_idx + num_pad - i
78
            else:
79
                pad_idx = num_frames + i
80
        elif i > max_frame_num:
81
            if padding == 'replicate':
82
                pad_idx = max_frame_num
83
            elif padding == 'reflection':
84
                pad_idx = max_frame_num * 2 - i
85
            elif padding == 'reflection_circle':
86
                pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
            else:
88
                pad_idx = i - num_frames
89
        else:
90
            pad_idx = i
91
        indices.append(pad_idx)
92
    return indices
93

94

95
def paired_paths_from_lmdb(folders, keys):
96
    """Generate paired paths from lmdb files.
97

98
    Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99

100
    ::
101

102
        lq.lmdb
103
        ├── data.mdb
104
        ├── lock.mdb
105
        ├── meta_info.txt
106

107
    The data.mdb and lock.mdb are standard lmdb files and you can refer to
108
    https://lmdb.readthedocs.io/en/release/ for more details.
109

110
    The meta_info.txt is a specified txt file to record the meta information
111
    of our datasets. It will be automatically created when preparing
112
    datasets by our provided dataset tools.
113
    Each line in the txt file records
114
    1)image name (with extension),
115
    2)image shape,
116
    3)compression level, separated by a white space.
117
    Example: `baboon.png (120,125,3) 1`
118

119
    We use the image name without extension as the lmdb key.
120
    Note that we use the same key for the corresponding lq and gt images.
121

122
    Args:
123
        folders (list[str]): A list of folder path. The order of list should
124
            be [input_folder, gt_folder].
125
        keys (list[str]): A list of keys identifying folders. The order should
126
            be in consistent with folders, e.g., ['lq', 'gt'].
127
            Note that this key is different from lmdb keys.
128

129
    Returns:
130
        list[str]: Returned path list.
131
    """
132
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
133
                               f'But got {len(folders)}')
134
    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
135
    input_folder, gt_folder = folders
136
    input_key, gt_key = keys
137

138
    if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
139
        raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
140
                         f'formats. But received {input_key}: {input_folder}; '
141
                         f'{gt_key}: {gt_folder}')
142
    # ensure that the two meta_info files are the same
143
    with open(osp.join(input_folder, 'meta_info.txt')) as fin:
144
        input_lmdb_keys = [line.split('.')[0] for line in fin]
145
    with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
146
        gt_lmdb_keys = [line.split('.')[0] for line in fin]
147
    if set(input_lmdb_keys) != set(gt_lmdb_keys):
148
        raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
149
    else:
150
        paths = []
151
        for lmdb_key in sorted(input_lmdb_keys):
152
            paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
153
        return paths
154

155

156
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
157
    """Generate paired paths from an meta information file.
158

159
    Each line in the meta information file contains the image names and
160
    image shape (usually for gt), separated by a white space.
161

162
    Example of an meta information file:
163
    ```
164
    0001_s001.png (480,480,3)
165
    0001_s002.png (480,480,3)
166
    ```
167

168
    Args:
169
        folders (list[str]): A list of folder path. The order of list should
170
            be [input_folder, gt_folder].
171
        keys (list[str]): A list of keys identifying folders. The order should
172
            be in consistent with folders, e.g., ['lq', 'gt'].
173
        meta_info_file (str): Path to the meta information file.
174
        filename_tmpl (str): Template for each filename. Note that the
175
            template excludes the file extension. Usually the filename_tmpl is
176
            for files in the input folder.
177

178
    Returns:
179
        list[str]: Returned path list.
180
    """
181
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
182
                               f'But got {len(folders)}')
183
    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
184
    input_folder, gt_folder = folders
185
    input_key, gt_key = keys
186

187
    with open(meta_info_file, 'r') as fin:
188
        gt_names = [line.strip().split(' ')[0] for line in fin]
189

190
    paths = []
191
    for gt_name in gt_names:
192
        basename, ext = osp.splitext(osp.basename(gt_name))
193
        input_name = f'{filename_tmpl.format(basename)}{ext}'
194
        input_path = osp.join(input_folder, input_name)
195
        gt_path = osp.join(gt_folder, gt_name)
196
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
197
    return paths
198

199

200
def paired_paths_from_folder(folders, keys, filename_tmpl):
201
    """Generate paired paths from folders.
202

203
    Args:
204
        folders (list[str]): A list of folder path. The order of list should
205
            be [input_folder, gt_folder].
206
        keys (list[str]): A list of keys identifying folders. The order should
207
            be in consistent with folders, e.g., ['lq', 'gt'].
208
        filename_tmpl (str): Template for each filename. Note that the
209
            template excludes the file extension. Usually the filename_tmpl is
210
            for files in the input folder.
211

212
    Returns:
213
        list[str]: Returned path list.
214
    """
215
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
216
                               f'But got {len(folders)}')
217
    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
218
    input_folder, gt_folder = folders
219
    input_key, gt_key = keys
220

221
    input_paths = list(scandir(input_folder))
222
    gt_paths = list(scandir(gt_folder))
223
    assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
224
                                               f'{len(input_paths)}, {len(gt_paths)}.')
225
    paths = []
226
    for gt_path in gt_paths:
227
        basename, ext = osp.splitext(osp.basename(gt_path))
228
        input_name = f'{filename_tmpl.format(basename)}{ext}'
229
        input_path = osp.join(input_folder, input_name)
230
        assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
231
        gt_path = osp.join(gt_folder, gt_path)
232
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
233
    return paths
234

235

236
def paths_from_folder(folder):
237
    """Generate paths from folder.
238

239
    Args:
240
        folder (str): Folder path.
241

242
    Returns:
243
        list[str]: Returned path list.
244
    """
245

246
    paths = list(scandir(folder))
247
    paths = [osp.join(folder, path) for path in paths]
248
    return paths
249

250

251
def paths_from_lmdb(folder):
252
    """Generate paths from lmdb.
253

254
    Args:
255
        folder (str): Folder path.
256

257
    Returns:
258
        list[str]: Returned path list.
259
    """
260
    if not folder.endswith('.lmdb'):
261
        raise ValueError(f'Folder {folder}folder should in lmdb format.')
262
    with open(osp.join(folder, 'meta_info.txt')) as fin:
263
        paths = [line.split('.')[0] for line in fin]
264
    return paths
265

266

267
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
268
    """Generate Gaussian kernel used in `duf_downsample`.
269

270
    Args:
271
        kernel_size (int): Kernel size. Default: 13.
272
        sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
273

274
    Returns:
275
        np.array: The Gaussian kernel.
276
    """
277
    from scipy.ndimage import filters as filters
278
    kernel = np.zeros((kernel_size, kernel_size))
279
    # set element at the middle to one, a dirac delta
280
    kernel[kernel_size // 2, kernel_size // 2] = 1
281
    # gaussian-smooth the dirac, resulting in a gaussian filter
282
    return filters.gaussian_filter(kernel, sigma)
283

284

285
def duf_downsample(x, kernel_size=13, scale=4):
286
    """Downsamping with Gaussian kernel used in the DUF official code.
287

288
    Args:
289
        x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
290
        kernel_size (int): Kernel size. Default: 13.
291
        scale (int): Downsampling factor. Supported scale: (2, 3, 4).
292
            Default: 4.
293

294
    Returns:
295
        Tensor: DUF downsampled frames.
296
    """
297
    assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
298

299
    squeeze_flag = False
300
    if x.ndim == 4:
301
        squeeze_flag = True
302
        x = x.unsqueeze(0)
303
    b, t, c, h, w = x.size()
304
    x = x.view(-1, 1, h, w)
305
    pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
306
    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
307

308
    gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
309
    gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
310
    x = F.conv2d(x, gaussian_filter, stride=scale)
311
    x = x[:, :, 2:-2, 2:-2]
312
    x = x.view(b, t, c, x.size(2), x.size(3))
313
    if squeeze_flag:
314
        x = x.squeeze(0)
315
    return x
316

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.