Real-ESRGAN

Форк
0
/
test_discriminator_arch.py 
19 строк · 561.0 Байт
1
import torch
2

3
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
4

5

6
def test_unetdiscriminatorsn():
7
    """Test arch: UNetDiscriminatorSN."""
8

9
    # model init and forward (cpu)
10
    net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
11
    img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
12
    output = net(img)
13
    assert output.shape == (1, 1, 32, 32)
14

15
    # model init and forward (gpu)
16
    if torch.cuda.is_available():
17
        net.cuda()
18
        output = net(img.cuda())
19
        assert output.shape == (1, 1, 32, 32)
20

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.