pytorch

Форк
0
/
test_sls_4bit_nnpi_fp16.py 
215 строк · 7.6 Кб
1
import numpy as np
2
import unittest
3

4
# Must happen before importing caffe2.python.*
5
import caffe2.python.fakelowp.init_shared_libs  # noqa
6

7
from hypothesis import given, settings
8
from hypothesis import strategies as st
9
from caffe2.proto import caffe2_pb2
10
from caffe2.python import core, workspace
11
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
12
from caffe2.python.fakelowp.test_utils import print_test_debug_info
13
import caffe2.python.serialized_test.serialized_test_util as serial
14
import datetime
15

16
workspace.GlobalInit(["caffe2", "--glow_global_fp16=1",
17
                      "--glow_global_fused_scale_offset_fp16=1",
18
                      "--glow_global_force_sls_fp16_accum=1"])
19

20

21
class SparseLengthsSum4BitFakeNNPIFp16Test(serial.SerializedTestCase):
22
    @given(seed=st.integers(0, 65535))
23
    @settings(deadline=datetime.timedelta(seconds=10))
24
    def test_slws_fused_4bit_rowwise_all_same(self, seed):
25
        np.random.seed(seed)
26
        workspace.ResetWorkspace()
27
        n = 1
28
        m = 2
29
        data = np.ones((n, m)).astype(np.float32) * 0.2 - 0.1
30
        max_segments = 5
31
        max_segment_length = 100
32
        num_lengths = np.random.randint(1, max_segments + 1)
33
        # number of segments to run
34
        lengths = np.random.randint(0, max_segment_length + 1,
35
                                    size=num_lengths).astype(np.int32)
36
        num_indices = np.sum(lengths)
37
        indices = np.zeros(num_indices, dtype=np.int64)
38
        weights = np.random.uniform(low=-0.5, high=0.5, size=[len(indices)])\
39
            .astype(np.float32)
40
        weights = np.ones(len(indices)).astype(np.float32)
41
        pred_net = caffe2_pb2.NetDef()
42
        pred_net.name = "pred"
43
        pred_net.external_input.extend(
44
            ["quantized_data", "weights", "indices", "lengths"])
45
        pred_net.external_output.append("Y")
46
        pred_net.op.add().CopyFrom(
47
            core.CreateOperator(
48
                "SparseLengthsWeightedSumFused4BitRowwise",
49
                ["quantized_data", "weights", "indices", "lengths"],
50
                ["Y"],
51
            )
52
        )
53
        ref_net = caffe2_pb2.NetDef()
54
        ref_net.name = "ref"
55
        ref_net.external_input.extend(
56
            ["quantized_data", "weights", "indices", "lengths"])
57
        ref_net.external_output.append("Y")
58
        ref_net.op.add().CopyFrom(
59
            core.CreateOperator(
60
                "SparseLengthsWeightedSumFused4BitRowwiseFakeFP16NNPI",
61
                ["quantized_data", "weights", "indices", "lengths"],
62
                ["Y"],
63
            )
64
        )
65
        workspace.FeedBlob("data", data)
66
        workspace.RunOperatorOnce(
67
            core.CreateOperator(
68
                "FloatToFused4BitRowwiseQuantized",
69
                ['data'],
70
                ['quantized_data']
71
            )
72
        )
73
        print("quantized", workspace.FetchBlob("quantized_data"))
74
        pred_net_onnxified = onnxifi_caffe2_net(
75
            pred_net,
76
            {},
77
            max_batch_size=max_segments,
78
            max_seq_size=max_segment_length,
79
            debug=True,
80
            adjust_batch=True,
81
            use_onnx=False
82
        )
