pytorch

Форк
0
/
test_sls_8bit_nnpi_fp32.py 
264 строки · 8.5 Кб
1
import unittest
2

3
# Must happen before importing caffe2.python.*
4
import caffe2.python.fakelowp.init_shared_libs  # noqa
5
import datetime
6
import numpy as np
7
from hypothesis import given, settings
8
from hypothesis import strategies as st
9
from caffe2.proto import caffe2_pb2
10
from caffe2.python import core, workspace
11
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
12
from caffe2.python.fakelowp.test_utils import print_test_debug_info
13
import caffe2.python.serialized_test.serialized_test_util as serial
14

15
workspace.GlobalInit(
16
    [
17
        "caffe2",
18
        "--glow_global_fp16=0",
19
        "--glow_global_fused_scale_offset_fp16=0",
20
        "--glow_global_force_sls_fp16_accum=0",
21
    ]
22
)
23
GLOW_MATMUL_ATOL = 1e-5
24
GLOW_MATMUL_RTOL = 1e-3
25

26
class SparseLengthsSum8BitFakeNNPIFp32Test(serial.SerializedTestCase):
27
    @given(
28
        seed=st.integers(0, 65535),
29
        num_rows=st.integers(2, 20),
30
        embedding_dim=st.sampled_from([8, 12, 16, 24, 32, 54, 64, 128]),
31
        batch_size=st.integers(1, 5),
32
        max_weight=st.integers(0, 100),
33
    )
34
    @settings(deadline=datetime.timedelta(seconds=10))
35
    def test_slws_fused_8bit_rowwise_acc32_nnpi(
36
        self, seed, num_rows, embedding_dim, batch_size, max_weight
37
    ):
38
        workspace.GlobalInit(
39
            [
40
                "caffe2",
41
                "--glow_global_fp16=0",
42
                "--glow_global_fused_scale_offset_fp16=0",
43
                "--glow_global_force_sls_fp16_accum=0",
44
            ]
45
        )
46

47
        workspace.ResetWorkspace()
48
        np.random.seed(seed)
49
        data = np.random.rand(num_rows, embedding_dim).astype(np.float32)
50
        lengths = np.random.choice(np.arange(1, num_rows), batch_size).astype(np.int32)
51

52
        _indices = []
53
        for length in lengths:
54
            _indices.extend(np.random.choice(np.arange(1, num_rows), length))
55
        indices = np.asarray(_indices).astype(np.int64)
56

57
        weights = np.random.uniform(
58
            low=0,
59
            high=max_weight,
60
            size=[len(indices)]
61
        ).astype(np.float32)
62

63
        pred_net = caffe2_pb2.NetDef()
64
        pred_net.name = "pred"
65
        pred_net.external_input.extend(
66
            ["quantized_data", "weights", "indices", "lengths"]
67
        )
68
        pred_net.external_output.append("Y")
69
        pred_net.op.add().CopyFrom(
70
            core.CreateOperator(
71
                "SparseLengthsWeightedSumFused8BitRowwise",
72
                ["quantized_data", "weights", "indices", "lengths"],
73
                ["Y"],
74
            )
75
        )
76

77
        ref_net = caffe2_pb2.NetDef()
78
        ref_net.name = "ref"
79
        ref_net.external_input.extend(
80
            ["quantized_data", "weights", "indices", "lengths"]
81
        )
82
        ref_net.external_output.append("Y")
83
        ref_net.op.add().CopyFrom(
84
            core.CreateOperator(
85
                "SparseLengthsWeightedSumFused8BitRowwiseFakeFP32NNPI",
86
                ["quantized_data", "weights", "indices", "lengths"],
87
                ["Y"],
88
            )
89
        )
90

91
        workspace.FeedBlob("data", data)
92
        workspace.RunOperatorOnce(
93
            core.CreateOperator(
94
                "FloatToFused8BitRowwiseQuantized",
95
                ["data"],
96
                ["quantized_data"]
97
            )
98
        )
99
        onnxified_net = onnxifi_caffe2_net(
100
            pred_net,
101
            {},
102
            max_batch_size=batch_size,
103
            max_seq_size=np.max(lengths),
104
            debug=True,
105
            adjust_batch=True,
106
            use_onnx=False,
107
        )
108
        num_onnxified_ops = sum(
109
            1 if o.type == "Onnxifi" else 0 for o in onnxified_net.op)
110
        np.testing.assert_equal(num_onnxified_ops, 1)
111

112
        workspace.FeedBlob("indices", indices)
113
        workspace.FeedBlob("lengths", lengths)
114
        workspace.FeedBlob("weights", weights)
115

116
        workspace.CreateNet(onnxified_net)
117
        workspace.CreateNet(ref_net)
118

119
        workspace.RunNet(onnxified_net.name)
120
        Y_glow = workspace.FetchBlob("Y")
121

122
        workspace.RunNet(ref_net.name)
123
        Y_ref = workspace.FetchBlob("Y")
124

