pytorch
39 строк · 1.5 Кб
1
2
3
4
5
6from caffe2.python import workspace, core7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9from hypothesis import given, settings10import hypothesis.strategies as st11import numpy as np12
13
14class TestNegateGradient(serial.SerializedTestCase):15
16@given(X=hu.tensor(), inplace=st.booleans(), **hu.gcs)17@settings(deadline=10000)18def test_forward(self, X, inplace, gc, dc):19def neg_grad_ref(X):20return (X,)21
22op = core.CreateOperator("NegateGradient", ["X"], ["Y" if not inplace else "X"])23self.assertReferenceChecks(gc, op, [X], neg_grad_ref)24self.assertDeviceChecks(dc, op, [X], [0])25
26@given(size=st.lists(st.integers(min_value=1, max_value=20),27min_size=1, max_size=5))28def test_grad(self, size):29X = np.random.random_sample(size)30workspace.ResetWorkspace()31workspace.FeedBlob("X", X.astype(np.float32))32
33net = core.Net("negate_grad_test")34Y = net.NegateGradient(["X"], ["Y"])35
36grad_map = net.AddGradientOperators([Y])37workspace.RunNetOnce(net)38
39# check X_grad == negate of Y_grad40x_val, y_val = workspace.FetchBlobs(['X', 'Y'])41x_grad_val, y_grad_val = workspace.FetchBlobs([grad_map['X'],42grad_map['Y']])43np.testing.assert_array_equal(x_val, y_val)44np.testing.assert_array_equal(x_grad_val, y_grad_val * (-1))45