3
from common_utils import assert_equal
4
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
5
from torchvision.models.detection.image_list import ImageList
9
def test_incorrect_anchors(self):
14
incorrect_aspects = (0.5, 1.0)
15
anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
16
image1 = torch.randn(3, 800, 800)
17
image_list = ImageList(image1, [(800, 800)])
18
feature_maps = [torch.randn(1, 50)]
19
pytest.raises(AssertionError, anc, image_list, feature_maps)
21
def _init_test_anchor_generator(self):
22
anchor_sizes = ((10,),)
23
aspect_ratios = ((1,),)
24
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
26
return anchor_generator
28
def _init_test_defaultbox_generator(self):
30
dbox_generator = DefaultBoxGenerator(aspect_ratios)
34
def get_features(self, images):
35
s0, s1 = images.shape[-2:]
36
features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
39
def test_anchor_generator(self):
40
images = torch.randn(2, 3, 15, 15)
41
features = self.get_features(images)
42
image_shapes = [i.shape[-2:] for i in images]
43
images = ImageList(images, image_shapes)
45
model = self._init_test_anchor_generator()
47
anchors = model(images, features)
49
# Estimate the number of target anchors
50
grid_sizes = [f.shape[-2:] for f in features]
51
num_anchors_estimated = 0
52
for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
53
num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
55
anchors_output = torch.tensor(
57
[-5.0, -5.0, 5.0, 5.0],
58
[0.0, -5.0, 10.0, 5.0],
59
[5.0, -5.0, 15.0, 5.0],
60
[-5.0, 0.0, 5.0, 10.0],
61
[0.0, 0.0, 10.0, 10.0],
62
[5.0, 0.0, 15.0, 10.0],
63
[-5.0, 5.0, 5.0, 15.0],
64
[0.0, 5.0, 10.0, 15.0],
65
[5.0, 5.0, 15.0, 15.0],
69
assert num_anchors_estimated == 9
70
assert len(anchors) == 2
71
assert tuple(anchors[0].shape) == (9, 4)
72
assert tuple(anchors[1].shape) == (9, 4)
73
assert_equal(anchors[0], anchors_output)
74
assert_equal(anchors[1], anchors_output)
76
def test_defaultbox_generator(self):
77
images = torch.zeros(2, 3, 15, 15)
78
features = [torch.zeros(2, 8, 1, 1)]
79
image_shapes = [i.shape[-2:] for i in images]
80
images = ImageList(images, image_shapes)
82
model = self._init_test_defaultbox_generator()
84
dboxes = model(images, features)
86
dboxes_output = torch.tensor(
88
[6.3750, 6.3750, 8.6250, 8.6250],
89
[4.7443, 4.7443, 10.2557, 10.2557],
90
[5.9090, 6.7045, 9.0910, 8.2955],
91
[6.7045, 5.9090, 8.2955, 9.0910],
95
assert len(dboxes) == 2
96
assert tuple(dboxes[0].shape) == (4, 4)
97
assert tuple(dboxes[1].shape) == (4, 4)
98
torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
99
torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)