BasicSR

Форк
0
/
test_sr_model.py 
165 строк · 4.7 Кб
1
import tempfile
2
import torch
3
import yaml
4

5
from basicsr.archs.srresnet_arch import MSRResNet
6
from basicsr.data.paired_image_dataset import PairedImageDataset
7
from basicsr.losses.basic_loss import L1Loss, PerceptualLoss
8
from basicsr.models.sr_model import SRModel
9

10

11
def test_srmodel():
12
    """Test model: SRModel"""
13

14
    opt_str = r"""
15
scale: 4
16
num_gpu: 1
17
manual_seed: 0
18
is_train: True
19
dist: False
20

21
# network structures
22
network_g:
23
  type: MSRResNet
24
  num_in_ch: 3
25
  num_out_ch: 3
26
  num_feat: 4
27
  num_block: 1
28
  upscale: 4
29

30
# path
31
path:
32
  pretrain_network_g: ~
33
  strict_load_g: true
34
  resume_state: ~
35

36
# training settings
37
train:
38
  ema_decay: 0.999
39
  optim_g:
40
    type: Adam
41
    lr: !!float 2e-4
42
    weight_decay: 0
43
    betas: [0.9, 0.99]
44

45
  scheduler:
46
    type: CosineAnnealingRestartLR
47
    periods: [250000, 250000, 250000, 250000]
48
    restart_weights: [1, 1, 1, 1]
49
    eta_min: !!float 1e-7
50

51
  total_iter: 1000000
52
  warmup_iter: -1  # no warm up
53

54
  # losses
55
  pixel_opt:
56
    type: L1Loss
57
    loss_weight: 1.0
58
    reduction: mean
59
  perceptual_opt:
60
    type: PerceptualLoss
61
    layer_weights:
62
      'conv5_4': 1  # before relu
63
    vgg_type: vgg19
64
    use_input_norm: true
65
    range_norm: false
66
    perceptual_weight: 1.0
67
    style_weight: 1.0
68
    criterion: l1
69

70
# validation settings
71
val:
72
  val_freq: !!float 5e3
73
  save_img: True
74

75
  metrics:
76
    psnr: # metric name
77
      type: calculate_psnr
78
      crop_border: 4
79
      test_y_channel: false
80
      better: higher  # the higher, the better. Default: higher
81
"""
82

83
    opt = yaml.safe_load(opt_str)
84

85
    # build model
86
    model = SRModel(opt)
87
    # test attributes
88
    assert model.__class__.__name__ == 'SRModel'
89
    assert isinstance(model.net_g, MSRResNet)
90
    assert isinstance(model.cri_pix, L1Loss)
91
    assert isinstance(model.cri_perceptual, PerceptualLoss)
92
    assert isinstance(model.optimizers[0], torch.optim.Adam)
93
    assert model.ema_decay == 0.999
94

95
    # prepare data
96
    gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
97
    lq = torch.rand((1, 3, 8, 8), dtype=torch.float32)
98
    data = dict(gt=gt, lq=lq)
99
    model.feed_data(data)
100
    # check data shape
101
    assert model.lq.shape == (1, 3, 8, 8)
102
    assert model.gt.shape == (1, 3, 32, 32)
103

104
    # ----------------- test optimize_parameters -------------------- #
105
    model.optimize_parameters(1)
106
    assert model.output.shape == (1, 3, 32, 32)
107
    assert isinstance(model.log_dict, dict)
108
    # check returned keys
109
    expected_keys = ['l_pix', 'l_percep', 'l_style']
110
    assert set(expected_keys).issubset(set(model.log_dict.keys()))
111

112
    # ----------------- test save -------------------- #
113
    with tempfile.TemporaryDirectory() as tmpdir:
114
        model.opt['path']['models'] = tmpdir
115
        model.opt['path']['training_states'] = tmpdir
116
        model.save(0, 1)
117

118
    # ----------------- test the test function -------------------- #
119
    model.test()
120
    assert model.output.shape == (1, 3, 32, 32)
121
    # delete net_g_ema
122
    model.__delattr__('net_g_ema')
123
    model.test()
124
    assert model.output.shape == (1, 3, 32, 32)
125
    assert model.net_g.training is True  # should back to training mode after testing
126

127
    # ----------------- test nondist_validation -------------------- #
128
    # construct dataloader
129
    dataset_opt = dict(
130
        name='Test',
131
        dataroot_gt='tests/data/gt',
132
        dataroot_lq='tests/data/lq',
133
        io_backend=dict(type='disk'),
134
        scale=4,
135
        phase='val')
136
    dataset = PairedImageDataset(dataset_opt)
137
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
138
    assert model.is_train is True
139
    with tempfile.TemporaryDirectory() as tmpdir:
140
        model.opt['path']['visualization'] = tmpdir
141
        model.nondist_validation(dataloader, 1, None, save_img=True)
142
        assert model.is_train is True
143
        # check metric_results
144
        assert 'psnr' in model.metric_results
145
        assert isinstance(model.metric_results['psnr'], float)
146

147
    # in validation mode
148
    with tempfile.TemporaryDirectory() as tmpdir:
149
        model.opt['is_train'] = False
150
        model.opt['val']['suffix'] = 'test'
151
        model.opt['path']['visualization'] = tmpdir
152
        model.opt['val']['pbar'] = True
153
        model.nondist_validation(dataloader, 1, None, save_img=True)
154
        # check metric_results
155
        assert 'psnr' in model.metric_results
156
        assert isinstance(model.metric_results['psnr'], float)
157

158
        # if opt['val']['suffix'] is None
159
        model.opt['val']['suffix'] = None
160
        model.opt['name'] = 'demo'
161
        model.opt['path']['visualization'] = tmpdir
162
        model.nondist_validation(dataloader, 1, None, save_img=True)
163
        # check metric_results
164
        assert 'psnr' in model.metric_results
165
        assert isinstance(model.metric_results['psnr'], float)
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.