scikit-image

Форк
0
/
test_fisher_vector.py 
191 строка · 5.6 Кб
1
import pytest
2

3
import numpy as np
4

5
pytest.importorskip('sklearn')
6

7
from skimage.feature._fisher_vector import (  # noqa: E402
8
    learn_gmm,
9
    fisher_vector,
10
    FisherVectorException,
11
    DescriptorException,
12
)
13

14

15
def test_gmm_wrong_descriptor_format_1():
16
    """Test that DescriptorException is raised when wrong type for descriptions
17
    is passed.
18
    """
19

20
    with pytest.raises(DescriptorException):
21
        learn_gmm('completely wrong test', n_modes=1)
22

23

24
def test_gmm_wrong_descriptor_format_2():
25
    """Test that DescriptorException is raised when descriptors are of
26
    different dimensionality.
27
    """
28

29
    with pytest.raises(DescriptorException):
30
        learn_gmm([np.zeros((5, 11)), np.zeros((4, 10))], n_modes=1)
31

32

33
def test_gmm_wrong_descriptor_format_3():
34
    """Test that DescriptorException is raised when not all descriptors are of
35
    rank 2.
36
    """
37

38
    with pytest.raises(DescriptorException):
39
        learn_gmm([np.zeros((5, 10)), np.zeros((4, 10, 1))], n_modes=1)
40

41

42
def test_gmm_wrong_descriptor_format_4():
43
    """Test that DescriptorException is raised when elements of descriptor list
44
    are of the incorrect type (i.e. not a NumPy ndarray).
45
    """
46

47
    with pytest.raises(DescriptorException):
48
        learn_gmm([[1, 2, 3], [1, 2, 3]], n_modes=1)
49

50

51
def test_gmm_wrong_num_modes_format_1():
52
    """Test that FisherVectorException is raised when incorrect type for
53
    n_modes is passed into the learn_gmm function.
54
    """
55

56
    with pytest.raises(FisherVectorException):
57
        learn_gmm([np.zeros((5, 10)), np.zeros((4, 10))], n_modes='not_valid')
58

59

60
def test_gmm_wrong_num_modes_format_2():
61
    """Test that FisherVectorException is raised when a number that is not a
62
    positive integer is passed into the n_modes argument of learn_gmm.
63
    """
64

65
    with pytest.raises(FisherVectorException):
66
        learn_gmm([np.zeros((5, 10)), np.zeros((4, 10))], n_modes=-1)
67

68

69
def test_gmm_wrong_covariance_type():
70
    """Test that FisherVectorException is raised when wrong covariance type is
71
    passed in as a keyword argument.
72
    """
73

74
    with pytest.raises(FisherVectorException):
75
        learn_gmm(
76
            np.random.random((10, 10)), n_modes=2, gm_args={'covariance_type': 'full'}
77
        )
78

79

80
def test_gmm_correct_covariance_type():
81
    """Test that GMM estimation is successful when the correct covariance type
82
    is passed in as a keyword argument.
83
    """
84

85
    gmm = learn_gmm(
86
        np.random.random((10, 10)), n_modes=2, gm_args={'covariance_type': 'diag'}
87
    )
88

89
    assert gmm.means_ is not None
90
    assert gmm.covariances_ is not None
91
    assert gmm.weights_ is not None
92

93

94
def test_gmm_e2e():
95
    """
96
    Test the GMM estimation. Since this is essentially a wrapper for the
97
    scikit-learn GaussianMixture class, the testing of the actual inner
98
    workings of the GMM estimation is left to scikit-learn and its
99
    dependencies.
100

101
    We instead simply assert that the estimation was successful based on the
102
    fact that the GMM object will have associated mixture weights, means, and
103
    variances after estimation is successful/complete.
104
    """
105

106
    gmm = learn_gmm(np.random.random((100, 64)), n_modes=5)
107

108
    assert gmm.means_ is not None
109
    assert gmm.covariances_ is not None
110
    assert gmm.weights_ is not None
111

112

113
def test_fv_wrong_descriptor_types():
114
    """
115
    Test that DescriptorException is raised when the incorrect type for the
116
    descriptors is passed into the fisher_vector function.
117
    """
118
    try:
119
        from sklearn.mixture import GaussianMixture
120
    except ImportError:
121
        print(
122
            'scikit-learn is not installed. Please ensure it is installed in '
123
            'order to use the Fisher vector functionality.'
124
        )
125

126
    with pytest.raises(DescriptorException):
127
        fisher_vector([[1, 2, 3, 4]], GaussianMixture())
128

129

130
def test_fv_wrong_gmm_type():
131
    """
132
    Test that FisherVectorException is raised when a GMM not of type
133
    sklearn.mixture.GaussianMixture is passed into the fisher_vector
134
    function.
135
    """
136

137
    class MyDifferentGaussianMixture:
138
        pass
139

140
    with pytest.raises(FisherVectorException):
141
        fisher_vector(np.zeros((10, 10)), MyDifferentGaussianMixture())
142

143

144
def test_fv_e2e():
145
    """
146
    Test the Fisher vector computation given a GMM returned from the learn_gmm
147
    function. We simply assert that the dimensionality of the resulting Fisher
148
    vector is correct.
149

150
    The dimensionality of a Fisher vector is given by 2KD + K, where K is the
151
    number of Gaussians specified in the associated GMM, and D is the
152
    dimensionality of the descriptors using to estimate the GMM.
153
    """
154

155
    dim = 128
156
    num_modes = 8
157

158
    expected_dim = 2 * num_modes * dim + num_modes
159

160
    descriptors = [np.random.random((np.random.randint(5, 30), dim)) for _ in range(10)]
161

162
    gmm = learn_gmm(descriptors, n_modes=num_modes)
163

164
    fisher_vec = fisher_vector(descriptors[0], gmm)
165

166
    assert len(fisher_vec) == expected_dim
167

168

169
def test_fv_e2e_improved():
170
    """
171
    Test the improved Fisher vector computation given a GMM returned from the
172
    learn_gmm function. We simply assert that the dimensionality of the
173
    resulting Fisher vector is correct.
174

175
    The dimensionality of a Fisher vector is given by 2KD + K, where K is the
176
    number of Gaussians specified in the associated GMM, and D is the
177
    dimensionality of the descriptors using to estimate the GMM.
178
    """
179

180
    dim = 128
181
    num_modes = 8
182

183
    expected_dim = 2 * num_modes * dim + num_modes
184

185
    descriptors = [np.random.random((np.random.randint(5, 30), dim)) for _ in range(10)]
186

187
    gmm = learn_gmm(descriptors, n_modes=num_modes)
188

189
    fisher_vec = fisher_vector(descriptors[0], gmm, improved=True)
190

191
    assert len(fisher_vec) == expected_dim
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.