pytorch

Форк
0
/
test_qlinear_packed_params.py 
262 строки · 9.8 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: mobile"]
3

4
import tempfile
5
import torch
6
from torch.ao.nn.sparse.quantized.dynamic.linear import Linear
7
from torch.testing._internal.common_quantization import (
8
    skipIfNoFBGEMM,
9
    skipIfNoQNNPACK,
10
)
11
from torch.testing._internal.common_quantized import (
12
    qengine_is_qnnpack,
13
    override_quantized_engine,
14
    override_cpu_allocator_for_qnnpack
15
)
16
from torch.testing._internal.common_utils import TestCase
17

18
class TestQlinearPackedParams(TestCase):
19
    def qlinear_packed_params_test(self, allow_non_zero_zero_points=False):
20
        # copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations,
21
        # so row/col block indices match that example, but with blocks and
22
        # scaled rows
23
        weight_fp32 = torch.Tensor([
24
            [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
25
            [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
26
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
27
        ])
28

29
        row_block_size = 1
30
        col_block_size = 4
31
        out_features = weight_fp32.shape[0]
32
        in_features = weight_fp32.shape[1]
33

34
        scales = [2.0, 6.0, 12.0]
35
        zero_points = [
36
            ((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features)
37
        ]
38
        dtype = torch.qint8
39

40
        wide_weight_fp32 = torch.zeros((3, 4008))  # 4000 is tile width for Fbgemm
41
        wide_weight_fp32[0][0] = 4
42
        wide_weight_fp32[0][4004] = 6
43
        wide_weight_fp32[1][0] = 8
44

45
        per_tensor_small = (
46
            torch.quantize_per_tensor(
47
                weight_fp32,
48
                scales[0],
49
                zero_points[0],
50
                dtype
51
            ),
52
            True,
53
            [0, 1, 3, 3],
54
            [2, 0, 1],
55
            [x + (1 if allow_non_zero_zero_points else 0) for x in [
56
                1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6
57
            ]],
58
        )
59

60
        per_channel_small = (
61
            torch.quantize_per_channel(
62
                weight_fp32,
63
                torch.Tensor(scales),
64
                torch.Tensor(zero_points).to(torch.int),
65
                0,  # axis = 0
66
                dtype,
67
            ),
68
            False,
69
            [0, 1, 3, 3],
70
            [2, 0, 1],
71
            [x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([
72
                1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2
73
            ])],
74
        )
75

76
        per_tensor_large = (
77
            torch.quantize_per_tensor(
78
                wide_weight_fp32,
79
                scales[0],
80
                zero_points[0],
81
                dtype,
82
            ),
83
            True,
84
            [0, 2, 3, 3],
85
            [0, 1001, 0],
86
            [x + (1 if allow_non_zero_zero_points else 0) for x in [
87
                2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0
88
            ]],
89
        )
90

91
        for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [
92
            per_tensor_small, per_channel_small, per_tensor_large
93
        ]:
94
            lin = Linear(
95
                out_features=weight.shape[0],
96
                in_features=weight.shape[1],
97
                row_block_size=row_block_size,
98
                col_block_size=col_block_size,
99
                bias=True,
100
                dtype=dtype,
101
            )
102

103
            bias = torch.ones(size=(weight.shape[0],))
104

105
            lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
106

107
            serialized = lin._packed_params._packed_params.__getstate__()
108

109
            (
110
                _,  # version
111
                bias_,
112
                out_features_block_size_,
113
                in_features_block_size_,
114
                weight_scales_,
115
                weight_zero_points_,
116
                quantization_scheme_,
117
                row_block_indices_,
118
                col_block_indices_,
119
                weights_,
120
                output_channels_,
121
                input_channels_
122
            ) = serialized[0]
123

124
            # Test Serialization
125
            self.assertEqual(bias_, bias)
126
            self.assertEqual(out_features_block_size_, row_block_size)
127
            self.assertEqual(in_features_block_size_, col_block_size)
128
            self.assertEqual(weight_scales_, [scales[0]] if is_per_tensor_quantized else scales)
129
            self.assertEqual(weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points)
130
            self.assertEqual(quantization_scheme_, is_per_tensor_quantized)
131
            self.assertEqual(row_block_indices_, expected_row_block_indices)
132
            self.assertEqual(col_block_indices_, expected_col_block_indices)
133
            self.assertEqual(weights_.tolist(), [v + 128 for v in expected_weights])  # weights are serialized as +128
134
            self.assertEqual(output_channels_, weight.shape[0])
135
            self.assertEqual(input_channels_, weight.shape[1])
136

137
            # Test Unpacking
138
            (weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias()
139
            self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight))
140
            self.assertEqual(bias_, bias)
141
            self.assertEqual(out_features_block_size_, row_block_size)
142
            self.assertEqual(in_features_block_size_, col_block_size)
143

144
            # Test Deserialization
145
            with tempfile.TemporaryFile() as file_buff:
146
                torch.save(lin, file_buff)
147
                file_buff.seek(0)
148
                lin2 = torch.load(file_buff)
149
                self.assertEqual(lin._weight_bias(), lin2._weight_bias())
150
                # Serialize -> Deserialize -> Serialize should match Serialize
151
                self.assertEqual(serialized, lin2._packed_params._packed_params.__getstate__())
152

153
                # Test that op output is preserved by serialize -> deserialize
154
                if qengine_is_qnnpack():
155
                    x = torch.rand(size=(1, weight.shape[1]))
156
                    y1 = lin(x)
157
                    y2 = lin2(x)
158
                    self.assertEqual(y1, y2)
159

160

161
    @skipIfNoFBGEMM
162
    def test_qlinear_packed_params_fbgemm(self):
163
        torch.manual_seed(0)
164
        with override_quantized_engine('fbgemm'):
165
            self.qlinear_packed_params_test(allow_non_zero_zero_points=False)
166

167

168
    @skipIfNoQNNPACK
169
    def test_qlinear_packed_params_qnnpack(self):
170
        torch.manual_seed(0)
171
        with override_quantized_engine('qnnpack'):
172
            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
173
                self.qlinear_packed_params_test(allow_non_zero_zero_points=True)
174

175
    def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self):
176
        torch.manual_seed(0)
177

178
        weight_fp32 = torch.Tensor([
179
            [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
180
            [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
181
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
182
        ])
183

184
        row_block_size = 1
185
        col_block_size = 4
186
        out_features = weight_fp32.shape[0]
187
        in_features = weight_fp32.shape[1]
188

189
        scales = [2.0, 3.0, 7.0]
190
        zero_points = [0 for _ in range(out_features)]
191
        dtype = torch.qint8
192

193
        x = torch.rand(size=(1, weight_fp32.shape[1]))
194

195
        def make_lin_get_state_weight_bias_and_save():
196
            weight = torch.quantize_per_tensor(
197
                weight_fp32,
198
                scales[0],
199
                zero_points[0],
200
                dtype,
201
            )
202
            lin = Linear(
203
                out_features=weight.shape[0],
204
                in_features=weight.shape[1],
205
                row_block_size=row_block_size,
206
                col_block_size=col_block_size,
207
                bias=True,
208
                dtype=dtype,
209
            )
210
            bias = torch.ones(size=(weight.shape[0],))
211
            lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
212

213
            state = lin._packed_params._packed_params.__getstate__()
214
            weight_bias = lin._weight_bias()
215

216
            file_buff = tempfile.TemporaryFile()
217
            torch.save(lin, file_buff)
218
            file_buff.seek(0)
219

220
            return ((state, weight_bias), file_buff)
221

222
        def load_get_state_weight_bias(f_b):
223
            lin2 = torch.load(f_b)
224
            state = lin2._packed_params._packed_params.__getstate__()
225
            weight_bias = lin2._weight_bias()
226
            f_b.close()
227
            return (state, weight_bias)
228

229
        def packed_params_data_with_int32_indices(data_as_state_and_weight_bias):
230
            (st, weight_bias) = data_as_state_and_weight_bias
231
            (s0, s1) = st
232
            s0_updated = tuple([
233
                # 7 and 8 are row and col block indices respectively
234
                v if (i != 7 and i != 8) else v.to(torch.int32) for (i, v) in enumerate(list(s0))
235
            ])
236
            return ((s0_updated, s1), weight_bias)
237

238
        # Test Fbgemm -> Qnnpack
239
        with override_quantized_engine('fbgemm'):
240
            packed_params_data_1a, file_buff_1 = make_lin_get_state_weight_bias_and_save()
241

242
        with override_quantized_engine('qnnpack'):
243
            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
244
                packed_params_data_1b = load_get_state_weight_bias(file_buff_1)
245

246
        self.assertEqual(
247
            packed_params_data_with_int32_indices(packed_params_data_1a),
248
            packed_params_data_with_int32_indices(packed_params_data_1b),
249
        )
250

251
        # Test Qnnpack -> Fbgemm
252
        with override_quantized_engine('qnnpack'):
253
            with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
254
                packed_params_data_2a, file_buff_2 = make_lin_get_state_weight_bias_and_save()
255

256
        with override_quantized_engine('fbgemm'):
257
            packed_params_data_2b = load_get_state_weight_bias(file_buff_2)
258

259
        self.assertEqual(
260
            packed_params_data_with_int32_indices(packed_params_data_2a),
261
            packed_params_data_with_int32_indices(packed_params_data_2b),
262
        )
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.