Real-ESRGAN
19 строк · 561.0 Байт
1import torch
2
3from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
4
5
6def test_unetdiscriminatorsn():
7"""Test arch: UNetDiscriminatorSN."""
8
9# model init and forward (cpu)
10net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
11img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
12output = net(img)
13assert output.shape == (1, 1, 32, 32)
14
15# model init and forward (gpu)
16if torch.cuda.is_available():
17net.cuda()
18output = net(img.cuda())
19assert output.shape == (1, 1, 32, 32)
20