GFPGAN

Форк
0
/
test_arcface_arch.py 
49 строк · 1.8 Кб
1
import torch
2

3
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
4

5

6
def test_resnetarcface():
7
    """Test arch: ResNetArcFace."""
8

9
    # model init and forward (gpu)
10
    if torch.cuda.is_available():
11
        net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
12
        img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
13
        output = net(img)
14
        assert output.shape == (1, 512)
15

16
        # -------------------- without SE block ----------------------- #
17
        net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
18
        output = net(img)
19
        assert output.shape == (1, 512)
20

21

22
def test_basicblock():
23
    """Test the BasicBlock in arcface_arch"""
24
    block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
25
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
26
    output = block(img)
27
    assert output.shape == (1, 3, 12, 12)
28

29
    # ----------------- use the downsmaple module--------------- #
30
    downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
31
    block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
32
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
33
    output = block(img)
34
    assert output.shape == (1, 3, 6, 6)
35

36

37
def test_bottleneck():
38
    """Test the Bottleneck in arcface_arch"""
39
    block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
40
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
41
    output = block(img)
42
    assert output.shape == (1, 4, 12, 12)
43

44
    # ----------------- use the downsmaple module--------------- #
45
    downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
46
    block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
47
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
48
    output = block(img)
49
    assert output.shape == (1, 4, 6, 6)
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.