pytorch
272 строки · 9.7 Кб
1
2
3
4
5import numpy as np6from hypothesis import given, assume, settings7import hypothesis.strategies as st8
9from caffe2.python import core, model_helper, brew, utils10import caffe2.python.hypothesis_test_util as hu11import caffe2.python.serialized_test.serialized_test_util as serial12
13import unittest14
15
16class TestInstanceNorm(serial.SerializedTestCase):17
18def _get_inputs(self, N, C, H, W, order):19input_data = np.random.rand(N, C, H, W).astype(np.float32)20if order == 'NHWC':21# Allocate in the same order as NCHW and transpose to make sure22# the inputs are identical on freshly-seeded calls.23input_data = utils.NCHW2NHWC(input_data)24elif order != "NCHW":25raise Exception('unknown order type ({})'.format(order))26
27scale_data = np.random.rand(C).astype(np.float32)28bias_data = np.random.rand(C).astype(np.float32)29return input_data, scale_data, bias_data30
31def _get_op(self, device_option, store_mean, store_inv_stdev, epsilon,32order, inplace=False):33outputs = ['output' if not inplace else "input"]34if store_mean or store_inv_stdev:35outputs += ['mean']36if store_inv_stdev:37outputs += ['inv_stdev']38op = core.CreateOperator(39'InstanceNorm',40['input', 'scale', 'bias'],41outputs,42order=order,43epsilon=epsilon,44device_option=device_option)45return op46
47def _feed_inputs(self, input_blobs, device_option):48names = ['input', 'scale', 'bias']49for name, blob in zip(names, input_blobs):50self.ws.create_blob(name).feed(blob, device_option=device_option)51
52@given(gc=hu.gcs['gc'],53dc=hu.gcs['dc'],54N=st.integers(1, 4),55C=st.integers(1, 4),56H=st.integers(2, 4),57W=st.integers(2, 4),58order=st.sampled_from(['NCHW', 'NHWC']),59epsilon=st.floats(1e-6, 1e-4),60store_mean=st.booleans(),61seed=st.integers(0, 1000),62store_inv_stdev=st.booleans())63@settings(deadline=10000)64def test_instance_norm_gradients(65self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,66epsilon, seed):67np.random.seed(seed)68
69# force store_inv_stdev if store_mean to match existing forward pass70# implementation71store_inv_stdev |= store_mean72
73op = self._get_op(74device_option=gc,75store_mean=store_mean,76store_inv_stdev=store_inv_stdev,77epsilon=epsilon,78order=order)79
80input_data = np.arange(N * C * H * W).astype(np.float32)81np.random.shuffle(input_data)82if order == "NCHW":83input_data = input_data.reshape(N, C, H, W)84else:85input_data = input_data.reshape(N, H, W, C)86scale_data = np.random.randn(C).astype(np.float32)87bias_data = np.random.randn(C).astype(np.float32)88input_blobs = (input_data, scale_data, bias_data)89
90output_indices = [0]91# if store_inv_stdev is turned on, store_mean must also be forced on92if store_mean or store_inv_stdev:93output_indices += [1]94if store_inv_stdev:95output_indices += [2]96self.assertDeviceChecks(dc, op, input_blobs, output_indices)97# The gradient only flows from output #0 since the other two only98# store the temporary mean and inv_stdev buffers.99# Check dl/dinput100self.assertGradientChecks(gc, op, input_blobs, 0, [0])101# Check dl/dscale102self.assertGradientChecks(gc, op, input_blobs, 1, [0])103# Check dl/dbias104self.assertGradientChecks(gc, op, input_blobs, 2, [0])105
106@given(gc=hu.gcs['gc'],107dc=hu.gcs['dc'],108N=st.integers(2, 10),109C=st.integers(3, 10),110H=st.integers(5, 10),111W=st.integers(7, 10),112seed=st.integers(0, 1000),113epsilon=st.floats(1e-6, 1e-4),114store_mean=st.booleans(),115store_inv_stdev=st.booleans())116def test_instance_norm_layout(self, gc, dc, N, C, H, W, store_mean,117store_inv_stdev, epsilon, seed):118# force store_inv_stdev if store_mean to match existing forward pass119# implementation120store_inv_stdev |= store_mean121
122outputs = {}123for order in ('NCHW', 'NHWC'):124np.random.seed(seed)125input_blobs = self._get_inputs(N, C, H, W, order)126self._feed_inputs(input_blobs, device_option=gc)127op = self._get_op(128device_option=gc,129store_mean=store_mean,130store_inv_stdev=store_inv_stdev,131epsilon=epsilon,132order=order)133self.ws.run(op)134outputs[order] = self.ws.blobs['output'].fetch()135np.testing.assert_allclose(136outputs['NCHW'],137utils.NHWC2NCHW(outputs["NHWC"]),138atol=1e-4,139rtol=1e-4)140
141@serial.given(gc=hu.gcs['gc'],142dc=hu.gcs['dc'],143N=st.integers(2, 10),144C=st.integers(3, 10),145H=st.integers(5, 10),146W=st.integers(7, 10),147order=st.sampled_from(['NCHW', 'NHWC']),148epsilon=st.floats(1e-6, 1e-4),149store_mean=st.booleans(),150seed=st.integers(0, 1000),151store_inv_stdev=st.booleans(),152inplace=st.booleans())153def test_instance_norm_reference_check(154self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,155epsilon, seed, inplace):156np.random.seed(seed)157
158# force store_inv_stdev if store_mean to match existing forward pass159# implementation160store_inv_stdev |= store_mean161if order != "NCHW":162assume(not inplace)163
164inputs = self._get_inputs(N, C, H, W, order)165op = self._get_op(166device_option=gc,167store_mean=store_mean,168store_inv_stdev=store_inv_stdev,169epsilon=epsilon,170order=order,171inplace=inplace)172
173def ref(input_blob, scale_blob, bias_blob):174if order == 'NHWC':175input_blob = utils.NHWC2NCHW(input_blob)176
177mean_blob = input_blob.reshape((N, C, -1)).mean(axis=2)178inv_stdev_blob = 1.0 / \179np.sqrt(input_blob.reshape((N, C, -1)).var(axis=2) + epsilon)180# _bc indicates blobs that are reshaped for broadcast181scale_bc = scale_blob[np.newaxis, :, np.newaxis, np.newaxis]182mean_bc = mean_blob[:, :, np.newaxis, np.newaxis]183inv_stdev_bc = inv_stdev_blob[:, :, np.newaxis, np.newaxis]184bias_bc = bias_blob[np.newaxis, :, np.newaxis, np.newaxis]185normalized_blob = scale_bc * (input_blob - mean_bc) * inv_stdev_bc \186+ bias_bc187
188if order == 'NHWC':189normalized_blob = utils.NCHW2NHWC(normalized_blob)190
191if not store_mean and not store_inv_stdev:192return normalized_blob,193elif not store_inv_stdev:194return normalized_blob, mean_blob195else:196return normalized_blob, mean_blob, inv_stdev_blob197
198self.assertReferenceChecks(gc, op, inputs, ref)199
200@given(gc=hu.gcs['gc'],201dc=hu.gcs['dc'],202N=st.integers(2, 10),203C=st.integers(3, 10),204H=st.integers(5, 10),205W=st.integers(7, 10),206order=st.sampled_from(['NCHW', 'NHWC']),207epsilon=st.floats(1e-6, 1e-4),208store_mean=st.booleans(),209seed=st.integers(0, 1000),210store_inv_stdev=st.booleans())211def test_instance_norm_device_check(212self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev,213epsilon, seed):214np.random.seed(seed)215
216# force store_inv_stdev if store_mean to match existing forward pass217# implementation218store_inv_stdev |= store_mean219
220inputs = self._get_inputs(N, C, H, W, order)221op = self._get_op(222device_option=gc,223store_mean=store_mean,224store_inv_stdev=store_inv_stdev,225epsilon=epsilon,226order=order)227
228self.assertDeviceChecks(dc, op, inputs, [0])229
230@given(is_test=st.booleans(),231N=st.integers(2, 10),232C=st.integers(3, 10),233H=st.integers(5, 10),234W=st.integers(7, 10),235order=st.sampled_from(['NCHW', 'NHWC']),236epsilon=st.floats(1e-6, 1e-4),237seed=st.integers(0, 1000))238def test_instance_norm_model_helper(239self, N, C, H, W, order, epsilon, seed, is_test):240np.random.seed(seed)241model = model_helper.ModelHelper(name="test_model")242brew.instance_norm(243model,244'input',245'output',246C,247epsilon=epsilon,248order=order,249is_test=is_test)250
251input_blob = np.random.rand(N, C, H, W).astype(np.float32)252if order == 'NHWC':253input_blob = utils.NCHW2NHWC(input_blob)254
255self.ws.create_blob('input').feed(input_blob)256
257self.ws.create_net(model.param_init_net).run()258self.ws.create_net(model.net).run()259
260if is_test:261scale = self.ws.blobs['output_s'].fetch()262assert scale is not None263assert scale.shape == (C, )264bias = self.ws.blobs['output_b'].fetch()265assert bias is not None266assert bias.shape == (C, )267
268output_blob = self.ws.blobs['output'].fetch()269if order == 'NHWC':270output_blob = utils.NHWC2NCHW(output_blob)271
272assert output_blob.shape == (N, C, H, W)273
274
275if __name__ == '__main__':276unittest.main()277