pytorch
173 строки · 5.5 Кб
1
2
3
4
5import numpy as np6from hypothesis import given, assume7import hypothesis.strategies as st8
9from caffe2.python import core, model_helper, utils10import caffe2.python.hypothesis_test_util as hu11
12
13class TestLeakyRelu(hu.HypothesisTestCase):14
15def _get_inputs(self, N, C, H, W, order):16input_data = np.random.rand(N, C, H, W).astype(np.float32) - 0.517
18# default step size is 0.0519input_data[np.logical_and(20input_data >= 0, input_data <= 0.051)] = 0.05121input_data[np.logical_and(22input_data <= 0, input_data >= -0.051)] = -0.05123
24if order == 'NHWC':25input_data = utils.NCHW2NHWC(input_data)26
27return input_data,28
29def _get_op(self, device_option, alpha, order, inplace=False):30outputs = ['output' if not inplace else "input"]31op = core.CreateOperator(32'LeakyRelu',33['input'],34outputs,35alpha=alpha,36device_option=device_option)37return op38
39def _feed_inputs(self, input_blobs, device_option):40names = ['input', 'scale', 'bias']41for name, blob in zip(names, input_blobs):42self.ws.create_blob(name).feed(blob, device_option=device_option)43
44@given(gc=hu.gcs['gc'],45dc=hu.gcs['dc'],46N=st.integers(2, 3),47C=st.integers(2, 3),48H=st.integers(2, 3),49W=st.integers(2, 3),50alpha=st.floats(0, 1),51order=st.sampled_from(['NCHW', 'NHWC']),52seed=st.integers(0, 1000))53def test_leaky_relu_gradients(self, gc, dc, N, C, H, W, order, alpha, seed):54np.random.seed(seed)55
56op = self._get_op(57device_option=gc,58alpha=alpha,59order=order)60input_blobs = self._get_inputs(N, C, H, W, order)61
62self.assertDeviceChecks(dc, op, input_blobs, [0])63self.assertGradientChecks(gc, op, input_blobs, 0, [0])64
65@given(gc=hu.gcs['gc'],66dc=hu.gcs['dc'],67N=st.integers(2, 10),68C=st.integers(3, 10),69H=st.integers(5, 10),70W=st.integers(7, 10),71alpha=st.floats(0, 1),72seed=st.integers(0, 1000))73def test_leaky_relu_layout(self, gc, dc, N, C, H, W, alpha, seed):74outputs = {}75for order in ('NCHW', 'NHWC'):76np.random.seed(seed)77input_blobs = self._get_inputs(N, C, H, W, order)78self._feed_inputs(input_blobs, device_option=gc)79op = self._get_op(80device_option=gc,81alpha=alpha,82order=order)83self.ws.run(op)84outputs[order] = self.ws.blobs['output'].fetch()85np.testing.assert_allclose(86outputs['NCHW'],87utils.NHWC2NCHW(outputs["NHWC"]),88atol=1e-4,89rtol=1e-4)90
91@given(gc=hu.gcs['gc'],92dc=hu.gcs['dc'],93N=st.integers(2, 10),94C=st.integers(3, 10),95H=st.integers(5, 10),96W=st.integers(7, 10),97order=st.sampled_from(['NCHW', 'NHWC']),98alpha=st.floats(0, 1),99seed=st.integers(0, 1000),100inplace=st.booleans())101def test_leaky_relu_reference_check(self, gc, dc, N, C, H, W, order, alpha,102seed, inplace):103np.random.seed(seed)104
105if order != "NCHW":106assume(not inplace)107
108inputs = self._get_inputs(N, C, H, W, order)109op = self._get_op(110device_option=gc,111alpha=alpha,112order=order,113inplace=inplace)114
115def ref(input_blob):116result = input_blob.copy()117result[result < 0] *= alpha118return result,119
120self.assertReferenceChecks(gc, op, inputs, ref)121
122@given(gc=hu.gcs['gc'],123dc=hu.gcs['dc'],124N=st.integers(2, 10),125C=st.integers(3, 10),126H=st.integers(5, 10),127W=st.integers(7, 10),128order=st.sampled_from(['NCHW', 'NHWC']),129alpha=st.floats(0, 1),130seed=st.integers(0, 1000))131def test_leaky_relu_device_check(self, gc, dc, N, C, H, W, order, alpha,132seed):133np.random.seed(seed)134
135inputs = self._get_inputs(N, C, H, W, order)136op = self._get_op(137device_option=gc,138alpha=alpha,139order=order)140
141self.assertDeviceChecks(dc, op, inputs, [0])142
143@given(N=st.integers(2, 10),144C=st.integers(3, 10),145H=st.integers(5, 10),146W=st.integers(7, 10),147order=st.sampled_from(['NCHW', 'NHWC']),148alpha=st.floats(0, 1),149seed=st.integers(0, 1000))150def test_leaky_relu_model_helper_helper(self, N, C, H, W, order, alpha, seed):151np.random.seed(seed)152arg_scope = {'order': order}153model = model_helper.ModelHelper(name="test_model", arg_scope=arg_scope)154model.LeakyRelu(155'input',156'output',157alpha=alpha)158
159input_blob = np.random.rand(N, C, H, W).astype(np.float32)160if order == 'NHWC':161input_blob = utils.NCHW2NHWC(input_blob)162
163self.ws.create_blob('input').feed(input_blob)164
165self.ws.create_net(model.param_init_net).run()166self.ws.create_net(model.net).run()167
168output_blob = self.ws.blobs['output'].fetch()169if order == 'NHWC':170output_blob = utils.NHWC2NCHW(output_blob)171
172assert output_blob.shape == (N, C, H, W)173
174
175if __name__ == '__main__':176import unittest177unittest.main()178