BasicSR
165 строк · 4.7 Кб
1import tempfile
2import torch
3import yaml
4
5from basicsr.archs.srresnet_arch import MSRResNet
6from basicsr.data.paired_image_dataset import PairedImageDataset
7from basicsr.losses.basic_loss import L1Loss, PerceptualLoss
8from basicsr.models.sr_model import SRModel
9
10
11def test_srmodel():
12"""Test model: SRModel"""
13
14opt_str = r"""
15scale: 4
16num_gpu: 1
17manual_seed: 0
18is_train: True
19dist: False
20
21# network structures
22network_g:
23type: MSRResNet
24num_in_ch: 3
25num_out_ch: 3
26num_feat: 4
27num_block: 1
28upscale: 4
29
30# path
31path:
32pretrain_network_g: ~
33strict_load_g: true
34resume_state: ~
35
36# training settings
37train:
38ema_decay: 0.999
39optim_g:
40type: Adam
41lr: !!float 2e-4
42weight_decay: 0
43betas: [0.9, 0.99]
44
45scheduler:
46type: CosineAnnealingRestartLR
47periods: [250000, 250000, 250000, 250000]
48restart_weights: [1, 1, 1, 1]
49eta_min: !!float 1e-7
50
51total_iter: 1000000
52warmup_iter: -1 # no warm up
53
54# losses
55pixel_opt:
56type: L1Loss
57loss_weight: 1.0
58reduction: mean
59perceptual_opt:
60type: PerceptualLoss
61layer_weights:
62'conv5_4': 1 # before relu
63vgg_type: vgg19
64use_input_norm: true
65range_norm: false
66perceptual_weight: 1.0
67style_weight: 1.0
68criterion: l1
69
70# validation settings
71val:
72val_freq: !!float 5e3
73save_img: True
74
75metrics:
76psnr: # metric name
77type: calculate_psnr
78crop_border: 4
79test_y_channel: false
80better: higher # the higher, the better. Default: higher
81"""
82
83opt = yaml.safe_load(opt_str)
84
85# build model
86model = SRModel(opt)
87# test attributes
88assert model.__class__.__name__ == 'SRModel'
89assert isinstance(model.net_g, MSRResNet)
90assert isinstance(model.cri_pix, L1Loss)
91assert isinstance(model.cri_perceptual, PerceptualLoss)
92assert isinstance(model.optimizers[0], torch.optim.Adam)
93assert model.ema_decay == 0.999
94
95# prepare data
96gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
97lq = torch.rand((1, 3, 8, 8), dtype=torch.float32)
98data = dict(gt=gt, lq=lq)
99model.feed_data(data)
100# check data shape
101assert model.lq.shape == (1, 3, 8, 8)
102assert model.gt.shape == (1, 3, 32, 32)
103
104# ----------------- test optimize_parameters -------------------- #
105model.optimize_parameters(1)
106assert model.output.shape == (1, 3, 32, 32)
107assert isinstance(model.log_dict, dict)
108# check returned keys
109expected_keys = ['l_pix', 'l_percep', 'l_style']
110assert set(expected_keys).issubset(set(model.log_dict.keys()))
111
112# ----------------- test save -------------------- #
113with tempfile.TemporaryDirectory() as tmpdir:
114model.opt['path']['models'] = tmpdir
115model.opt['path']['training_states'] = tmpdir
116model.save(0, 1)
117
118# ----------------- test the test function -------------------- #
119model.test()
120assert model.output.shape == (1, 3, 32, 32)
121# delete net_g_ema
122model.__delattr__('net_g_ema')
123model.test()
124assert model.output.shape == (1, 3, 32, 32)
125assert model.net_g.training is True # should back to training mode after testing
126
127# ----------------- test nondist_validation -------------------- #
128# construct dataloader
129dataset_opt = dict(
130name='Test',
131dataroot_gt='tests/data/gt',
132dataroot_lq='tests/data/lq',
133io_backend=dict(type='disk'),
134scale=4,
135phase='val')
136dataset = PairedImageDataset(dataset_opt)
137dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
138assert model.is_train is True
139with tempfile.TemporaryDirectory() as tmpdir:
140model.opt['path']['visualization'] = tmpdir
141model.nondist_validation(dataloader, 1, None, save_img=True)
142assert model.is_train is True
143# check metric_results
144assert 'psnr' in model.metric_results
145assert isinstance(model.metric_results['psnr'], float)
146
147# in validation mode
148with tempfile.TemporaryDirectory() as tmpdir:
149model.opt['is_train'] = False
150model.opt['val']['suffix'] = 'test'
151model.opt['path']['visualization'] = tmpdir
152model.opt['val']['pbar'] = True
153model.nondist_validation(dataloader, 1, None, save_img=True)
154# check metric_results
155assert 'psnr' in model.metric_results
156assert isinstance(model.metric_results['psnr'], float)
157
158# if opt['val']['suffix'] is None
159model.opt['val']['suffix'] = None
160model.opt['name'] = 'demo'
161model.opt['path']['visualization'] = tmpdir
162model.nondist_validation(dataloader, 1, None, save_img=True)
163# check metric_results
164assert 'psnr' in model.metric_results
165assert isinstance(model.metric_results['psnr'], float)
166