pytorch

Форк
0
/
instance_norm_test.py 
272 строки · 9.7 Кб
1

2

3

4

5
import numpy as np
6
from hypothesis import given, assume, settings
7
import hypothesis.strategies as st
8

9
from caffe2.python import core, model_helper, brew, utils
10
import caffe2.python.hypothesis_test_util as hu
11
import caffe2.python.serialized_test.serialized_test_util as serial
12

13
import unittest
14

15

16
class TestInstanceNorm(serial.SerializedTestCase):
17

18
    def _get_inputs(self, N, C, H, W, order):
19
        input_data = np.random.rand(N, C, H, W).astype(np.float32)
20
        if order == 'NHWC':
21
            # Allocate in the same order as NCHW and transpose to make sure
22
            # the inputs are identical on freshly-seeded calls.
23
            input_data = utils.NCHW2NHWC(input_data)
24
        elif order != "NCHW":
25
            raise Exception('unknown order type ({})'.format(order))
26

27
        scale_data = np.random.rand(C).astype(np.float32)
28
        bias_data = np.random.rand(C).astype(np.float32)
29
        return input_data, scale_data, bias_data
30

31
    def _get_op(self, device_option, store_mean, store_inv_stdev, epsilon,
32
                order, inplace=False):
33
        outputs = ['output' if not inplace else "input"]
34
        if store_mean or store_inv_stdev:
35
            outputs += ['mean']
36
        if store_inv_stdev:
37
            outputs += ['inv_stdev']
38
        op = core.CreateOperator(
39
            'InstanceNorm',
40
            ['input', 'scale', 'bias'],
41
            outputs,
42
            order=order,
43
            epsilon=epsilon,
44
            device_option=device_option)
45
        return op
46

47
    def _feed_inputs(self, input_blobs, device_option):
48
        names = ['input', 'scale', 'bias']
49
        for name, blob in zip(names, input_blobs):
50
            self.ws.create_blob(name).feed(blob, device_option=device_option)
51

52
    @given(gc=hu.gcs['gc'],
53
           dc=hu.gcs['dc'],
54
           N=st.integers(1, 4),
55
           C=st.integers(1, 4),
56
           H=st.integers(2, 4),
57
           W=st.integers(2, 4),
58
           order=st.sampled_from(['NCHW', 'NHWC']),
59
           epsilon=st.floats(1e-6, 1e-4),
60
           store_mean=st.booleans(),
61
           seed=st.integers(0, 1000),
62
           store_inv_stdev=st.booleans())
63
    @settings(deadline=10000)
64
    def test_instance_norm_gradients(
65
            self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,
66
            epsilon, seed):
67
        np.random.seed(seed)
68

69
        # force store_inv_stdev if store_mean to match existing forward pass
70
        # implementation
71
        store_inv_stdev |= store_mean
72

73
        op = self._get_op(
74
            device_option=gc,
75
            store_mean=store_mean,
76
            store_inv_stdev=store_inv_stdev,
77
            epsilon=epsilon,
78
            order=order)
79

80
        input_data = np.arange(N * C * H * W).astype(np.float32)
81
        np.random.shuffle(input_data)
82
        if order == "NCHW":
83
            input_data = input_data.reshape(N, C, H, W)
84
        else:
85
            input_data = input_data.reshape(N, H, W, C)
86
        scale_data = np.random.randn(C).astype(np.float32)
87
        bias_data = np.random.randn(C).astype(np.float32)
88
        input_blobs = (input_data, scale_data, bias_data)
89

90
        output_indices = [0]
91
        # if store_inv_stdev is turned on, store_mean must also be forced on
92
        if store_mean or store_inv_stdev:
93
            output_indices += [1]
94
        if store_inv_stdev:
95
            output_indices += [2]
96
        self.assertDeviceChecks(dc, op, input_blobs, output_indices)
97
        # The gradient only flows from output #0 since the other two only
98
        # store the temporary mean and inv_stdev buffers.
99
        # Check dl/dinput
100
        self.assertGradientChecks(gc, op, input_blobs, 0, [0])
101
        # Check dl/dscale
102
        self.assertGradientChecks(gc, op, input_blobs, 1, [0])
103
        # Check dl/dbias
104
        self.assertGradientChecks(gc, op, input_blobs, 2, [0])
105

106
    @given(gc=hu.gcs['gc'],
107
           dc=hu.gcs['dc'],
108
           N=st.integers(2, 10),
109
           C=st.integers(3, 10),
110
           H=st.integers(5, 10),
111
           W=st.integers(7, 10),
112
           seed=st.integers(0, 1000),
113
           epsilon=st.floats(1e-6, 1e-4),
114
           store_mean=st.booleans(),
115
           store_inv_stdev=st.booleans())
116
    def test_instance_norm_layout(self, gc, dc, N, C, H, W, store_mean,
117
                                  store_inv_stdev, epsilon, seed):
118
        # force store_inv_stdev if store_mean to match existing forward pass
119
        # implementation
120
        store_inv_stdev |= store_mean
121

122
        outputs = {}
123
        for order in ('NCHW', 'NHWC'):
124
            np.random.seed(seed)
125
            input_blobs = self._get_inputs(N, C, H, W, order)
126
            self._feed_inputs(input_blobs, device_option=gc)
127
            op = self._get_op(
128
                device_option=gc,
129
                store_mean=store_mean,
130
                store_inv_stdev=store_inv_stdev,
131
                epsilon=epsilon,
132
                order=order)
133
            self.ws.run(op)
134
            outputs[order] = self.ws.blobs['output'].fetch()
