pytorch
50 строк · 1.6 Кб
1
2
3
4
5
6from caffe2.python import core7from hypothesis import given, settings8from hypothesis import strategies as st9import caffe2.python.hypothesis_test_util as hu10import caffe2.python.serialized_test.serialized_test_util as serial11
12import numpy as np13import unittest14
15
16class TestMathOps(serial.SerializedTestCase):17
18@given(X=hu.tensor(),19exponent=st.floats(min_value=2.0, max_value=3.0),20**hu.gcs)21def test_elementwise_power(self, X, exponent, gc, dc):22# negative integer raised with non-integer exponent is domain error23X = np.abs(X)24def powf(X):25return (X ** exponent,)26
27def powf_grad(g_out, outputs, fwd_inputs):28return (exponent * (fwd_inputs[0] ** (exponent - 1)) * g_out,)29
30op = core.CreateOperator(31"Pow", ["X"], ["Y"], exponent=exponent)32
33self.assertReferenceChecks(gc, op, [X], powf,34output_to_grad="Y",35grad_reference=powf_grad,36ensure_outputs_are_inferred=True)37
38@given(X=hu.tensor(),39exponent=st.floats(min_value=-3.0, max_value=3.0),40**hu.gcs)41@settings(deadline=10000)42def test_sign(self, X, exponent, gc, dc):43def signf(X):44return [np.sign(X)]45
46op = core.CreateOperator(47"Sign", ["X"], ["Y"])48
49self.assertReferenceChecks(50gc, op, [X], signf, ensure_outputs_are_inferred=True)51self.assertDeviceChecks(dc, op, [X], [0])52
53
54if __name__ == "__main__":55unittest.main()56