BasicSR
54 строки · 1.3 Кб
1import math2import os3import torch4import torchvision.utils5
6from basicsr.data import build_dataloader, build_dataset7
8
9def main():10"""Test FFHQ dataset."""11opt = {}12opt['dist'] = False13opt['gpu_ids'] = [0]14opt['phase'] = 'train'15
16opt['name'] = 'FFHQ'17opt['type'] = 'FFHQDataset'18
19opt['dataroot_gt'] = 'datasets/ffhq/ffhq_256.lmdb'20opt['io_backend'] = dict(type='lmdb')21
22opt['use_hflip'] = True23opt['mean'] = [0.5, 0.5, 0.5]24opt['std'] = [0.5, 0.5, 0.5]25
26opt['num_worker_per_gpu'] = 127opt['batch_size_per_gpu'] = 428
29opt['dataset_enlarge_ratio'] = 130
31os.makedirs('tmp', exist_ok=True)32
33dataset = build_dataset(opt)34data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)35
36nrow = int(math.sqrt(opt['batch_size_per_gpu']))37padding = 2 if opt['phase'] == 'train' else 038
39print('start...')40for i, data in enumerate(data_loader):41if i > 5:42break43print(i)44
45gt = data['gt']46print(torch.min(gt), torch.max(gt))47gt_path = data['gt_path']48print(gt_path)49torchvision.utils.save_image(50gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=True, range=(-1, 1))51
52
53if __name__ == '__main__':54main()55