pytorch

Форк
0
/
test_int8_ops_nnpi.py 
322 строки · 10.2 Кб
1
import caffe2.python.fakelowp.init_shared_libs  # noqa
2
import numpy as np
3
from caffe2.python import core, workspace
4
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
5
from hypothesis import given, strategies as st, settings
6
from caffe2.python.fakelowp.test_utils import print_test_debug_info
7
import caffe2.python.serialized_test.serialized_test_util as serial
8
import datetime
9

10
core.GlobalInit(["caffe2",
11
                 "--caffe2_log_level=-3",
12
                 "--glow_global_fp16=1",
13
                 "--glow_clip_quant_range_to_fp16=1",
14
                 "--glow_global_fp16_constants=1"
15
                 ])
16

17

18
class Int8OpsTest(serial.SerializedTestCase):
19
    def _get_scale_zp(self, tensor):
20
        tensor_max = np.max(tensor)
21
        tensor_min = min(0, np.min(tensor))
22
        scale = np.float32(np.float16((tensor_max - tensor_min) / 255.0))
23
        if scale < 1e-6:
24
            scale = np.float32(1e-6)
25
        zero_point = 0 - tensor_min / scale
26
        zero_point = int(round(np.clip(zero_point, 0, 255.0)))
27
        return (scale, zero_point)
28

29
    @given(
30
        n=st.integers(2, 1024),
31
        rand_seed=st.integers(0, 65534),
32
        non_zero_offset=st.booleans()
33
    )
34
    @settings(deadline=datetime.timedelta(seconds=50))
35
    def test_int8_quantize(self, n, rand_seed, non_zero_offset):
36
        print("n={}, rand_seed={}".format(n, rand_seed))
37
        np.random.seed(rand_seed)
38
        workspace.ResetWorkspace()
39

40
        if non_zero_offset:
41
            X_fp32 = np.random.uniform(-1, 1, size=(n, n)).astype(np.float16) \
42
                .astype(np.float32)
43
        else:
44
            X_fp32 = np.random.rand(n, n).astype(np.float16).astype(np.float32)
45

46
        W_fp32 = np.identity(n, dtype=np.float32)
47
        b_fp32 = np.zeros((n,), dtype=np.float32)
48

49
        X_scale, X_zero_point = self._get_scale_zp(X_fp32)
50

51
        workspace.FeedBlob("X", X_fp32)
52
        workspace.FeedBlob("W", W_fp32)
53
        workspace.FeedBlob("b", b_fp32)
54

55
        workspace.RunOperatorOnce(
56
            core.CreateOperator(
57
                "Int8FCPackWeight",
58
                ["W"],
59
                ["W_int8"],
60
                engine="DNNLOWP",
61
                save_unpacked_weights=True,
62
                in_scale=X_scale,
63
            )
64
        )
65

66
        ref_net = core.Net("net")
67
        ref_net.Int8QuantizeNNPI(
68
            ["X"],
69
            ["X_int8"],
70
            Y_scale=X_scale,
71
            Y_zero_point=X_zero_point
72
        )
73
        ref_net.Int8FCFakeAcc32NNPI(
74
            ["X_int8", "W_int8", "b"],
75
            ["Y_int8"],
76
            Y_scale=X_scale,
77
            Y_zero_point=X_zero_point,
78
        )
79
        ref_net.Int8DequantizeNNPI(
80
            ["Y_int8"],
81
            ["Y"]
82
        )
83
        ref_net.Proto().external_output.append("Y")
84

85
        # run ref_net
86
        workspace.RunNetOnce(ref_net)
87
        Y_fbgemm = workspace.FetchBlob("Y")
88

89
        # run onnxifi net
90
        ref_net.Proto().op[0].type = "Int8Quantize"
91
        ref_net.Proto().op[1].type = "Int8FC"
92
        ref_net.Proto().op[2].type = "Int8Dequantize"
93
        net_onnxified = onnxifi_caffe2_net(
94
            ref_net.Proto(),
95
            {},
96
            debug=True,
97
            adjust_batch=False,
98
            use_onnx=False,
99
            weight_names=["W_int8", "b"],
100
        )
101
        num_onnxified_ops = sum(
102
            1 if o.type == "Onnxifi" else 0 for o in net_onnxified.op
103
        )
104
        np.testing.assert_equal(num_onnxified_ops, 1)
105
        workspace.CreateNet(net_onnxified)
