6
from torch.ao.nn.sparse.quantized.dynamic.linear import Linear
7
from torch.testing._internal.common_quantization import (
11
from torch.testing._internal.common_quantized import (
13
override_quantized_engine,
14
override_cpu_allocator_for_qnnpack
16
from torch.testing._internal.common_utils import TestCase
18
class TestQlinearPackedParams(TestCase):
19
def qlinear_packed_params_test(self, allow_non_zero_zero_points=False):
23
weight_fp32 = torch.Tensor([
24
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
25
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
26
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
31
out_features = weight_fp32.shape[0]
32
in_features = weight_fp32.shape[1]
34
scales = [2.0, 6.0, 12.0]
36
((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features)
40
wide_weight_fp32 = torch.zeros((3, 4008))
41
wide_weight_fp32[0][0] = 4
42
wide_weight_fp32[0][4004] = 6
43
wide_weight_fp32[1][0] = 8
46
torch.quantize_per_tensor(
55
[x + (1 if allow_non_zero_zero_points else 0) for x in [
56
1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6
61
torch.quantize_per_channel(
64
torch.Tensor(zero_points).to(torch.int),
71
[x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([
72
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2
77
torch.quantize_per_tensor(
86
[x + (1 if allow_non_zero_zero_points else 0) for x in [
87
2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0
91
for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [
92
per_tensor_small, per_channel_small, per_tensor_large
95
out_features=weight.shape[0],
96
in_features=weight.shape[1],
97
row_block_size=row_block_size,
98
col_block_size=col_block_size,
103
bias = torch.ones(size=(weight.shape[0],))
105
lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
107
serialized = lin._packed_params._packed_params.__getstate__()
112
out_features_block_size_,
113
in_features_block_size_,
116
quantization_scheme_,
125
self.assertEqual(bias_, bias)
126
self.assertEqual(out_features_block_size_, row_block_size)
127
self.assertEqual(in_features_block_size_, col_block_size)
128
self.assertEqual(weight_scales_, [scales[0]] if is_per_tensor_quantized else scales)
129
self.assertEqual(weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points)
130
self.assertEqual(quantization_scheme_, is_per_tensor_quantized)
131
self.assertEqual(row_block_indices_, expected_row_block_indices)
132
self.assertEqual(col_block_indices_, expected_col_block_indices)
133
self.assertEqual(weights_.tolist(), [v + 128 for v in expected_weights])
134
self.assertEqual(output_channels_, weight.shape[0])
135
self.assertEqual(input_channels_, weight.shape[1])
138
(weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias()
139
self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight))
140
self.assertEqual(bias_, bias)
141
self.assertEqual(out_features_block_size_, row_block_size)
142
self.assertEqual(in_features_block_size_, col_block_size)
145
with tempfile.TemporaryFile() as file_buff:
146
torch.save(lin, file_buff)
148
lin2 = torch.load(file_buff)
149
self.assertEqual(lin._weight_bias(), lin2._weight_bias())
151
self.assertEqual(serialized, lin2._packed_params._packed_params.__getstate__())
154
if qengine_is_qnnpack():
155
x = torch.rand(size=(1, weight.shape[1]))
158
self.assertEqual(y1, y2)
162
def test_qlinear_packed_params_fbgemm(self):
164
with override_quantized_engine('fbgemm'):
165
self.qlinear_packed_params_test(allow_non_zero_zero_points=False)
169
def test_qlinear_packed_params_qnnpack(self):
171
with override_quantized_engine('qnnpack'):
172
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
173
self.qlinear_packed_params_test(allow_non_zero_zero_points=True)
175
def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self):
178
weight_fp32 = torch.Tensor([
179
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
180
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
181
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
186
out_features = weight_fp32.shape[0]
187
in_features = weight_fp32.shape[1]
189
scales = [2.0, 3.0, 7.0]
190
zero_points = [0 for _ in range(out_features)]
193
x = torch.rand(size=(1, weight_fp32.shape[1]))
195
def make_lin_get_state_weight_bias_and_save():
196
weight = torch.quantize_per_tensor(
203
out_features=weight.shape[0],
204
in_features=weight.shape[1],
205
row_block_size=row_block_size,
206
col_block_size=col_block_size,
210
bias = torch.ones(size=(weight.shape[0],))
211
lin.set_weight_bias(weight, bias, row_block_size, col_block_size)
213
state = lin._packed_params._packed_params.__getstate__()
214
weight_bias = lin._weight_bias()
216
file_buff = tempfile.TemporaryFile()
217
torch.save(lin, file_buff)
220
return ((state, weight_bias), file_buff)
222
def load_get_state_weight_bias(f_b):
223
lin2 = torch.load(f_b)
224
state = lin2._packed_params._packed_params.__getstate__()
225
weight_bias = lin2._weight_bias()
227
return (state, weight_bias)
229
def packed_params_data_with_int32_indices(data_as_state_and_weight_bias):
230
(st, weight_bias) = data_as_state_and_weight_bias
234
v if (i != 7 and i != 8) else v.to(torch.int32) for (i, v) in enumerate(list(s0))
236
return ((s0_updated, s1), weight_bias)
239
with override_quantized_engine('fbgemm'):
240
packed_params_data_1a, file_buff_1 = make_lin_get_state_weight_bias_and_save()
242
with override_quantized_engine('qnnpack'):
243
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
244
packed_params_data_1b = load_get_state_weight_bias(file_buff_1)
247
packed_params_data_with_int32_indices(packed_params_data_1a),
248
packed_params_data_with_int32_indices(packed_params_data_1b),
252
with override_quantized_engine('qnnpack'):
253
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
254
packed_params_data_2a, file_buff_2 = make_lin_get_state_weight_bias_and_save()
256
with override_quantized_engine('fbgemm'):
257
packed_params_data_2b = load_get_state_weight_bias(file_buff_2)
260
packed_params_data_with_int32_indices(packed_params_data_2a),
261
packed_params_data_with_int32_indices(packed_params_data_2b),