pytorch
57 строк · 1.7 Кб
1
2
3
4
5
6from caffe2.python import core
7import caffe2.python.hypothesis_test_util as hu
8import caffe2.python.serialized_test.serialized_test_util as serial
9import hypothesis.strategies as st
10import numpy as np
11
12
13class TestUnmaskOp(serial.SerializedTestCase):
14@serial.given(N=st.integers(min_value=2, max_value=20),
15dtype=st.sampled_from([
16np.bool_,
17np.int8,
18np.int16,
19np.int32,
20np.int64,
21np.uint8,
22np.uint16,
23np.float16,
24np.float32,
25np.float64]),
26**hu.gcs)
27def test(self, N, dtype, gc, dc):
28if dtype is np.bool_:
29all_value = np.random.choice(a=[True, False], size=N)
30else:
31all_value = (np.random.rand(N) * N).astype(dtype)
32
33M = np.random.randint(1, N)
34split = sorted(np.random.randint(1, N, size=M))
35indices = np.random.permutation(N)
36pieces = np.split(indices, split)
37
38def ref(*args, **kwargs):
39return (all_value,)
40
41inputs = []
42inputs_names = []
43for i, piece in enumerate(pieces):
44piece.sort()
45mask = np.zeros(N, dtype=np.bool_)
46mask[piece] = True
47values = all_value[piece]
48inputs.extend([mask, values])
49inputs_names.extend(["mask%d" % i, "value%d" % i])
50
51op = core.CreateOperator(
52'BooleanUnmask',
53inputs_names,
54'output')
55
56self.assertReferenceChecks(gc, op, inputs, ref)
57self.assertDeviceChecks(dc, op, inputs, [0])
58
59
60if __name__ == "__main__":
61import unittest
62unittest.main()
63