83
        num_onnxified_ops = sum(
84
            1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
85
        np.testing.assert_equal(num_onnxified_ops, 1)
86
        workspace.FeedBlob("indices", indices)
87
        workspace.FeedBlob("lengths", lengths)
88
        workspace.FeedBlob("weights", weights)
89
        workspace.CreateNet(pred_net_onnxified)
90
        workspace.CreateNet(ref_net)
91
        workspace.RunNet(pred_net_onnxified.name)
92
        Y_glow = workspace.FetchBlob('Y')
93
        workspace.RunNet(ref_net.name)
94
        Y_c2 = workspace.FetchBlob('Y')
95
        if not np.allclose(Y_c2, Y_glow):
96
            print_test_debug_info(
97
                "slws_fused_4bit_rowwise",
98
                {"seed": seed,
99
                 "indices": indices,
100
                 "data": data,
101
                 "lengths": lengths,
102
                 "weights": weights,
103
                 "Y_c2": Y_c2,
104
                 "Y_glow": Y_glow,
105
                 "diff": Y_glow - Y_c2,
106
                 "rowwise_diff": (Y_glow - Y_c2)[:, 0]})
107
            assert(0)
108

109

110
    @given(
111
        seed=st.integers(0, 65535),
112
        num_rows=st.integers(2, 20),
113
        embedding_dim=st.sampled_from([8, 12, 16, 24, 32, 54, 64, 72, 128]),
114
        batch_size=st.integers(1, 32),
115
        max_weight=st.integers(0, 1),
116
    )
117
    @settings(deadline=datetime.timedelta(seconds=10))
118
    def test_slws_fused_4bit_rowwise(self, seed, num_rows, embedding_dim, batch_size, max_weight):
119
        workspace.ResetWorkspace()
120
        np.random.seed(seed)
121
        data = np.random.rand(num_rows, embedding_dim).astype(np.float32)
122
        data = data * 1e-3
123

124
        lengths = np.random.choice(np.arange(1, num_rows), batch_size).astype(np.int32)
125
        _indices = []
126
        for length in lengths:
127
            _indices.extend(np.random.choice(np.arange(1, num_rows), length))
128
        indices = np.asarray(_indices).astype(np.int64)
129

130
        weights = np.random.uniform(
131
            low=0,
132
            high=max_weight,
133
            size=[len(indices)]
134
        ).astype(np.float32) - max_weight / 2.0
135
        pred_net = caffe2_pb2.NetDef()
136
        pred_net.name = "pred"
137
        pred_net.external_input.extend(
138
            ["quantized_data", "weights", "indices", "lengths"])
139
        pred_net.external_output.append("Y")
140
        pred_net.op.add().CopyFrom(
141
            core.CreateOperator(
142
                "SparseLengthsWeightedSumFused4BitRowwise",
143
                ["quantized_data", "weights", "indices", "lengths"],
144
                ["Y"],
145
            )
146
        )
147

148
        ref_net = caffe2_pb2.NetDef()
149
        ref_net.name = "ref"
150
        ref_net.external_input.extend(
151
            ["quantized_data", "weights", "indices", "lengths"])
152
        ref_net.external_output.append("Y")
153
        ref_net.op.add().CopyFrom(
154
            core.CreateOperator(
155
                "SparseLengthsWeightedSumFused4BitRowwiseFakeFP16NNPI",
156
                ["quantized_data", "weights", "indices", "lengths"],
157
                ["Y"],
158
            )
159
        )
160

161
        workspace.FeedBlob("data", data)
162
        workspace.RunOperatorOnce(
163
            core.CreateOperator(
164
                "FloatToFused4BitRowwiseQuantized",
165
                ["data"],
166
                ["quantized_data"]
167
            )
168
        )
169

170
        pred_net_onnxified = onnxifi_caffe2_net(
171
            pred_net,
172
            {},
173
            max_batch_size=batch_size,
174
            max_seq_size=np.max(lengths),
175
            debug=True,
176
            adjust_batch=True,
177
            use_onnx=False
178
        )
179

180
        num_onnxified_ops = sum(
181
            1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
182
        np.testing.assert_equal(num_onnxified_ops, 1)
183

184
        workspace.FeedBlob("indices", indices)
185
        workspace.FeedBlob("lengths", lengths)
186
        workspace.FeedBlob("weights", weights)
187

188
        workspace.CreateNet(pred_net_onnxified)
189
        workspace.CreateNet(ref_net)
190

191
        workspace.RunNet(pred_net_onnxified.name)
192
        Y_glow = workspace.FetchBlob('Y')
193

194
        workspace.RunNet(ref_net.name)
195
        Y_c2 = workspace.FetchBlob('Y')
196

197
        if not np.allclose(Y_c2, Y_glow):
198
            print_test_debug_info(
199
                "slws_fused_4bit_rowwise",
200
                {
201
                    "seed": seed,
202
                    "indices": indices,
203
                    "data": data.shape,
204
                    "lengths": lengths,
205
                    "weights": weights,
206
                    "Y_c2": Y_c2.shape,
207
                    "Y_glow": Y_glow.shape,
208
                    "diff": Y_glow - Y_c2,
209
                    "rowwise_diff": (Y_glow - Y_c2)[:, 0]
210
                }
211
            )
212
            assert(0)
213

214
if __name__ == '__main__':
215
    unittest.main()
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.