pytorch
266 строк · 8.3 Кб
1
2
3
4
5
6from caffe2.proto import caffe2_pb27from caffe2.python import core, workspace8import caffe2.python.hypothesis_test_util as hu9import caffe2.python.serialized_test.serialized_test_util as serial10
11from hypothesis import given, settings12import hypothesis.strategies as st13import numpy as np14
15
16def _fill_diagonal(shape, value):17result = np.zeros(shape)18np.fill_diagonal(result, value)19return (result,)20
21
22class TestFillerOperator(serial.SerializedTestCase):23
24@given(**hu.gcs)25@settings(deadline=10000)26def test_shape_error(self, gc, dc):27op = core.CreateOperator(28'GaussianFill',29[],30'out',31shape=32, # illegal parameter32mean=0.0,33std=1.0,34)35exception = False36try:37workspace.RunOperatorOnce(op)38except Exception:39exception = True40self.assertTrue(exception, "Did not throw exception on illegal shape")41
42op = core.CreateOperator(43'ConstantFill',44[],45'out',46shape=[], # scalar47value=2.0,48)49exception = False50self.assertTrue(workspace.RunOperatorOnce(op))51self.assertEqual(workspace.FetchBlob('out'), [2.0])52
53@given(**hu.gcs)54@settings(deadline=10000)55def test_int64_shape(self, gc, dc):56large_dim = 2 ** 31 + 157net = core.Net("test_shape_net")58net.UniformFill(59[],60'out',61shape=[0, large_dim],62min=0.0,63max=1.0,64)65self.assertTrue(workspace.CreateNet(net))66self.assertTrue(workspace.RunNet(net.Name()))67self.assertEqual(workspace.blobs['out'].shape, (0, large_dim))68
69@given(70shape=hu.dims().flatmap(71lambda dims: hu.arrays(72[dims], dtype=np.int64,73elements=st.integers(min_value=0, max_value=20)74)75),76a=st.integers(min_value=0, max_value=100),77b=st.integers(min_value=0, max_value=100),78**hu.gcs79)80@settings(deadline=10000)81def test_uniform_int_fill_op_blob_input(self, shape, a, b, gc, dc):82net = core.Net('test_net')83
84with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):85shape_blob = net.Const(shape, dtype=np.int64)86a_blob = net.Const(a, dtype=np.int32)87b_blob = net.Const(b, dtype=np.int32)88uniform_fill = net.UniformIntFill([shape_blob, a_blob, b_blob],891, input_as_shape=1)90
91workspace.RunNetOnce(net)92
93blob_out = workspace.FetchBlob(uniform_fill)94if b < a:95new_shape = shape[:]96new_shape[0] = 097np.testing.assert_array_equal(new_shape, blob_out.shape)98else:99np.testing.assert_array_equal(shape, blob_out.shape)100self.assertTrue((blob_out >= a).all())101self.assertTrue((blob_out <= b).all())102
103@given(104**hu.gcs105)106def test_uniform_fill_using_arg(self, gc, dc):107net = core.Net('test_net')108shape = [2**3, 5]109# uncomment this to test filling large blob110# shape = [2**30, 5]111min_v = -100112max_v = 100113output_blob = net.UniformIntFill(114[],115['output_blob'],116shape=shape,117min=min_v,118max=max_v,119)120
121workspace.RunNetOnce(net)122output_data = workspace.FetchBlob(output_blob)123
124np.testing.assert_array_equal(shape, output_data.shape)125min_data = np.min(output_data)126max_data = np.max(output_data)127
128self.assertGreaterEqual(min_data, min_v)129self.assertLessEqual(max_data, max_v)130
131self.assertNotEqual(min_data, max_data)132
133@serial.given(134shape=st.sampled_from(135[136[3, 3],137[5, 5, 5],138[7, 7, 7, 7],139]140),141**hu.gcs142)143def test_diagonal_fill_op_float(self, shape, gc, dc):144value = 2.5145op = core.CreateOperator(146'DiagonalFill',147[],148'out',149shape=shape, # scalar150value=value,151)152
153for device_option in dc:154op.device_option.CopyFrom(device_option)155# Check against numpy reference156self.assertReferenceChecks(gc, op, [shape, value], _fill_diagonal)157
158@given(**hu.gcs)159def test_diagonal_fill_op_int(self, gc, dc):160value = 2161shape = [3, 3]162op = core.CreateOperator(163'DiagonalFill',164[],165'out',166shape=shape,167dtype=core.DataType.INT32,168value=value,169)170
171# Check against numpy reference172self.assertReferenceChecks(gc, op, [shape, value], _fill_diagonal)173
174@serial.given(lengths=st.lists(st.integers(min_value=0, max_value=10),175min_size=0,176max_size=10),177**hu.gcs)178def test_lengths_range_fill(self, lengths, gc, dc):179op = core.CreateOperator(180"LengthsRangeFill",181["lengths"],182["increasing_seq"])183
184def _len_range_fill(lengths):185sids = []186for _, l in enumerate(lengths):187sids.extend(list(range(l)))188return (np.array(sids, dtype=np.int32), )189
190self.assertReferenceChecks(191device_option=gc,192op=op,193inputs=[np.array(lengths, dtype=np.int32)],194reference=_len_range_fill)195
196@given(**hu.gcs)197def test_gaussian_fill_op(self, gc, dc):198op = core.CreateOperator(199'GaussianFill',200[],201'out',202shape=[17, 3, 3], # sample odd dimensions203mean=0.0,204std=1.0,205)206
207for device_option in dc:208op.device_option.CopyFrom(device_option)209assert workspace.RunOperatorOnce(op), "GaussianFill op did not run "210"successfully"211
212blob_out = workspace.FetchBlob('out')213assert np.count_nonzero(blob_out) > 0, "All generated elements are "214"zeros. Is the random generator functioning correctly?"215
216@given(**hu.gcs)217def test_msra_fill_op(self, gc, dc):218op = core.CreateOperator(219'MSRAFill',220[],221'out',222shape=[15, 5, 3], # sample odd dimensions223)224for device_option in dc:225op.device_option.CopyFrom(device_option)226assert workspace.RunOperatorOnce(op), "MSRAFill op did not run "227"successfully"228
229blob_out = workspace.FetchBlob('out')230assert np.count_nonzero(blob_out) > 0, "All generated elements are "231"zeros. Is the random generator functioning correctly?"232
233@given(234min=st.integers(min_value=0, max_value=5),235range=st.integers(min_value=1, max_value=10),236emb_size=st.sampled_from((10000, 20000, 30000)),237dim_size=st.sampled_from((16, 32, 64)),238**hu.gcs)239@settings(deadline=None)240def test_fp16_uniformfill_op(self, min, range, emb_size, dim_size, gc, dc):241op = core.CreateOperator(242'Float16UniformFill',243[],244'out',245shape=[emb_size, dim_size],246min=float(min),247max=float(min + range),248)249for device_option in dc:250op.device_option.CopyFrom(device_option)251assert workspace.RunOperatorOnce(op), "Float16UniformFill op did not run successfully"252
253self.assertEqual(workspace.blobs['out'].shape, (emb_size, dim_size))254
255blob_out = workspace.FetchBlob('out')256
257expected_type = "float16"258expected_mean = min + range / 2.0259expected_var = range * range / 12.0260expected_min = min261expected_max = min + range262
263self.assertEqual(blob_out.dtype.name, expected_type)264self.assertAlmostEqual(np.mean(blob_out, dtype=np.float32), expected_mean, delta=0.1)265self.assertAlmostEqual(np.var(blob_out, dtype=np.float32), expected_var, delta=0.1)266self.assertGreaterEqual(np.min(blob_out), expected_min)267self.assertLessEqual(np.max(blob_out), expected_max)268
269if __name__ == "__main__":270import unittest271unittest.main()272