4
from os import path as osp
5
from torch.nn import functional as F
7
from basicsr.data.transforms import mod_crop
8
from basicsr.utils import img2tensor, scandir
11
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
"""Read a sequence of images from a given folder path.
15
path (list[str] | str): List of image paths or image folder path.
16
require_mod_crop (bool): Require mod crop for each image.
18
scale (int): Scale factor for mod_crop. Default: 1.
19
return_imgname(bool): Whether return image names. Default False.
22
Tensor: size (t, c, h, w), RGB, [0, 1].
23
list[str]: Returned image name list.
25
if isinstance(path, list):
28
img_paths = sorted(list(scandir(path, full_path=True)))
29
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
32
imgs = [mod_crop(img, scale) for img in imgs]
33
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
imgs = torch.stack(imgs, dim=0)
37
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
43
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
"""Generate an index list for reading `num_frames` frames from a sequence
48
crt_idx (int): Current center index.
49
max_frame_num (int): Max number of the sequence of images (from 1).
50
num_frames (int): Reading num_frames frames.
51
padding (str): Padding mode, one of
52
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
Examples: current_idx = 0, num_frames = 5
54
The generated frame indices under different padding mode:
55
replicate: [0, 0, 0, 1, 2]
56
reflection: [2, 1, 0, 1, 2]
57
reflection_circle: [4, 3, 0, 1, 2]
58
circle: [3, 4, 0, 1, 2]
61
list[int]: A list of indices.
63
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
66
max_frame_num = max_frame_num - 1
67
num_pad = num_frames // 2
70
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
72
if padding == 'replicate':
74
elif padding == 'reflection':
76
elif padding == 'reflection_circle':
77
pad_idx = crt_idx + num_pad - i
79
pad_idx = num_frames + i
80
elif i > max_frame_num:
81
if padding == 'replicate':
82
pad_idx = max_frame_num
83
elif padding == 'reflection':
84
pad_idx = max_frame_num * 2 - i
85
elif padding == 'reflection_circle':
86
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
88
pad_idx = i - num_frames
91
indices.append(pad_idx)
95
def paired_paths_from_lmdb(folders, keys):
96
"""Generate paired paths from lmdb files.
98
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
107
The data.mdb and lock.mdb are standard lmdb files and you can refer to
108
https://lmdb.readthedocs.io/en/release/ for more details.
110
The meta_info.txt is a specified txt file to record the meta information
111
of our datasets. It will be automatically created when preparing
112
datasets by our provided dataset tools.
113
Each line in the txt file records
114
1)image name (with extension),
116
3)compression level, separated by a white space.
117
Example: `baboon.png (120,125,3) 1`
119
We use the image name without extension as the lmdb key.
120
Note that we use the same key for the corresponding lq and gt images.
123
folders (list[str]): A list of folder path. The order of list should
124
be [input_folder, gt_folder].
125
keys (list[str]): A list of keys identifying folders. The order should
126
be in consistent with folders, e.g., ['lq', 'gt'].
127
Note that this key is different from lmdb keys.
130
list[str]: Returned path list.
132
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
133
f'But got {len(folders)}')
134
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
135
input_folder, gt_folder = folders
136
input_key, gt_key = keys
138
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
139
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
140
f'formats. But received {input_key}: {input_folder}; '
141
f'{gt_key}: {gt_folder}')
143
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
144
input_lmdb_keys = [line.split('.')[0] for line in fin]
145
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
146
gt_lmdb_keys = [line.split('.')[0] for line in fin]
147
if set(input_lmdb_keys) != set(gt_lmdb_keys):
148
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
151
for lmdb_key in sorted(input_lmdb_keys):
152
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
156
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
157
"""Generate paired paths from an meta information file.
159
Each line in the meta information file contains the image names and
160
image shape (usually for gt), separated by a white space.
162
Example of an meta information file:
164
0001_s001.png (480,480,3)
165
0001_s002.png (480,480,3)
169
folders (list[str]): A list of folder path. The order of list should
170
be [input_folder, gt_folder].
171
keys (list[str]): A list of keys identifying folders. The order should
172
be in consistent with folders, e.g., ['lq', 'gt'].
173
meta_info_file (str): Path to the meta information file.
174
filename_tmpl (str): Template for each filename. Note that the
175
template excludes the file extension. Usually the filename_tmpl is
176
for files in the input folder.
179
list[str]: Returned path list.
181
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
182
f'But got {len(folders)}')
183
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
184
input_folder, gt_folder = folders
185
input_key, gt_key = keys
187
with open(meta_info_file, 'r') as fin:
188
gt_names = [line.strip().split(' ')[0] for line in fin]
191
for gt_name in gt_names:
192
basename, ext = osp.splitext(osp.basename(gt_name))
193
input_name = f'{filename_tmpl.format(basename)}{ext}'
194
input_path = osp.join(input_folder, input_name)
195
gt_path = osp.join(gt_folder, gt_name)
196
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
200
def paired_paths_from_folder(folders, keys, filename_tmpl):
201
"""Generate paired paths from folders.
204
folders (list[str]): A list of folder path. The order of list should
205
be [input_folder, gt_folder].
206
keys (list[str]): A list of keys identifying folders. The order should
207
be in consistent with folders, e.g., ['lq', 'gt'].
208
filename_tmpl (str): Template for each filename. Note that the
209
template excludes the file extension. Usually the filename_tmpl is
210
for files in the input folder.
213
list[str]: Returned path list.
215
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
216
f'But got {len(folders)}')
217
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
218
input_folder, gt_folder = folders
219
input_key, gt_key = keys
221
input_paths = list(scandir(input_folder))
222
gt_paths = list(scandir(gt_folder))
223
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
224
f'{len(input_paths)}, {len(gt_paths)}.')
226
for gt_path in gt_paths:
227
basename, ext = osp.splitext(osp.basename(gt_path))
228
input_name = f'{filename_tmpl.format(basename)}{ext}'
229
input_path = osp.join(input_folder, input_name)
230
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
231
gt_path = osp.join(gt_folder, gt_path)
232
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
236
def paths_from_folder(folder):
237
"""Generate paths from folder.
240
folder (str): Folder path.
243
list[str]: Returned path list.
246
paths = list(scandir(folder))
247
paths = [osp.join(folder, path) for path in paths]
251
def paths_from_lmdb(folder):
252
"""Generate paths from lmdb.
255
folder (str): Folder path.
258
list[str]: Returned path list.
260
if not folder.endswith('.lmdb'):
261
raise ValueError(f'Folder {folder}folder should in lmdb format.')
262
with open(osp.join(folder, 'meta_info.txt')) as fin:
263
paths = [line.split('.')[0] for line in fin]
267
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
268
"""Generate Gaussian kernel used in `duf_downsample`.
271
kernel_size (int): Kernel size. Default: 13.
272
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
275
np.array: The Gaussian kernel.
277
from scipy.ndimage import filters as filters
278
kernel = np.zeros((kernel_size, kernel_size))
280
kernel[kernel_size // 2, kernel_size // 2] = 1
282
return filters.gaussian_filter(kernel, sigma)
285
def duf_downsample(x, kernel_size=13, scale=4):
286
"""Downsamping with Gaussian kernel used in the DUF official code.
289
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
290
kernel_size (int): Kernel size. Default: 13.
291
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
295
Tensor: DUF downsampled frames.
297
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
303
b, t, c, h, w = x.size()
304
x = x.view(-1, 1, h, w)
305
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
306
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
308
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
309
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
310
x = F.conv2d(x, gaussian_filter, stride=scale)
311
x = x[:, :, 2:-2, 2:-2]
312
x = x.view(b, t, c, x.size(2), x.size(3))