pytorch
84 строки · 3.4 Кб
1
2
3
4
5
6from caffe2.python import core7from caffe2.python.test_util import rand_array8import caffe2.python.hypothesis_test_util as hu9import caffe2.python.serialized_test.serialized_test_util as serial10from hypothesis import given, settings11import hypothesis.strategies as st12import numpy as np13
14class TestScatterOps(serial.SerializedTestCase):15# TODO(dzhulgakov): add test cases for failure scenarios16@given(num_args=st.integers(1, 5),17first_dim=st.integers(1, 20),18index_dim=st.integers(1, 10),19extra_dims=st.lists(st.integers(1, 4), min_size=0, max_size=3),20ind_type=st.sampled_from([np.int32, np.int64]),21data_type=st.sampled_from([np.float32, np.float64]),22**hu.gcs)23@settings(deadline=10000)24def testScatterWeightedSum(25self, num_args, first_dim, index_dim, extra_dims, ind_type, data_type, gc, dc):26ins = ['data', 'w0', 'indices']27for i in range(1, num_args + 1):28ins.extend(['x' + str(i), 'w' + str(i)])29op = core.CreateOperator(30'ScatterWeightedSum',31ins,32['data'],33device_option=gc)34def ref(d, w0, ind, *args):35r = d.copy()36for i in ind:37r[i] *= w038for i in range(0, len(args), 2):39x = args[i]40w = args[i+1]41for i, j in enumerate(ind):42r[j] += w * x[i]43return [r]44
45d = rand_array(first_dim, *extra_dims)46ind = np.random.randint(0, first_dim, index_dim).astype(ind_type)47# ScatterWeightedSumOp only supports w0=1.0 in CUDAContext48# And it only support float32 data in CUDAContext49if(gc == hu.gpu_do or gc == hu.hip_do):50w0 = np.array(1.0).astype(np.float32)51data_type = np.float3252else:53w0 = rand_array()54d = d.astype(data_type)55inputs = [d, w0, ind]56for _ in range(1, num_args + 1):57x = rand_array(index_dim, *extra_dims).astype(data_type)58w = rand_array()59inputs.extend([x,w])60self.assertReferenceChecks(gc, op, inputs, ref, threshold=1e-3)61
62@given(first_dim=st.integers(1, 20),63index_dim=st.integers(1, 10),64extra_dims=st.lists(st.integers(1, 4), min_size=0, max_size=3),65data_type=st.sampled_from([np.float16, np.float32, np.int32, np.int64]),66ind_type=st.sampled_from([np.int32, np.int64]),67**hu.gcs)68@settings(deadline=10000)69def testScatterAssign(70self, first_dim, index_dim, extra_dims, data_type, ind_type, gc, dc):71op = core.CreateOperator('ScatterAssign',72['data', 'indices', 'slices'], ['data'])73def ref(d, ind, x):74r = d.copy()75r[ind] = x76return [r]77
78# let's have indices unique79if first_dim < index_dim:80first_dim, index_dim = index_dim, first_dim81d = (rand_array(first_dim, *extra_dims) * 10).astype(data_type)82ind = np.random.choice(first_dim, index_dim,83replace=False).astype(ind_type)84x = (rand_array(index_dim, *extra_dims) * 10).astype(data_type)85self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3, ensure_outputs_are_inferred=True)86
87if __name__ == "__main__":88import unittest89unittest.main()90