106
        workspace.RunNet(net_onnxified.name)
107
        Y_glow = workspace.FetchBlob("Y")
108

109
        if not np.allclose(Y_glow, Y_fbgemm):
110
            diff_Y = np.abs(Y_glow - Y_fbgemm)
111
            print_test_debug_info(
112
                "int8_fc",
113
                {
114
                    "seed": rand_seed,
115
                    "n": n,
116
                    "X": X_fp32,
117
                    "W": W_fp32,
118
                    "b": b_fp32,
119
                    "Y_fbgemm": Y_fbgemm,
120
                    "Y_glow": Y_glow,
121
                    "diff": diff_Y,
122
                    "maxdiff": diff_Y.max(axis=1),
123
                },
124
            )
125
            assert 0
126

127
    @given(
128
        n=st.integers(1, 1024),
129
        m=st.integers(1, 1024),
130
        k=st.integers(1, 1024),
131
        f=st.integers(1, 1),  # TODO: figure a safe number to increase
132
        rand_seed=st.integers(0, 65534),
133
        quantize_bias=st.sampled_from([False]),
134
    )
135
    @settings(deadline=datetime.timedelta(seconds=50))
136
    def test_int8_fc(
137
        self, n, m, k, rand_seed, quantize_bias, f
138
    ):
139
        print(
140
            f"n={n}, m={m}, k={k}, rand_seed={rand_seed}, quantize_bias={quantize_bias}"
141
        )
142
        np.random.seed(rand_seed)
143
        workspace.ResetWorkspace()
144

145
        ff = float(f)
146
        X_fp32 = np.random.uniform(-ff, ff, size=(m, k)).astype(np.float32)
147
        W_fp32 = np.random.uniform(-ff, ff, size=(n, k)).astype(np.float32)
148
        b_fp32 = np.random.uniform(-ff, ff, size=(n)).astype(np.float32)
149

150
        X_scale, X_zero_point = self._get_scale_zp(X_fp32)
151
        Y_fp32 = np.dot(X_fp32, W_fp32.T) + b_fp32
152
        Y_scale, Y_zero_point = self._get_scale_zp(Y_fp32)
153

154
        workspace.FeedBlob("X", X_fp32)
155
        workspace.FeedBlob("W", W_fp32)
156
        workspace.FeedBlob("b", b_fp32)
157

158
        workspace.RunOperatorOnce(
159
            core.CreateOperator(
160
                "Int8FCPackWeight",
161
                ["W", "b"] if quantize_bias else ["W"],
162
                ["W_int8", "b_int32"] if quantize_bias else ["W_int8"],
163
                engine="DNNLOWP",
164
                save_unpacked_weights=True,
165
                in_scale=X_scale,
166
            )
167
        )
168

169
        ref_net = core.Net("net")
170
        ref_net.Int8QuantizeNNPI(
171
            ["X"],
172
            ["X_int8"],
173
            Y_scale=X_scale,
174
            Y_zero_point=X_zero_point
175
        )
176
        ref_net.Int8FCFakeAcc32NNPI(
177
            ["X_int8", "W_int8", "b_int32" if quantize_bias else "b"],
178
            ["Y_int8"],
179
            Y_scale=Y_scale,
180
            Y_zero_point=Y_zero_point,
181
        )
182
        ref_net.Int8DequantizeNNPI(
183
            ["Y_int8"],
184
            ["Y"]
185
        )
186
        ref_net.Proto().external_output.append("Y")
187

188
        # run ref_net
189
        workspace.RunNetOnce(ref_net)
190
        Y_fbgemm = workspace.FetchBlob("Y")
191

192
        # run onnxifi net
193
        ref_net.Proto().op[0].type = "Int8Quantize"
194
        ref_net.Proto().op[1].type = "Int8FC"
195
        ref_net.Proto().op[2].type = "Int8Dequantize"
196
        net_onnxified = onnxifi_caffe2_net(
197
            ref_net.Proto(),
198
            {},
199
            debug=True,
200
            adjust_batch=False,
201
            use_onnx=False,
202
            weight_names=["W_int8", "b_int32"] if quantize_bias else ["W_int8", "b"],
203
        )
204
        num_onnxified_ops = sum(
205
            1 if o.type == "Onnxifi" else 0 for o in net_onnxified.op
206
        )
207
        np.testing.assert_equal(num_onnxified_ops, 1)
208
        workspace.CreateNet(net_onnxified)