125
        diff = np.abs((Y_ref - Y_glow) / (Y_ref + 1e-8))
126
        max_err = np.max(diff, axis=1)
127
        num_offenders = (max_err > 0).sum()
128
        if num_offenders > 0:
129
            print_test_debug_info(
130
                "test_slws_fused_8bit_rowwise_acc32_nnpi",
131
                {
132
                    "seed": seed,
133
                    "num_rows": num_rows,
134
                    "embedding_dim": embedding_dim,
135
                    "batch_size": batch_size,
136
                    "indices": indices,
137
                    "data": data.shape,
138
                    "lengths": lengths,
139
                    "weights": weights,
140
                    "Y_glow": Y_glow,
141
                    "Y_ref": Y_ref,
142
                    "diff": diff,
143
                    "rowwise_diff": np.max(diff, axis=1),
144
                },
145
            )
146
            assert 0
147

148

149
    @given(seed=st.integers(0, 65535))
150
    @settings(deadline=datetime.timedelta(seconds=10))
151
    def test_small_sls_acc32(self, seed):
152
        workspace.GlobalInit(
153
            [
154
                "caffe2",
155
                "--glow_global_fp16=0",
156
                "--glow_global_fused_scale_offset_fp16=0",
157
                "--glow_global_force_sls_fp16_accum=0",
158
            ]
159
        )
160
        np.random.seed(seed)
161
        workspace.ResetWorkspace()
162

163
        n = 2
164
        DIM = 3
165
        data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32)
166

167
        lengths = np.array([n], dtype=np.int32)
168
        indices = np.array(range(n), dtype=np.int64)
169
        weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32)
170

171
        pred_net = caffe2_pb2.NetDef()
172
        pred_net.name = "pred"
173
        pred_net.external_input.extend(
174
            ["quantized_data", "weights", "indices", "lengths"]
175
        )
176
        pred_net.external_output.append("Y")
177
        pred_net.op.add().CopyFrom(
178
            core.CreateOperator(
179
                "SparseLengthsWeightedSumFused8BitRowwise",
180
                ["quantized_data", "weights", "indices", "lengths"],
181
                ["Y"],
182
            )
183
        )
184

185
        ref_net = caffe2_pb2.NetDef()
186
        ref_net.name = "ref"
187
        ref_net.external_input.extend(
188
            ["quantized_data", "weights", "indices", "lengths"]
189
        )
190
        ref_net.external_output.append("Y")
191
        ref_net.op.add().CopyFrom(
192
            core.CreateOperator(
193
                "SparseLengthsWeightedSumFused8BitRowwiseFakeFP32NNPI",
194
                ["quantized_data", "weights", "indices", "lengths"],
195
                ["Y"],
196
            )
197
        )
198

199
        workspace.FeedBlob("data", data)
200
        workspace.RunOperatorOnce(
201
            core.CreateOperator(
202
                "FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"]
203
            )
204
        )
205

206
        quantized_data = workspace.FetchBlob("quantized_data")
207

208
        onnxified_net = onnxifi_caffe2_net(
209
            pred_net,
210
            {},
211
            max_batch_size=1,
212
            max_seq_size=n,
213
            debug=True,
214
            adjust_batch=True,
215
            use_onnx=False,
216
        )
217
        num_onnxified_ops = sum(
218
            1 if o.type == "Onnxifi" else 0 for o in onnxified_net.op)
219
        np.testing.assert_equal(num_onnxified_ops, 1)
220

221
        workspace.FeedBlob("indices", indices)
222
        workspace.FeedBlob("lengths", lengths)
223
        workspace.FeedBlob("weights", weights)
224

225
        workspace.CreateNet(onnxified_net)
226
        workspace.CreateNet(ref_net)
227

228
        workspace.RunNet(onnxified_net.name)
229
        Y_glow = workspace.FetchBlob("Y")
230

231
        workspace.RunNet(ref_net.name)
232
        Y_ref = workspace.FetchBlob("Y")
233

234
        diff = np.abs((Y_ref - Y_glow) / (Y_ref + 1e-8))
235
        max_err = np.max(diff, axis=1)
236
        num_offenders = (max_err > 0).sum()
237
        if num_offenders > 0:
238
            np.set_printoptions(precision=12)
239
            print(
240
                "ref",
241
                Y_ref.astype(np.float16).astype(np.float32),
242
                "glow",
243
                Y_glow.astype(np.float16).astype(np.float32),
244
            )
245
            print_test_debug_info(
246
                "test_small_sls_acc32",
247
                {
248
                    "seed": seed,
249
                    "indices": indices,
250
                    "data": data,
251
                    "quantized_data": quantized_data,
252
                    "lengths": lengths,
253
                    "weights": weights,
254
                    "Y_glow": Y_glow,
255
                    "Y_ref": Y_ref,
256
                    "diff": diff,
257
                    "rowwise_diff": np.max(diff, axis=1),
258
                },
259
            )
260
            assert 0
261

262

263
if __name__ == '__main__':
264
    unittest.main()
265

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.