1
# Owner(s): ["module: unknown"]
7
import torch.ao.quantization as tq
9
from torch.ao import pruning
10
from torch.testing._internal.common_utils import TestCase
11
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, convert_to_reference_fx, prepare_qat_fx
12
from torch.ao.pruning import fqn_to_module
15
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
19
"sparsity_level": 0.8,
20
"sparse_block_shape": (1, 4),
24
def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
25
model = nn.Sequential(
36
model[4].qconfig = qconfig
37
model[5].qconfig = qconfig
39
sparsifier = pruning.WeightNormSparsifier(**sparse_defaults)
43
"tensor_fqn": '5.weight',
44
"sparsity_level": 0.7,
45
"sparse_block_shape": (1, 4),
48
{"tensor_fqn": "0.weight"},
50
return model, sparsifier, sparse_config
52
def _squash_mask_calibrate_and_convert(model, sparsifier, input):
54
sparsifier.squash_mask()
56
tq.convert(model, inplace=True)
58
def _calculate_sparsity(tensor):
59
return ((tensor == 0).sum() / tensor.numel()).item()
61
# This series of tests are to check the composability goals for sparsity and quantization. Namely
62
# that performing quantization and sparsity model manipulations in various orderings
63
# does not cause problems
64
class TestComposability(TestCase):
65
# This test checks whether performing quantization prepare before sparse prepare
66
# causes any issues and verifies that the correct observers are inserted and that
67
# the quantized model works as expected
68
def test_q_prep_before_s_prep(self):
73
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
75
tq.prepare(mod, inplace=True)
76
sparsifier.prepare(mod, config=sparse_config)
78
# check that correct modules had parametrizations added
79
self.assertTrue(hasattr(mod[0], "parametrizations"))
80
self.assertTrue(hasattr(mod[5], "parametrizations"))
81
# check that correct observers were inserted
82
self.assertTrue(hasattr(mod[5], "activation_post_process"))
84
_squash_mask_calibrate_and_convert(
85
mod, sparsifier, torch.randn(1, 4, 4, 4)
88
# check that final module is the expected quantized module and that the model runs
89
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
90
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
92
# This test checks whether performing sparsity prepare before quantization prepare
93
# causes any issues. In particular, previous quantization flow was unable to match
94
# the post sparse prepare module names (adding parametrizations changes the module class names)
95
# which would result in those parametrized modules not being quantized. This test verifies that
96
# the fix for this was successful.
97
def test_s_prep_before_q_prep(self):
102
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
104
sparsifier.prepare(mod, config=sparse_config)
105
tq.prepare(mod, inplace=True)
107
# check that correct modules had parametrizations added and
108
# that none were lost during prepare
109
self.assertTrue(hasattr(mod[0], "parametrizations"))
110
self.assertTrue(hasattr(mod[5], "parametrizations"))
112
# check that correct observers were inserted and that matching
113
# occurred successfully
114
self.assertTrue(hasattr(mod[5], "activation_post_process"))
116
_squash_mask_calibrate_and_convert(
117
mod, sparsifier, torch.randn(1, 4, 4, 4)
120
# check that final module is the expected quantized module and that the model runs
121
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
122
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
124
# if the sparsified modules have not undergone the final squash mask operation, its possible
125
# that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies
126
# both that the fix to the convert flow avoids this issue and that the resulting quantized
127
# module uses the sparse version of the weight value.
128
def test_convert_without_squash_mask(self):
133
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
135
sparsifier.prepare(mod, config=sparse_config)
136
tq.prepare(mod, inplace=True)
138
# check that correct modules had parametrizations added and
139
# that none were lost during prepare
140
self.assertTrue(hasattr(mod[0], "parametrizations"))
141
self.assertTrue(hasattr(mod[5], "parametrizations"))
143
# check that correct observers were inserted and that matching
144
# occurred successfully
145
self.assertTrue(hasattr(mod[5], "activation_post_process"))
147
sparsity_level = _calculate_sparsity(mod[5].weight)
148
mod(torch.randn(1, 4, 4, 4))
149
tq.convert(mod, inplace=True)
151
# check that final module is the expected quantized module and that the model runs
152
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
153
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
155
# check that module was actually sparsified
156
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
157
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
158
self.assertGreaterAlmostEqual(
159
sparsity_level, sparse_config[0]["sparsity_level"]
161
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
163
# This tests whether performing sparse prepare before fusion causes any issues. The
164
# worry was that the link created between the sparsifier and the modules that need to
165
# be sparsified would be broken.
166
def test_s_prep_before_fusion(self):
171
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
172
sparsifier.prepare(mod, config=sparse_config)
173
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
174
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
175
tq.prepare(mod, inplace=True)
177
# check that correct modules had parametrizations added and
178
# that none were lost during prepare or fusion
179
self.assertTrue(hasattr(mod[0], "parametrizations"))
180
self.assertTrue(hasattr(mod[5][0], "parametrizations"))
182
# check that correct observers were inserted and that matching
183
# occurred successfully
184
self.assertTrue(hasattr(mod[5], "activation_post_process"))
185
_squash_mask_calibrate_and_convert(
186
mod, sparsifier, torch.randn(1, 4, 4, 4)
189
# check that final module is the expected quantized module and that the model runs
190
self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
191
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
193
# This tests whether performing fusion before sparse prepare causes and issues. The
194
# main worry was that the links to the modules in the sparse config would be broken by fusion.
195
def test_fusion_before_s_prep(self):
200
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
201
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
203
# its absolutely broken by fusion but will still work if you put the correct fqn in
206
"tensor_fqn": "5.0.weight",
207
"sparsity_level": 0.7,
208
"sparse_block_shape": (1, 4),
209
"zeros_per_block": 4,
211
{"tensor_fqn": "0.weight"},
214
sparsifier.prepare(mod, config=sparse_config)
215
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
216
tq.prepare(mod, inplace=True)
218
# check that correct modules had parametrizations added and
219
# that none were lost during prepare
220
self.assertTrue(hasattr(mod[0], "parametrizations"))
221
self.assertTrue(hasattr(mod[5][0], "parametrizations"))
223
# check that correct observers were inserted and that matching
224
# occurred successfully
225
self.assertTrue(hasattr(mod[5], "activation_post_process"))
227
sparsity_level = _calculate_sparsity(mod[5][0].weight)
228
mod(torch.randn(1, 4, 4, 4))
229
tq.convert(mod, inplace=True)
231
# check that final module is the expected quantized module and that the model runs
232
self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
233
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
235
# check that module was actually sparsified
236
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
237
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
238
self.assertGreaterAlmostEqual(
239
sparsity_level, sparse_config[0]["sparsity_level"]
241
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
243
# This tests whether performing sparse prepare before qat prepare causes issues.
244
# The primary worries were that qat_prep wouldn't recognize the parametrized
245
# modules and that the convert step for qat would remove the parametrizations
247
def test_s_prep_before_qat_prep(self):
252
) = _get_model_and_sparsifier_and_sparse_config(
253
tq.get_default_qat_qconfig("fbgemm")
255
sparsifier.prepare(mod, config=sparse_config)
256
tq.prepare_qat(mod, inplace=True)
257
self.assertTrue(hasattr(mod[0], "parametrizations"))
258
self.assertTrue(hasattr(mod[5], "parametrizations"))
260
# check that correct observers were inserted and that matching
261
# occurred successfully
262
self.assertTrue(hasattr(mod[5], "activation_post_process"))
263
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
264
_squash_mask_calibrate_and_convert(
265
mod, sparsifier, torch.randn(1, 4, 4, 4)
267
# check that final module is the expected quantized module and that the model runs
268
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
269
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
271
# check that module was actually sparsified
272
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
273
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
275
# This tests whether performing qat prepare before sparse prepare causes issues.
276
def test_qat_prep_before_s_prep(self):
277
mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config(
278
tq.get_default_qat_qconfig("fbgemm")
280
tq.prepare_qat(mod, inplace=True)
282
# need to setup sparse_config on new modules
285
"tensor_fqn": "5.weight",
286
"sparsity_level": 0.7,
287
"sparse_block_shape": (1, 4),
288
"zeros_per_block": 4,
290
{"tensor_fqn": "0.weight"},
292
sparsifier.prepare(mod, config=sparse_config)
294
# check that correct modules had parametrizations added and
295
# that none were lost during qat prepare
296
self.assertTrue(hasattr(mod[0], "parametrizations"))
297
self.assertTrue(hasattr(mod[5], "parametrizations"))
299
# check that correct observers were inserted and that matching
300
# occurred successfully
301
self.assertTrue(hasattr(mod[5], "activation_post_process"))
302
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
304
_squash_mask_calibrate_and_convert(
305
mod, sparsifier, torch.randn(1, 4, 4, 4)
308
# check that final module is the expected quantized module and that the model runs
309
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
310
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
312
# check that module was actually sparsified
313
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
314
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
316
def _module_has_activation_post_process(model, fqn_of_module):
317
for node in model.graph.nodes:
318
# look for an observer whose arg is the target module
319
if "activation_post_process" in node.name:
320
if node.args[0].target == fqn_of_module:
324
class TestFxComposability(TestCase):
325
r"""This series of tests checks that various steps of the quantization and sparsity flow
326
compose cleanly despite variation in sequencing.
328
def test_q_prep_fx_before_s_prep(self):
330
This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
331
compose cleanly without issue and that the final result is sparsified without
332
having to call squash mask between sparse prepare and convert_fx. This also tests the
333
automatic fusion that occurs during prepare_fx.
339
) = _get_model_and_sparsifier_and_sparse_config()
341
example = torch.randn(1, 4, 4, 4)
342
qconfig = tq.get_default_qconfig("fbgemm")
343
qconfig_mapping = tq.QConfigMapping() \
344
.set_module_name("4", qconfig) \
345
.set_module_name("5", qconfig)
348
mod = prepare_fx(mod, qconfig_mapping, (example,))
350
# its absolutely broken by auto fusion in fx
351
# but will still work if you put the correct fqn in
354
"tensor_fqn": "5.0.weight",
355
"sparsity_level": 0.7,
356
"sparse_block_shape": (1, 4),
357
"zeros_per_block": 4,
359
{"tensor_fqn": "0.0.weight"},
361
sparsifier.prepare(mod, config=sparse_config)
363
# check that correct modules had parametrizations added and
364
# that none were lost during prepare
365
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
366
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
368
# check that correct observers were inserted and that matching
369
# occurred successfully
370
self.assertTrue(_module_has_activation_post_process(mod, "5"))
372
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
374
mod = convert_fx(mod)
376
# check that final module is the expected quantized module and that the model runs
377
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
378
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
380
# check that module was actually sparsified
381
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
382
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
383
self.assertGreaterAlmostEqual(
384
sparsity_level, sparse_config[0]["sparsity_level"]
386
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
388
def test_q_prep_fx_s_prep_ref_conv(self):
390
This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx
391
compose cleanly without issue and that the final result is sparsified without
392
having to call squash mask before convert_to_reference_fx.
398
) = _get_model_and_sparsifier_and_sparse_config()
400
example = torch.randn(1, 4, 4, 4)
401
qconfig = tq.get_default_qconfig("fbgemm")
402
qconfig_mapping = tq.QConfigMapping() \
403
.set_module_name("4", qconfig) \
404
.set_module_name("5", qconfig)
406
mod = prepare_fx(mod, qconfig_mapping, (example,))
408
# its absolutely broken by auto fusion in fx
409
# but will still work if you put the correct fqn in
412
"tensor_fqn": "5.0.weight",
413
"sparsity_level": 0.7,
414
"sparse_block_shape": (1, 4),
415
"zeros_per_block": 4,
417
{"tensor_fqn": "0.0.weight"},
419
sparsifier.prepare(mod, config=sparse_config)
421
# check that correct modules had parametrizations added and
422
# that none were lost during prepare
423
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
424
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
426
# check that correct observers were inserted and that matching
427
# occurred successfully
428
self.assertTrue(_module_has_activation_post_process(mod, "5"))
430
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
432
mod = convert_to_reference_fx(mod)
434
# check that final module is the expected quantized module and that the model runs
435
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
436
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
437
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
439
# check that module was actually sparsified
440
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
441
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
442
self.assertGreaterAlmostEqual(
443
sparsity_level, sparse_config[0]["sparsity_level"]
445
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
447
def test_s_prep_before_q_prep_fx(self):
449
This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx
450
compose cleanly without issue and that the final result is sparsified without
451
having to call squash mask before convert_fx.
457
) = _get_model_and_sparsifier_and_sparse_config()
458
sparsifier.prepare(mod, config=sparse_config)
460
example = torch.randn(1, 4, 4, 4)
461
qconfig = tq.get_default_qconfig("fbgemm")
462
qconfig_mapping = tq.QConfigMapping() \
463
.set_module_name("4", qconfig) \
464
.set_module_name("5", qconfig)
465
mod = prepare_fx(mod, qconfig_mapping, (example,))
467
# check that correct modules had parametrizations added and
468
# that none were lost during prepare
469
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
470
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
472
# check that correct observers were inserted and that matching
473
# occurred successfully
474
self.assertTrue(_module_has_activation_post_process(mod, "5"))
476
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
478
mod = convert_fx(mod)
480
# check that final module is the expected quantized module and that the model runs
481
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
482
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
484
# check that module was actually sparsified
485
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
486
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
487
self.assertGreaterAlmostEqual(
488
sparsity_level, sparse_config[0]["sparsity_level"]
490
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
492
def test_s_prep_before_qat_prep_fx(self):
494
This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
495
compose cleanly without issue and that the final result is sparsified without
496
having to call squash mask before convert_fx.
502
) = _get_model_and_sparsifier_and_sparse_config()
503
sparsifier.prepare(mod, config=sparse_config)
505
example = torch.randn(1, 4, 4, 4)
506
qconfig = tq.get_default_qat_qconfig("fbgemm")
507
qconfig_mapping = tq.QConfigMapping() \
508
.set_module_name("4", qconfig) \
509
.set_module_name("5", qconfig)
510
mod = prepare_qat_fx(mod, qconfig_mapping, (example,))
512
# check that correct modules had parametrizations added and
513
# that none were lost during prepare
514
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
515
self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
516
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
518
# check that correct observers were inserted and that matching
519
# occurred successfully
520
self.assertTrue(_module_has_activation_post_process(mod, "5"))
522
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
524
mod = convert_fx(mod)
526
# check that final module is the expected quantized module and that the model runs
527
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
528
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
530
# check that module was actually sparsified
531
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
532
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
533
self.assertGreaterAlmostEqual(
534
sparsity_level, sparse_config[0]["sparsity_level"]
536
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
538
def test_s_prep_q_prep_fx_ref(self):
540
This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx
541
compose cleanly without issue and that the final result is sparsified without
542
having to call squash mask before convert_to_reference_fx.
548
) = _get_model_and_sparsifier_and_sparse_config()
549
sparsifier.prepare(mod, config=sparse_config)
551
example = torch.randn(1, 4, 4, 4)
552
qconfig = tq.get_default_qconfig("fbgemm")
553
qconfig_mapping = tq.QConfigMapping() \
554
.set_module_name("4", qconfig) \
555
.set_module_name("5", qconfig)
556
mod = prepare_fx(mod, qconfig_mapping, (example,))
558
# check that correct modules had parametrizations added and
559
# that none were lost during prepare
560
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
561
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
563
# check that correct observers were inserted and that matching
564
# occurred successfully
565
self.assertTrue(_module_has_activation_post_process(mod, "5"))
567
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
569
mod = convert_to_reference_fx(mod)
571
# check that final module is the expected quantized module and that the model runs
572
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
573
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
574
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
576
# check that module was actually sparsified
577
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
578
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
579
self.assertGreaterAlmostEqual(
580
sparsity_level, sparse_config[0]["sparsity_level"]
582
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])