pytorch
45 строк · 1.3 Кб
1
2
3
4
5from caffe2.python import core6import caffe2.python.hypothesis_test_util as hu7import caffe2.python.serialized_test.serialized_test_util as serial8import hypothesis.strategies as st9import numpy as np10import unittest11
12
13class TestPad(serial.SerializedTestCase):14@serial.given(pad_t=st.integers(-5, 0),15pad_l=st.integers(-5, 0),16pad_b=st.integers(-5, 0),17pad_r=st.integers(-5, 0),18mode=st.sampled_from(["constant", "reflect", "edge"]),19size_w=st.integers(16, 128),20size_h=st.integers(16, 128),21size_c=st.integers(1, 4),22size_n=st.integers(1, 4),23**hu.gcs)24def test_crop(self,25pad_t, pad_l, pad_b, pad_r,26mode,27size_w, size_h, size_c, size_n,28gc, dc):29op = core.CreateOperator(30"PadImage",31["X"],32["Y"],33pad_t=pad_t,34pad_l=pad_l,35pad_b=pad_b,36pad_r=pad_r,37)38X = np.random.rand(39size_n, size_c, size_h, size_w).astype(np.float32)40
41def ref(X):42return (X[:, :, -pad_t:pad_b or None, -pad_l:pad_r or None],)43
44self.assertReferenceChecks(gc, op, [X], ref)45self.assertDeviceChecks(dc, op, [X], [0])46
47
48if __name__ == "__main__":49unittest.main()50