209
        workspace.RunNet(net_onnxified.name)
210
        Y_glow = workspace.FetchBlob("Y")
211

212
        if not np.allclose(Y_glow, Y_fbgemm):
213
            diff_Y = np.abs(Y_glow - Y_fbgemm)
214
            print_test_debug_info(
215
                "int8_fc",
216
                {
217
                    "seed": rand_seed,
218
                    "n": n,
219
                    "m": m,
220
                    "k": k,
221
                    "X": X_fp32,
222
                    "W": W_fp32,
223
                    "b": b_fp32,
224
                    "Y_fbgemm": Y_fbgemm,
225
                    "Y_glow": Y_glow,
226
                    "diff": diff_Y,
227
                    "maxdiff": diff_Y.max(axis=1),
228
                },
229
            )
230
            assert 0
231

232
    @given(
233
        n=st.integers(1, 4),
234
        rand_seed=st.integers(0, 65534)
235
    )
236
    @settings(deadline=datetime.timedelta(seconds=10))
237
    def test_int8_small_input(self, n, rand_seed):
238
        print("n={}, rand_seed={}".format(n, rand_seed))
239
        np.random.seed(rand_seed)
240
        workspace.ResetWorkspace()
241

242
        X_fp32 = np.random.uniform(0.01, 0.03, size=(n, n)).astype(np.float32)
243
        W_fp32 = np.identity(n, dtype=np.float32)
244
        b_fp32 = np.zeros((n,), dtype=np.float32)
245

246
        X_scale, X_zero_point = self._get_scale_zp(X_fp32)
247

248
        workspace.FeedBlob("X", X_fp32)
249
        workspace.FeedBlob("W", W_fp32)
250
        workspace.FeedBlob("b", b_fp32)
251

252
        workspace.RunOperatorOnce(
253
            core.CreateOperator(
254
                "Int8FCPackWeight",
255
                ["W"],
256
                ["W_int8"],
257
                engine="DNNLOWP",
258
                save_unpacked_weights=True,
259
                in_scale=X_scale,
260
            )
261
        )
262

263
        ref_net = core.Net("net")
264
        ref_net.Int8QuantizeNNPI(
265
            ["X"],
266
            ["X_int8"],
267
            Y_scale=X_scale,
268
            Y_zero_point=X_zero_point
269
        )
270
        ref_net.Int8FCFakeAcc32NNPI(
271
            ["X_int8", "W_int8", "b"],
272
            ["Y_int8"],
273
            Y_scale=X_scale,
274
            Y_zero_point=X_zero_point,
275
        )
276
        ref_net.Int8DequantizeNNPI(
277
            ["Y_int8"],
278
            ["Y"]
279
        )
280
        ref_net.Proto().external_output.append("Y")
281

282
        # run ref_net
283
        workspace.RunNetOnce(ref_net)
284
        Y_fbgemm = workspace.FetchBlob("Y")
285

286
        # run onnxifi net
287
        ref_net.Proto().op[0].type = "Int8Quantize"
288
        ref_net.Proto().op[1].type = "Int8FC"
289
        ref_net.Proto().op[2].type = "Int8Dequantize"
290
        net_onnxified = onnxifi_caffe2_net(
291
            ref_net.Proto(),
292
            {},
293
            debug=True,
294
            adjust_batch=False,
295
            use_onnx=False,
296
            weight_names=["W_int8", "b"],
297
        )
298
        num_onnxified_ops = sum(
299
            1 if o.type == "Onnxifi" else 0 for o in net_onnxified.op
300
        )
301
        np.testing.assert_equal(num_onnxified_ops, 1)
302
        workspace.CreateNet(net_onnxified)
303
        workspace.RunNet(net_onnxified.name)
304
        Y_glow = workspace.FetchBlob("Y")
305

306
        if not np.allclose(Y_glow, Y_fbgemm):
307
            diff_Y = np.abs(Y_glow - Y_fbgemm)
308
            print_test_debug_info(
309
                "int8_fc",
310
                {
311
                    "seed": rand_seed,
312
                    "n": n,
313
                    "X": X_fp32,
314
                    "W": W_fp32,
315
                    "b": b_fp32,
316
                    "Y_fbgemm": Y_fbgemm,
317
                    "Y_glow": Y_glow,
318
                    "diff": diff_Y,
319
                    "maxdiff": diff_Y.max(axis=1),
320
                },
321
            )
322
            assert 0
323

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.