pytorch
133 строки · 4.2 Кб
1
2
3
4
5
6from caffe2.python import core
7from hypothesis import assume, given, settings
8
9import caffe2.python.hypothesis_test_util as hu
10import hypothesis.strategies as st
11import numpy as np
12
13
14class TestReduceFrontSum(hu.HypothesisTestCase):
15@given(batch_size=st.integers(1, 3),
16stride=st.integers(1, 3),
17pad=st.integers(0, 3),
18kernel=st.integers(1, 5),
19dilation=st.integers(1, 3),
20size=st.integers(7, 10),
21channels=st.integers(1, 8),
22**hu.gcs)
23def test_im2col_layout(self, batch_size, stride, pad, kernel, dilation,
24size, channels, gc, dc):
25
26dkernel = (dilation * (kernel - 1) + 1)
27assume(size >= dkernel)
28
29NCHW_TO_NHWC = (0, 2, 3, 1)
30NHWC_TO_NCHW = (0, 3, 1, 2)
31COL_NHWC_TO_NCHW = (4, 2, 3, 0, 1)
32
33N = batch_size
34C = channels
35H = size
36W = size
37
38out_h = int((H + (2 * pad) - dkernel) / stride + 1)
39out_w = int((W + (2 * pad) - dkernel) / stride + 1)
40
41im_nchw = np.random.rand(N, C, H, W).astype(np.float32) - 0.5
42im_nhwc = im_nchw.transpose(NCHW_TO_NHWC)
43
44op_im2col_nchw = core.CreateOperator(
45"Im2Col",
46["im_nchw"], ["col_nchw"],
47stride=stride,
48kernel=kernel,
49dilation=dilation,
50pad=pad,
51order="NCHW",
52device_option=gc)
53
54op_im2col_nhwc = core.CreateOperator(
55"Im2Col",
56["im_nhwc"], ["col_nhwc"],
57stride=stride,
58kernel=kernel,
59dilation=dilation,
60pad=pad,
61order="NHWC",
62device_option=gc)
63
64self.ws.create_blob("im_nchw").feed(im_nchw, device_option=gc)
65self.ws.create_blob("im_nhwc").feed(im_nhwc, device_option=gc)
66self.ws.run(op_im2col_nchw)
67self.ws.run(op_im2col_nhwc)
68
69# there is probably a clever way to spell this in np
70col_nchw = self.ws.blobs["col_nchw"].fetch()
71col_nhwc = self.ws.blobs["col_nhwc"].fetch()
72col_nchw_ = col_nchw.reshape(N, C, kernel, kernel, out_h, out_w)
73col_nhwc_ = col_nhwc.reshape(N, out_h, out_w, kernel, kernel, C)
74for i in range(0, N):
75np.testing.assert_allclose(
76col_nchw_[i],
77col_nhwc_[i].transpose(COL_NHWC_TO_NCHW),
78atol=1e-4,
79rtol=1e-4)
80
81op_col2im_nchw = core.CreateOperator(
82"Col2Im",
83["col_nchw", "im_nchw"],
84["out_nchw"],
85stride=stride,
86kernel=kernel,
87dilation=dilation,
88pad=pad,
89order="NCHW",
90device_option=gc)
91
92op_col2im_nhwc = core.CreateOperator(
93"Col2Im",
94["col_nhwc", "im_nhwc"],
95["out_nhwc"],
96stride=stride,
97kernel=kernel,
98dilation=dilation,
99pad=pad,
100order="NHWC",
101device_option=gc)
102
103self.ws.run(op_col2im_nchw)
104self.ws.run(op_col2im_nhwc)
105
106out_nchw = self.ws.blobs["out_nchw"].fetch()
107out_nhwc = self.ws.blobs["out_nhwc"].fetch()
108np.testing.assert_allclose(
109out_nchw,
110out_nhwc.transpose(NHWC_TO_NCHW),
111atol=1e-4,
112rtol=1e-4)
113
114@given(batch_size=st.integers(1, 3),
115stride=st.integers(1, 3),
116pad=st.integers(0, 3),
117kernel=st.integers(1, 5),
118dilation=st.integers(1, 3),
119size=st.integers(7, 10),
120channels=st.integers(1, 8),
121order=st.sampled_from(["NCHW"]),
122**hu.gcs)
123@settings(deadline=10000)
124def test_col2im_gradients(self, batch_size, stride, pad, kernel,
125dilation, size, channels, order, gc, dc):
126assume(size >= dilation * (kernel - 1) + 1)
127op = core.CreateOperator(
128"Im2Col",
129["X"], ["Y"],
130stride=stride,
131kernel=kernel,
132dilation=dilation,
133pad=pad,
134order=order,
135device_option=gc)
136X = np.random.rand(batch_size, channels, size, size).astype(np.float32)
137self.assertGradientChecks(gc, op, [X], 0, [0])
138return
139