pytorch
132 строки · 4.5 Кб
1
2
3
4
5
6from caffe2.python import core7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9from hypothesis import given, settings10import hypothesis.strategies as st11import numpy as np12import unittest13
14
15def mux(select, left, right):16return [np.vectorize(lambda c, x, y: x if c else y)(select, left, right)]17
18
19def rowmux(select_vec, left, right):20select = [[s] * len(left) for s in select_vec]21return mux(select, left, right)22
23
24class TestWhere(serial.SerializedTestCase):25
26def test_reference(self):27self.assertTrue((28np.array([1, 4]) == mux([True, False],29[1, 2],30[3, 4])[0]31).all())32self.assertTrue((33np.array([[1], [4]]) == mux([[True], [False]],34[[1], [2]],35[[3], [4]])[0]36).all())37
38@given(N=st.integers(min_value=1, max_value=10),39engine=st.sampled_from(["", "CUDNN"]),40**hu.gcs_cpu_only)41@settings(deadline=10000)42def test_where(self, N, gc, dc, engine):43C = np.random.rand(N).astype(bool)44X = np.random.rand(N).astype(np.float32)45Y = np.random.rand(N).astype(np.float32)46op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)47self.assertDeviceChecks(dc, op, [C, X, Y], [0])48self.assertReferenceChecks(gc, op, [C, X, Y], mux)49
50@given(N=st.integers(min_value=1, max_value=10),51engine=st.sampled_from(["", "CUDNN"]),52**hu.gcs_cpu_only)53@settings(deadline=10000)54def test_where_dim2(self, N, gc, dc, engine):55C = np.random.rand(N, N).astype(bool)56X = np.random.rand(N, N).astype(np.float32)57Y = np.random.rand(N, N).astype(np.float32)58op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)59self.assertDeviceChecks(dc, op, [C, X, Y], [0])60self.assertReferenceChecks(gc, op, [C, X, Y], mux)61
62
63class TestRowWhere(hu.HypothesisTestCase):64
65def test_reference(self):66self.assertTrue((67np.array([1, 2]) == rowmux([True],68[1, 2],69[3, 4])[0]70).all())71self.assertTrue((72np.array([[1, 2], [7, 8]]) == rowmux([True, False],73[[1, 2], [3, 4]],74[[5, 6], [7, 8]])[0]75).all())76
77@given(N=st.integers(min_value=1, max_value=10),78engine=st.sampled_from(["", "CUDNN"]),79**hu.gcs_cpu_only)80def test_rowwhere(self, N, gc, dc, engine):81C = np.random.rand(N).astype(bool)82X = np.random.rand(N).astype(np.float32)83Y = np.random.rand(N).astype(np.float32)84op = core.CreateOperator(85"Where",86["C", "X", "Y"],87["Z"],88broadcast_on_rows=True,89engine=engine,90)91self.assertDeviceChecks(dc, op, [C, X, Y], [0])92self.assertReferenceChecks(gc, op, [C, X, Y], mux)93
94@given(N=st.integers(min_value=1, max_value=10),95engine=st.sampled_from(["", "CUDNN"]),96**hu.gcs_cpu_only)97def test_rowwhere_dim2(self, N, gc, dc, engine):98C = np.random.rand(N).astype(bool)99X = np.random.rand(N, N).astype(np.float32)100Y = np.random.rand(N, N).astype(np.float32)101op = core.CreateOperator(102"Where",103["C", "X", "Y"],104["Z"],105broadcast_on_rows=True,106engine=engine,107)108self.assertDeviceChecks(dc, op, [C, X, Y], [0])109self.assertReferenceChecks(gc, op, [C, X, Y], rowmux)110
111
112class TestIsMemberOf(serial.SerializedTestCase):113
114@given(N=st.integers(min_value=1, max_value=10),115engine=st.sampled_from(["", "CUDNN"]),116**hu.gcs_cpu_only)117@settings(deadline=10000)118def test_is_member_of(self, N, gc, dc, engine):119X = np.random.randint(10, size=N).astype(np.int64)120values = [0, 3, 4, 6, 8]121op = core.CreateOperator(122"IsMemberOf",123["X"],124["Y"],125value=np.array(values),126engine=engine,127)128self.assertDeviceChecks(dc, op, [X], [0])129values = set(values)130
131def test(x):132return [np.vectorize(lambda x: x in values)(x)]133self.assertReferenceChecks(gc, op, [X], test)134
135
136if __name__ == "__main__":137unittest.main()138