1
from caffe2.python import core
2
from hypothesis import given
4
import caffe2.python.hypothesis_test_util as hu
5
import hypothesis.strategies as st
9
class TestATen(hu.HypothesisTestCase):
11
@given(inputs=hu.tensors(n=2), **hu.gcs)
12
def test_add(self, inputs, gc, dc):
13
op = core.CreateOperator(
21
self.assertReferenceChecks(gc, op, inputs, ref)
23
@given(inputs=hu.tensors(n=2, dtype=np.float16), **hu.gcs_gpu_only)
24
def test_add_half(self, inputs, gc, dc):
25
op = core.CreateOperator(
33
self.assertReferenceChecks(gc, op, inputs, ref)
35
@given(inputs=hu.tensors(n=1), **hu.gcs)
36
def test_pow(self, inputs, gc, dc):
37
op = core.CreateOperator(
41
operator="pow", exponent=2.0)
46
self.assertReferenceChecks(gc, op, inputs, ref)
48
@given(x=st.integers(min_value=2, max_value=8), **hu.gcs)
49
def test_sort(self, x, gc, dc):
50
inputs = [np.random.permutation(x)]
51
op = core.CreateOperator(
58
return [np.sort(X), np.argsort(X)]
59
self.assertReferenceChecks(gc, op, inputs, ref)
61
@given(inputs=hu.tensors(n=1), **hu.gcs)
62
def test_sum(self, inputs, gc, dc):
63
op = core.CreateOperator(
72
self.assertReferenceChecks(gc, op, inputs, ref)
75
def test_index_uint8(self, gc, dc):
77
op = core.CreateOperator(
84
return (self[mask.astype(np.bool_)],)
86
tensor = np.random.randn(2, 3, 4).astype(np.float32)
87
mask = np.array([[1, 0, 0], [1, 1, 0]]).astype(np.uint8)
89
self.assertReferenceChecks(gc, op, [tensor, mask], ref)
92
def test_index_put(self, gc, dc):
93
op = core.CreateOperator(
95
['self', 'indices', 'values'],
99
def ref(self, indices, values):
100
self[indices] = values
103
tensor = np.random.randn(3, 3).astype(np.float32)
104
mask = np.array([[True, True, True], [True, False, False], [True, True, False]])
105
values = np.random.randn(6).astype(np.float32)
107
self.assertReferenceChecks(gc, op, [tensor, mask, values], ref)
110
def test_unique(self, gc, dc):
111
op = core.CreateOperator(
121
index, _ = np.unique(self, return_index=False, return_inverse=True, return_counts=False)
124
tensor = np.array([1, 2, 6, 4, 2, 3, 2])
126
self.assertReferenceChecks(gc, op, [tensor], ref)
129
if __name__ == "__main__":