pytorch
46 строк · 1.5 Кб
1
2
3
4
5
6import unittest
7import hypothesis.strategies as st
8from hypothesis import given, settings
9import numpy as np
10from caffe2.python import core, workspace
11import caffe2.python.hypothesis_test_util as hu
12import caffe2.python.mkl_test_util as mu
13
14
15@unittest.skipIf(not workspace.C.has_mkldnn,
16"Skipping as we do not have mkldnn.")
17class MKLConvTest(hu.HypothesisTestCase):
18@given(stride=st.integers(1, 3),
19pad=st.integers(0, 3),
20kernel=st.integers(3, 5),
21size=st.integers(8, 8),
22input_channels=st.integers(1, 3),
23output_channels=st.integers(1, 3),
24batch_size=st.integers(1, 3),
25**mu.gcs)
26@settings(max_examples=2, deadline=100)
27def test_mkl_convolution(self, stride, pad, kernel, size,
28input_channels, output_channels,
29batch_size, gc, dc):
30op = core.CreateOperator(
31"Conv",
32["X", "w", "b"],
33["Y"],
34stride=stride,
35pad=pad,
36kernel=kernel,
37)
38X = np.random.rand(
39batch_size, input_channels, size, size).astype(np.float32) - 0.5
40w = np.random.rand(
41output_channels, input_channels, kernel, kernel) \
42.astype(np.float32) - 0.5
43b = np.random.rand(output_channels).astype(np.float32) - 0.5
44
45inputs = [X, w, b]
46self.assertDeviceChecks(dc, op, inputs, [0])
47
48
49if __name__ == "__main__":
50import unittest
51unittest.main()
52