pytorch

Форк
0
/
transformations_test.py 
336 строк · 11.7 Кб
1
# Copyright (c) 2016-present, Facebook, Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
##############################################################################
15

16

17

18

19

20

21
from hypothesis import given
22
import hypothesis.strategies as st
23
import numpy as np
24
import unittest
25

26
from caffe2.python.transformations import Transformer
27
from caffe2.python import core, workspace
28
from caffe2.python import test_util as tu
29

30
transformer = Transformer()
31

32

33
class TestTransformations(tu.TestCase):
34
    def _base_test_net(self):
35
        net = core.Net("net")
36
        net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
37
        return net
38

39
    def _add_nnpack(self, net):
40
        transformer.AddNNPACK(net)
41
        assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
42

43
    def _fuse_nnpack_convrelu(self, net, expected_result_num_ops,
44
    expected_activation_arg=True):
45
        self._add_nnpack(net)
46
        transformer.FuseNNPACKConvRelu(net)
47
        self.assertEqual(tu.numOps(net), expected_result_num_ops)
48
        has_activation_arg = False
49
        for arg in net.Proto().op[0].arg:
50
            if tu.str_compare(arg.name, "activation"):
51
                assert tu.str_compare(arg.s, "Relu")
52
                has_activation_arg = True
53
        if expected_activation_arg:
54
            assert has_activation_arg
55
        else:
56
            assert not has_activation_arg
57

58
    def test_transformer_AddNNPACK(self):
59
        net = self._base_test_net()
60
        net.Relu(["Y"], ["Y2"])
61
        self._add_nnpack(net)
62

63
    def test_transformer_FuseNNPACKConvRelu(self):
64
        net = self._base_test_net()
65
        net.Relu(["Y"], ["Y2"])
66
        self._fuse_nnpack_convrelu(net, 1)
67

68
    def test_noFuseNNPACKConvRelu(self):
69
        net = self._base_test_net()
70
        net.Relu(["Y"], ["Y2"])
71
        net.Relu(["Y"], ["Y3"])
72
        self._fuse_nnpack_convrelu(net, 3, expected_activation_arg=False)
73

74
    def test_transformer_FuseNNPACKConvReluNoInplace(self):
75
        net = self._base_test_net()
76
        net.Relu(["Y"], ["X"])
77
        self._fuse_nnpack_convrelu(net, 1)
78
        assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
79

80
    def test_transformer_FuseNNPACKConvReluInplaceRelu(self):
81
        net = self._base_test_net()
82
        net.Relu(["Y"], ["Y"])
83
        self._fuse_nnpack_convrelu(net, 1)
84
        assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
85

86
    def test_transformer_FuseNNPACKConvReluPingPongNaming(self):
87
        net = self._base_test_net()
88
        net.Relu(["Y"], ["X"])
89
        net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
90
        self._fuse_nnpack_convrelu(net, 2)
91
        assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
92
        assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
93

94
    def test_transformer_FuseNNPACKConvReluFollowedByMultipleInputOp(self):
95
        net = self._base_test_net()
96
        net.Relu(["Y"], ["Y2"])
