21
from hypothesis import given
22
import hypothesis.strategies as st
26
from caffe2.python.transformations import Transformer
27
from caffe2.python import core, workspace
28
from caffe2.python import test_util as tu
30
transformer = Transformer()
33
class TestTransformations(tu.TestCase):
34
def _base_test_net(self):
36
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
39
def _add_nnpack(self, net):
40
transformer.AddNNPACK(net)
41
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
43
def _fuse_nnpack_convrelu(self, net, expected_result_num_ops,
44
expected_activation_arg=True):
46
transformer.FuseNNPACKConvRelu(net)
47
self.assertEqual(tu.numOps(net), expected_result_num_ops)
48
has_activation_arg = False
49
for arg in net.Proto().op[0].arg:
50
if tu.str_compare(arg.name, "activation"):
51
assert tu.str_compare(arg.s, "Relu")
52
has_activation_arg = True
53
if expected_activation_arg:
54
assert has_activation_arg
56
assert not has_activation_arg
58
def test_transformer_AddNNPACK(self):
59
net = self._base_test_net()
60
net.Relu(["Y"], ["Y2"])
63
def test_transformer_FuseNNPACKConvRelu(self):
64
net = self._base_test_net()
65
net.Relu(["Y"], ["Y2"])
66
self._fuse_nnpack_convrelu(net, 1)
68
def test_noFuseNNPACKConvRelu(self):
69
net = self._base_test_net()
70
net.Relu(["Y"], ["Y2"])
71
net.Relu(["Y"], ["Y3"])
72
self._fuse_nnpack_convrelu(net, 3, expected_activation_arg=False)
74
def test_transformer_FuseNNPACKConvReluNoInplace(self):
75
net = self._base_test_net()
76
net.Relu(["Y"], ["X"])
77
self._fuse_nnpack_convrelu(net, 1)
78
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
80
def test_transformer_FuseNNPACKConvReluInplaceRelu(self):
81
net = self._base_test_net()
82
net.Relu(["Y"], ["Y"])
83
self._fuse_nnpack_convrelu(net, 1)
84
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
86
def test_transformer_FuseNNPACKConvReluPingPongNaming(self):
87
net = self._base_test_net()
88
net.Relu(["Y"], ["X"])
89
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
90
self._fuse_nnpack_convrelu(net, 2)
91
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
92
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
94
def test_transformer_FuseNNPACKConvReluFollowedByMultipleInputOp(self):
95
net = self._base_test_net()
96
net.Relu(["Y"], ["Y2"])
97
net.Conv(["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
98
net.Relu(["Y"], ["Y2"])
99
self._fuse_nnpack_convrelu(net, 2)
100
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
101
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
103
def test_transformer_FuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
104
net = self._base_test_net()
105
net.Relu(["Y"], ["Y"])
106
net.Conv(["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW")
107
net.Relu(["Y2"], ["Y2"])
108
self._fuse_nnpack_convrelu(net, 2)
109
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
110
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
113
size=st.integers(7, 10),
114
input_channels=st.integers(1, 10),
115
seed=st.integers(0, 65535),
116
order=st.sampled_from(["NCHW", "NHWC"]),
117
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
119
def test_transformer_FuseConvBN(self, size, input_channels, seed, order, epsilon):
120
workspace.ResetWorkspace()
121
net = core.Net("net")
126
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=k, order=order)
128
["Y", "scale", "bias", "mean", "var"],
137
tu.randBlobFloat32("X", 1, c, h, w)
138
tu.randBlobFloat32("w", c, c, k, k)
140
tu.randBlobFloat32("X", 1, h, w, c)
141
tu.randBlobFloat32("w", c, k, k, c)
142
tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
146
tu.randBlobFloat32("var", c, offset=0.5)
147
workspace.RunNetOnce(net)
148
preTransformOutput = workspace.FetchBlob("Y2").flatten()
149
workspace.FeedBlob("Y2", np.zeros((1, 1)))
150
transformer.FuseConvBN(net)
153
assert tu.numOps(net) == 1
154
workspace.RunNetOnce(net)
155
postTransformOutput = workspace.FetchBlob("Y2").flatten()
164
@unittest.skip("Test is flaky")
166
size=st.integers(7, 10),
167
input_channels=st.integers(1, 10),
168
seed=st.integers(0, 65535),
169
order=st.sampled_from(["NCHW", "NHWC"]),
170
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
172
def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon):
173
workspace.ResetWorkspace()
174
net = core.Net("net")
179
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
181
["Y", "scale", "bias", "mean", "var"],
190
tu.randBlobFloat32("X", 1, c, h, w)
191
tu.randBlobFloat32("w", c, c, k, k)
193
tu.randBlobFloat32("X", 1, h, w, c)
194
tu.randBlobFloat32("w", c, k, k, c)
195
tu.randBlobsFloat32(["scale", "bias", "mean"], c)
198
tu.randBlobFloat32("var", c, offset=0.5)
199
workspace.RunNetOnce(net)
200
preTransformOutput = workspace.FetchBlob("Y2").flatten()
201
workspace.FeedBlob("Y2", np.zeros((1, 1)))
202
transformer.FuseConvBN(net)
205
assert tu.numOps(net) == 1
206
workspace.RunNetOnce(net)
207
postTransformOutput = workspace.FetchBlob("Y2").flatten()
217
size=st.integers(7, 10),
218
input_channels=st.integers(1, 10),
219
seed=st.integers(0, 65535),
220
order=st.sampled_from(["NCHW", "NHWC"]),
221
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
223
def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon):
224
workspace.ResetWorkspace()
225
net = core.Net("net")
230
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
232
["Y", "scale", "_bias0", "mean", "var"],
241
tu.randBlobFloat32("X", 1, c, h, w)
242
tu.randBlobFloat32("w", c, c, k, k)
244
tu.randBlobFloat32("X", 1, h, w, c)
245
tu.randBlobFloat32("w", c, k, k, c)
246
tu.randBlobsFloat32(["scale", "_bias0", "mean"], c)
249
tu.randBlobFloat32("var", c, offset=0.5)
250
workspace.RunNetOnce(net)
251
preTransformOutput = workspace.FetchBlob("Y2").flatten()
252
workspace.FeedBlob("Y2", np.zeros((1, 1)))
253
transformer.FuseConvBN(net)
256
assert tu.numOps(net) == 1
257
workspace.RunNetOnce(net)
258
postTransformOutput = workspace.FetchBlob("Y2").flatten()
260
print(preTransformOutput)
262
print(postTransformOutput)
272
size=st.integers(7, 10),
273
input_channels=st.integers(1, 10),
274
kt=st.integers(3, 5),
275
kh=st.integers(3, 5),
276
kw=st.integers(3, 5),
277
seed=st.integers(0, 65535),
278
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
280
def test_transformer_FuseConv3DBN(
281
self, size, input_channels, kt, kh, kw, seed, epsilon
283
workspace.ResetWorkspace()
284
net = core.Net("net")
292
kernels=[kt, kh, kw],
295
["Y", "scale", "bias", "mean", "var"],
302
tu.randBlobFloat32("X", 1, c, t, h, w)
303
tu.randBlobFloat32("w", c, c, kt, kh, kw)
304
tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
307
tu.randBlobFloat32("var", c, offset=0.5)
308
workspace.RunNetOnce(net)
309
preTransformOutput = workspace.FetchBlob("Y2").flatten()
310
workspace.FeedBlob("Y2", np.zeros((1, 1)))
311
transformer.FuseConvBN(net)
314
assert tu.numOps(net) == 1
315
workspace.RunNetOnce(net)
316
postTransformOutput = workspace.FetchBlob("Y2").flatten()
325
def test_converterDontEnforceUnusedInputs(self):
326
net = core.Net("net")
327
net.Relu(["X"], ["Y"])
328
net.Proto().external_input.extend(["fake"])
330
transformer.AddNNPACK(net)
332
def test_converterDontEnforceUnusedOutputs(self):
333
net = core.Net("net")
334
net.Relu(["X"], ["Y"])
335
net.Proto().external_output.extend(["fake"])
336
transformer.AddNNPACK(net)