135
        np.testing.assert_allclose(
136
            outputs['NCHW'],
137
            utils.NHWC2NCHW(outputs["NHWC"]),
138
            atol=1e-4,
139
            rtol=1e-4)
140

141
    @serial.given(gc=hu.gcs['gc'],
142
           dc=hu.gcs['dc'],
143
           N=st.integers(2, 10),
144
           C=st.integers(3, 10),
145
           H=st.integers(5, 10),
146
           W=st.integers(7, 10),
147
           order=st.sampled_from(['NCHW', 'NHWC']),
148
           epsilon=st.floats(1e-6, 1e-4),
149
           store_mean=st.booleans(),
150
           seed=st.integers(0, 1000),
151
           store_inv_stdev=st.booleans(),
152
           inplace=st.booleans())
153
    def test_instance_norm_reference_check(
154
            self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,
155
            epsilon, seed, inplace):
156
        np.random.seed(seed)
157

158
        # force store_inv_stdev if store_mean to match existing forward pass
159
        # implementation
160
        store_inv_stdev |= store_mean
161
        if order != "NCHW":
162
            assume(not inplace)
163

164
        inputs = self._get_inputs(N, C, H, W, order)
165
        op = self._get_op(
166
            device_option=gc,
167
            store_mean=store_mean,
168
            store_inv_stdev=store_inv_stdev,
169
            epsilon=epsilon,
170
            order=order,
171
            inplace=inplace)
172

173
        def ref(input_blob, scale_blob, bias_blob):
174
            if order == 'NHWC':
175
                input_blob = utils.NHWC2NCHW(input_blob)
176

177
            mean_blob = input_blob.reshape((N, C, -1)).mean(axis=2)
178
            inv_stdev_blob = 1.0 / \
179
                np.sqrt(input_blob.reshape((N, C, -1)).var(axis=2) + epsilon)
180
            # _bc indicates blobs that are reshaped for broadcast
181
            scale_bc = scale_blob[np.newaxis, :, np.newaxis, np.newaxis]
182
            mean_bc = mean_blob[:, :, np.newaxis, np.newaxis]
183
            inv_stdev_bc = inv_stdev_blob[:, :, np.newaxis, np.newaxis]
184
            bias_bc = bias_blob[np.newaxis, :, np.newaxis, np.newaxis]
185
            normalized_blob = scale_bc * (input_blob - mean_bc) * inv_stdev_bc \
186
                + bias_bc
187

188
            if order == 'NHWC':
189
                normalized_blob = utils.NCHW2NHWC(normalized_blob)
190

191
            if not store_mean and not store_inv_stdev:
192
                return normalized_blob,
193
            elif not store_inv_stdev:
194
                return normalized_blob, mean_blob
195
            else:
196
                return normalized_blob, mean_blob, inv_stdev_blob
197

198
        self.assertReferenceChecks(gc, op, inputs, ref)
199

200
    @given(gc=hu.gcs['gc'],
201
           dc=hu.gcs['dc'],
202
           N=st.integers(2, 10),
203
           C=st.integers(3, 10),
204
           H=st.integers(5, 10),
205
           W=st.integers(7, 10),
206
           order=st.sampled_from(['NCHW', 'NHWC']),
207
           epsilon=st.floats(1e-6, 1e-4),
208
           store_mean=st.booleans(),
209
           seed=st.integers(0, 1000),
210
           store_inv_stdev=st.booleans())
211
    def test_instance_norm_device_check(
212
            self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,
213
            epsilon, seed):
214
        np.random.seed(seed)
215

216
        # force store_inv_stdev if store_mean to match existing forward pass
217
        # implementation
218
        store_inv_stdev |= store_mean
219

220
        inputs = self._get_inputs(N, C, H, W, order)
221
        op = self._get_op(
222
            device_option=gc,
223
            store_mean=store_mean,
224
            store_inv_stdev=store_inv_stdev,
225
            epsilon=epsilon,
226
            order=order)
227

228
        self.assertDeviceChecks(dc, op, inputs, [0])
229

230
    @given(is_test=st.booleans(),
231
           N=st.integers(2, 10),
232
           C=st.integers(3, 10),
233
           H=st.integers(5, 10),
234
           W=st.integers(7, 10),
235
           order=st.sampled_from(['NCHW', 'NHWC']),
236
           epsilon=st.floats(1e-6, 1e-4),
237
           seed=st.integers(0, 1000))
238
    def test_instance_norm_model_helper(
239
            self, N, C, H, W, order, epsilon, seed, is_test):
240
        np.random.seed(seed)
241
        model = model_helper.ModelHelper(name="test_model")
242
        brew.instance_norm(
243
            model,
244
            'input',
245
            'output',
246
            C,
247
            epsilon=epsilon,
248
            order=order,
249
            is_test=is_test)
250

251
        input_blob = np.random.rand(N, C, H, W).astype(np.float32)
252
        if order == 'NHWC':
253
            input_blob = utils.NCHW2NHWC(input_blob)
254

255
        self.ws.create_blob('input').feed(input_blob)
256

257
        self.ws.create_net(model.param_init_net).run()
258
        self.ws.create_net(model.net).run()
259

260
        if is_test:
261
            scale = self.ws.blobs['output_s'].fetch()
262
            assert scale is not None
263
            assert scale.shape == (C, )
264
            bias = self.ws.blobs['output_b'].fetch()
265
            assert bias is not None
266
            assert bias.shape == (C, )
267

268
        output_blob = self.ws.blobs['output'].fetch()
269
        if order == 'NHWC':
270
            output_blob = utils.NHWC2NCHW(output_blob)
271

272
        assert output_blob.shape == (N, C, H, W)
273

274

275
if __name__ == '__main__':
276
    unittest.main()
277

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.