scikit-image
191 строка · 5.6 Кб
1import pytest2
3import numpy as np4
5pytest.importorskip('sklearn')6
7from skimage.feature._fisher_vector import ( # noqa: E4028learn_gmm,9fisher_vector,10FisherVectorException,11DescriptorException,12)
13
14
15def test_gmm_wrong_descriptor_format_1():16"""Test that DescriptorException is raised when wrong type for descriptions17is passed.
18"""
19
20with pytest.raises(DescriptorException):21learn_gmm('completely wrong test', n_modes=1)22
23
24def test_gmm_wrong_descriptor_format_2():25"""Test that DescriptorException is raised when descriptors are of26different dimensionality.
27"""
28
29with pytest.raises(DescriptorException):30learn_gmm([np.zeros((5, 11)), np.zeros((4, 10))], n_modes=1)31
32
33def test_gmm_wrong_descriptor_format_3():34"""Test that DescriptorException is raised when not all descriptors are of35rank 2.
36"""
37
38with pytest.raises(DescriptorException):39learn_gmm([np.zeros((5, 10)), np.zeros((4, 10, 1))], n_modes=1)40
41
42def test_gmm_wrong_descriptor_format_4():43"""Test that DescriptorException is raised when elements of descriptor list44are of the incorrect type (i.e. not a NumPy ndarray).
45"""
46
47with pytest.raises(DescriptorException):48learn_gmm([[1, 2, 3], [1, 2, 3]], n_modes=1)49
50
51def test_gmm_wrong_num_modes_format_1():52"""Test that FisherVectorException is raised when incorrect type for53n_modes is passed into the learn_gmm function.
54"""
55
56with pytest.raises(FisherVectorException):57learn_gmm([np.zeros((5, 10)), np.zeros((4, 10))], n_modes='not_valid')58
59
60def test_gmm_wrong_num_modes_format_2():61"""Test that FisherVectorException is raised when a number that is not a62positive integer is passed into the n_modes argument of learn_gmm.
63"""
64
65with pytest.raises(FisherVectorException):66learn_gmm([np.zeros((5, 10)), np.zeros((4, 10))], n_modes=-1)67
68
69def test_gmm_wrong_covariance_type():70"""Test that FisherVectorException is raised when wrong covariance type is71passed in as a keyword argument.
72"""
73
74with pytest.raises(FisherVectorException):75learn_gmm(76np.random.random((10, 10)), n_modes=2, gm_args={'covariance_type': 'full'}77)78
79
80def test_gmm_correct_covariance_type():81"""Test that GMM estimation is successful when the correct covariance type82is passed in as a keyword argument.
83"""
84
85gmm = learn_gmm(86np.random.random((10, 10)), n_modes=2, gm_args={'covariance_type': 'diag'}87)88
89assert gmm.means_ is not None90assert gmm.covariances_ is not None91assert gmm.weights_ is not None92
93
94def test_gmm_e2e():95"""96Test the GMM estimation. Since this is essentially a wrapper for the
97scikit-learn GaussianMixture class, the testing of the actual inner
98workings of the GMM estimation is left to scikit-learn and its
99dependencies.
100
101We instead simply assert that the estimation was successful based on the
102fact that the GMM object will have associated mixture weights, means, and
103variances after estimation is successful/complete.
104"""
105
106gmm = learn_gmm(np.random.random((100, 64)), n_modes=5)107
108assert gmm.means_ is not None109assert gmm.covariances_ is not None110assert gmm.weights_ is not None111
112
113def test_fv_wrong_descriptor_types():114"""115Test that DescriptorException is raised when the incorrect type for the
116descriptors is passed into the fisher_vector function.
117"""
118try:119from sklearn.mixture import GaussianMixture120except ImportError:121print(122'scikit-learn is not installed. Please ensure it is installed in '123'order to use the Fisher vector functionality.'124)125
126with pytest.raises(DescriptorException):127fisher_vector([[1, 2, 3, 4]], GaussianMixture())128
129
130def test_fv_wrong_gmm_type():131"""132Test that FisherVectorException is raised when a GMM not of type
133sklearn.mixture.GaussianMixture is passed into the fisher_vector
134function.
135"""
136
137class MyDifferentGaussianMixture:138pass139
140with pytest.raises(FisherVectorException):141fisher_vector(np.zeros((10, 10)), MyDifferentGaussianMixture())142
143
144def test_fv_e2e():145"""146Test the Fisher vector computation given a GMM returned from the learn_gmm
147function. We simply assert that the dimensionality of the resulting Fisher
148vector is correct.
149
150The dimensionality of a Fisher vector is given by 2KD + K, where K is the
151number of Gaussians specified in the associated GMM, and D is the
152dimensionality of the descriptors using to estimate the GMM.
153"""
154
155dim = 128156num_modes = 8157
158expected_dim = 2 * num_modes * dim + num_modes159
160descriptors = [np.random.random((np.random.randint(5, 30), dim)) for _ in range(10)]161
162gmm = learn_gmm(descriptors, n_modes=num_modes)163
164fisher_vec = fisher_vector(descriptors[0], gmm)165
166assert len(fisher_vec) == expected_dim167
168
169def test_fv_e2e_improved():170"""171Test the improved Fisher vector computation given a GMM returned from the
172learn_gmm function. We simply assert that the dimensionality of the
173resulting Fisher vector is correct.
174
175The dimensionality of a Fisher vector is given by 2KD + K, where K is the
176number of Gaussians specified in the associated GMM, and D is the
177dimensionality of the descriptors using to estimate the GMM.
178"""
179
180dim = 128181num_modes = 8182
183expected_dim = 2 * num_modes * dim + num_modes184
185descriptors = [np.random.random((np.random.randint(5, 30), dim)) for _ in range(10)]186
187gmm = learn_gmm(descriptors, n_modes=num_modes)188
189fisher_vec = fisher_vector(descriptors[0], gmm, improved=True)190
191assert len(fisher_vec) == expected_dim192