pytorch
40 строк · 1.2 Кб
1
2
3
4
5
6from caffe2.python import core7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9from hypothesis import given, settings10import hypothesis.strategies as st11import numpy as np12
13import unittest14
15
16@st.composite17def _glu_old_input(draw):18dims = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=3))19axis = draw(st.integers(min_value=0, max_value=len(dims)))20# The axis dimension must be divisible by two21axis_dim = 2 * draw(st.integers(min_value=1, max_value=2))22dims.insert(axis, axis_dim)23X = draw(hu.arrays(dims, np.float32, None))24return (X, axis)25
26
27class TestGlu(serial.SerializedTestCase):28@given(29X_axis=_glu_old_input(),30**hu.gcs31)32@settings(deadline=10000)33def test_glu_old(self, X_axis, gc, dc):34X, axis = X_axis35
36def glu_ref(X):37x1, x2 = np.split(X, [X.shape[axis] // 2], axis=axis)38Y = x1 * (1. / (1. + np.exp(-x2)))39return [Y]40
41op = core.CreateOperator("Glu", ["X"], ["Y"], dim=axis)42self.assertReferenceChecks(gc, op, [X], glu_ref)43
44if __name__ == "__main__":45unittest.main()46