97
        net.Conv(["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
98
        net.Relu(["Y"], ["Y2"])
99
        self._fuse_nnpack_convrelu(net, 2)
100
        assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
101
        assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
102

103
    def test_transformer_FuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
104
        net = self._base_test_net()
105
        net.Relu(["Y"], ["Y"])
106
        net.Conv(["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW")
107
        net.Relu(["Y2"], ["Y2"])
108
        self._fuse_nnpack_convrelu(net, 2)
109
        assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
110
        assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
111

112
    @given(
113
        size=st.integers(7, 10),
114
        input_channels=st.integers(1, 10),
115
        seed=st.integers(0, 65535),
116
        order=st.sampled_from(["NCHW", "NHWC"]),
117
        epsilon=st.floats(min_value=1e-5, max_value=1e-2),
118
    )
119
    def test_transformer_FuseConvBN(self, size, input_channels, seed, order, epsilon):
120
        workspace.ResetWorkspace()
121
        net = core.Net("net")
122
        c = input_channels
123
        h = size
124
        w = size
125
        k = 3
126
        net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=k, order=order)
127
        net.SpatialBN(
128
            ["Y", "scale", "bias", "mean", "var"],
129
            ["Y2"],
130
            is_test=True,
131
            order=order,
132
            epsilon=epsilon,
133
        )
134

135
        np.random.seed(seed)
136
        if order == "NCHW":
137
            tu.randBlobFloat32("X", 1, c, h, w)
138
            tu.randBlobFloat32("w", c, c, k, k)
139
        else:
140
            tu.randBlobFloat32("X", 1, h, w, c)
141
            tu.randBlobFloat32("w", c, k, k, c)
142
        tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
143

144
        # This is necessary because 1/sqrt(var) is used and if var is too small
145
        # we get floating point artifacts that cause test failures
146
        tu.randBlobFloat32("var", c, offset=0.5)
147
        workspace.RunNetOnce(net)
148
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
149
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
150
        transformer.FuseConvBN(net)
151

152
        # Ensure fusion
153
        assert tu.numOps(net) == 1
154
        workspace.RunNetOnce(net)
155
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
156
        # Check that there is no numerical difference
157
        assert np.allclose(
158
            preTransformOutput,
159
            postTransformOutput,
160
            rtol=5e-02,
161
            atol=1e-03
162
        )
163

164
    @unittest.skip("Test is flaky")
165
    @given(
166
        size=st.integers(7, 10),
167
        input_channels=st.integers(1, 10),
168
        seed=st.integers(0, 65535),
169
        order=st.sampled_from(["NCHW", "NHWC"]),
170
        epsilon=st.floats(min_value=1e-5, max_value=1e-2),
171
    )
172
    def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon):
173
        workspace.ResetWorkspace()
174
        net = core.Net("net")
175
        c = input_channels
176
        h = size
177
        w = size
178
        k = 3
179
        net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
180
        net.SpatialBN(
181
            ["Y", "scale", "bias", "mean", "var"],
182
            ["Y2"],
183
            is_test=True,
184
            order=order,
185
            epsilon=epsilon,
186
        )
187

188
        np.random.seed(seed)
189
        if order == "NCHW":
190
            tu.randBlobFloat32("X", 1, c, h, w)
191
            tu.randBlobFloat32("w", c, c, k, k)
192
        else:
193
            tu.randBlobFloat32("X", 1, h, w, c)
194
            tu.randBlobFloat32("w", c, k, k, c)
195
        tu.randBlobsFloat32(["scale", "bias", "mean"], c)
196
        # This is necessary because 1/sqrt(var) is used and if var is too small
197
        # we get floating point artifacts that cause test failures
198
        tu.randBlobFloat32("var", c, offset=0.5)
199
        workspace.RunNetOnce(net)
200
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
201
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
202
        transformer.FuseConvBN(net)
203

204
        # Ensure fusion
205
        assert tu.numOps(net) == 1
206
        workspace.RunNetOnce(net)
207
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
208
        # Check that there is no numerical difference
209
        assert np.allclose(
210
            preTransformOutput,
211
            postTransformOutput,
212
            rtol=5e-02,
213
            atol=1e-03
214
        )
215

216
    @given(
217
        size=st.integers(7, 10),
218
        input_channels=st.integers(1, 10),
219
        seed=st.integers(0, 65535),
220
        order=st.sampled_from(["NCHW", "NHWC"]),
221
        epsilon=st.floats(min_value=1e-5, max_value=1e-2),
222
    )
223
    def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon):
224
        workspace.ResetWorkspace()
225
        net = core.Net("net")
226
        c = input_channels
227
        h = size
228
        w = size
