vision

Форк
0
/
test_models_detection_anchor_utils.py 
99 строк · 3.4 Кб
1
import pytest
2
import torch
3
from common_utils import assert_equal
4
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
5
from torchvision.models.detection.image_list import ImageList
6

7

8
class Tester:
9
    def test_incorrect_anchors(self):
10
        incorrect_sizes = (
11
            (2, 4, 8),
12
            (32, 8),
13
        )
14
        incorrect_aspects = (0.5, 1.0)
15
        anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
16
        image1 = torch.randn(3, 800, 800)
17
        image_list = ImageList(image1, [(800, 800)])
18
        feature_maps = [torch.randn(1, 50)]
19
        pytest.raises(AssertionError, anc, image_list, feature_maps)
20

21
    def _init_test_anchor_generator(self):
22
        anchor_sizes = ((10,),)
23
        aspect_ratios = ((1,),)
24
        anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
25

26
        return anchor_generator
27

28
    def _init_test_defaultbox_generator(self):
29
        aspect_ratios = [[2]]
30
        dbox_generator = DefaultBoxGenerator(aspect_ratios)
31

32
        return dbox_generator
33

34
    def get_features(self, images):
35
        s0, s1 = images.shape[-2:]
36
        features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
37
        return features
38

39
    def test_anchor_generator(self):
40
        images = torch.randn(2, 3, 15, 15)
41
        features = self.get_features(images)
42
        image_shapes = [i.shape[-2:] for i in images]
43
        images = ImageList(images, image_shapes)
44

45
        model = self._init_test_anchor_generator()
46
        model.eval()
47
        anchors = model(images, features)
48

49
        # Estimate the number of target anchors
50
        grid_sizes = [f.shape[-2:] for f in features]
51
        num_anchors_estimated = 0
52
        for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
53
            num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
54

55
        anchors_output = torch.tensor(
56
            [
57
                [-5.0, -5.0, 5.0, 5.0],
58
                [0.0, -5.0, 10.0, 5.0],
59
                [5.0, -5.0, 15.0, 5.0],
60
                [-5.0, 0.0, 5.0, 10.0],
61
                [0.0, 0.0, 10.0, 10.0],
62
                [5.0, 0.0, 15.0, 10.0],
63
                [-5.0, 5.0, 5.0, 15.0],
64
                [0.0, 5.0, 10.0, 15.0],
65
                [5.0, 5.0, 15.0, 15.0],
66
            ]
67
        )
68

69
        assert num_anchors_estimated == 9
70
        assert len(anchors) == 2
71
        assert tuple(anchors[0].shape) == (9, 4)
72
        assert tuple(anchors[1].shape) == (9, 4)
73
        assert_equal(anchors[0], anchors_output)
74
        assert_equal(anchors[1], anchors_output)
75

76
    def test_defaultbox_generator(self):
77
        images = torch.zeros(2, 3, 15, 15)
78
        features = [torch.zeros(2, 8, 1, 1)]
79
        image_shapes = [i.shape[-2:] for i in images]
80
        images = ImageList(images, image_shapes)
81

82
        model = self._init_test_defaultbox_generator()
83
        model.eval()
84
        dboxes = model(images, features)
85

86
        dboxes_output = torch.tensor(
87
            [
88
                [6.3750, 6.3750, 8.6250, 8.6250],
89
                [4.7443, 4.7443, 10.2557, 10.2557],
90
                [5.9090, 6.7045, 9.0910, 8.2955],
91
                [6.7045, 5.9090, 8.2955, 9.0910],
92
            ]
93
        )
94

95
        assert len(dboxes) == 2
96
        assert tuple(dboxes[0].shape) == (4, 4)
97
        assert tuple(dboxes[1].shape) == (4, 4)
98
        torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
99
        torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.