229
        k = 3
230
        net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
231
        net.SpatialBN(
232
            ["Y", "scale", "_bias0", "mean", "var"],
233
            ["Y2"],
234
            is_test=True,
235
            order=order,
236
            epsilon=epsilon,
237
        )
238

239
        np.random.seed(seed)
240
        if order == "NCHW":
241
            tu.randBlobFloat32("X", 1, c, h, w)
242
            tu.randBlobFloat32("w", c, c, k, k)
243
        else:
244
            tu.randBlobFloat32("X", 1, h, w, c)
245
            tu.randBlobFloat32("w", c, k, k, c)
246
        tu.randBlobsFloat32(["scale", "_bias0", "mean"], c)
247
        # This is necessary because 1/sqrt(var) is used and if var is too small
248
        # we get floating point artifacts that cause test failures
249
        tu.randBlobFloat32("var", c, offset=0.5)
250
        workspace.RunNetOnce(net)
251
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
252
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
253
        transformer.FuseConvBN(net)
254

255
        # Ensure fusion
256
        assert tu.numOps(net) == 1
257
        workspace.RunNetOnce(net)
258
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
259
        print("pre")
260
        print(preTransformOutput)
261
        print("after")
262
        print(postTransformOutput)
263
        # Check that there is no numerical difference
264
        assert np.allclose(
265
            preTransformOutput,
266
            postTransformOutput,
267
            rtol=5e-02,
268
            atol=1e-03
269
        )
270

271
    @given(
272
        size=st.integers(7, 10),
273
        input_channels=st.integers(1, 10),
274
        kt=st.integers(3, 5),
275
        kh=st.integers(3, 5),
276
        kw=st.integers(3, 5),
277
        seed=st.integers(0, 65535),
278
        epsilon=st.floats(min_value=1e-5, max_value=1e-2),
279
    )
280
    def test_transformer_FuseConv3DBN(
281
        self, size, input_channels, kt, kh, kw, seed, epsilon
282
    ):
283
        workspace.ResetWorkspace()
284
        net = core.Net("net")
285
        c = input_channels
286
        t = size
287
        h = size
288
        w = size
289
        net.Conv(
290
            ["X", "w", "b"],
291
            ["Y"],
292
            kernels=[kt, kh, kw],
293
        )
294
        net.SpatialBN(
295
            ["Y", "scale", "bias", "mean", "var"],
296
            ["Y2"],
297
            is_test=True,
298
            epsilon=epsilon,
299
        )
300

301
        np.random.seed(seed)
302
        tu.randBlobFloat32("X", 1, c, t, h, w)
303
        tu.randBlobFloat32("w", c, c, kt, kh, kw)
304
        tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
305
        # This is necessary because 1/sqrt(var) is used and if var is too small
306
        # we get floating point artifacts that cause test failures
307
        tu.randBlobFloat32("var", c, offset=0.5)
308
        workspace.RunNetOnce(net)
309
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
310
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
311
        transformer.FuseConvBN(net)
312

313
        # Ensure fusion
314
        assert tu.numOps(net) == 1
315
        workspace.RunNetOnce(net)
316
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
317
        # Check that there is no numerical difference
318
        assert np.allclose(
319
            preTransformOutput,
320
            postTransformOutput,
321
            rtol=1e-02,
322
            atol=1e-04
323
        )
324

325
    def test_converterDontEnforceUnusedInputs(self):
326
        net = core.Net("net")
327
        net.Relu(["X"], ["Y"])
328
        net.Proto().external_input.extend(["fake"])
329
        # This should now work
330
        transformer.AddNNPACK(net)  # just testing the converter
331

332
    def test_converterDontEnforceUnusedOutputs(self):
333
        net = core.Net("net")
334
        net.Relu(["X"], ["Y"])
335
        net.Proto().external_output.extend(["fake"])
336
        transformer.AddNNPACK(net)  # just testing the converter
337

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.