pytorch

Форк
0
/
test_composability.py 
582 строки · 24.6 Кб
1
# Owner(s): ["module: unknown"]
2

3

4
import logging
5

6
import torch
7
import torch.ao.quantization as tq
8
from torch import nn
9
from torch.ao import pruning
10
from torch.testing._internal.common_utils import TestCase
11
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, convert_to_reference_fx, prepare_qat_fx
12
from torch.ao.pruning import fqn_to_module
13

14
logging.basicConfig(
15
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
16
)
17

18
sparse_defaults = {
19
    "sparsity_level": 0.8,
20
    "sparse_block_shape": (1, 4),
21
    "zeros_per_block": 4,
22
}
23

24
def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
25
    model = nn.Sequential(
26
        nn.Linear(4, 4),  # 0
27
        nn.ReLU(),
28
        nn.Linear(4, 4),  # 2
29
        nn.ReLU(),
30
        tq.QuantStub(),
31
        nn.Linear(4, 4),  # 5
32
        nn.ReLU(),
33
        tq.DeQuantStub(),
34
    )
35
    if qconfig:
36
        model[4].qconfig = qconfig
37
        model[5].qconfig = qconfig
38

39
    sparsifier = pruning.WeightNormSparsifier(**sparse_defaults)
40

41
    sparse_config = [
42
        {
43
            "tensor_fqn": '5.weight',
44
            "sparsity_level": 0.7,
45
            "sparse_block_shape": (1, 4),
46
            "zeros_per_block": 4,
47
        },
48
        {"tensor_fqn": "0.weight"},
49
    ]
50
    return model, sparsifier, sparse_config
51

52
def _squash_mask_calibrate_and_convert(model, sparsifier, input):
53
    sparsifier.step()
54
    sparsifier.squash_mask()
55
    model(input)
56
    tq.convert(model, inplace=True)
57

58
def _calculate_sparsity(tensor):
59
    return ((tensor == 0).sum() / tensor.numel()).item()
60

61
# This series of tests are to check the composability goals for sparsity and quantization. Namely
62
# that performing quantization and sparsity model manipulations in various orderings
63
# does not cause problems
64
class TestComposability(TestCase):
65
    # This test checks whether performing quantization prepare before sparse prepare
66
    # causes any issues and verifies that the correct observers are inserted and that
67
    # the quantized model works as expected
68
    def test_q_prep_before_s_prep(self):
69
        (
70
            mod,
71
            sparsifier,
72
            sparse_config,
73
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
74

75
        tq.prepare(mod, inplace=True)
76
        sparsifier.prepare(mod, config=sparse_config)
77

78
        # check that correct modules had parametrizations added
79
        self.assertTrue(hasattr(mod[0], "parametrizations"))
80
        self.assertTrue(hasattr(mod[5], "parametrizations"))
81
        # check that correct observers were inserted
82
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
83

84
        _squash_mask_calibrate_and_convert(
85
            mod, sparsifier, torch.randn(1, 4, 4, 4)
86
        )
87

88
        # check that final module is the expected quantized module and that the model runs
89
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
90
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
91

92
    # This test checks whether performing sparsity prepare before quantization prepare
93
    # causes any issues. In particular, previous quantization flow was unable to match
94
    # the post sparse prepare module names (adding parametrizations changes the module class names)
95
    # which would result in those parametrized modules not being quantized. This test verifies that
96
    # the fix for this was successful.
97
    def test_s_prep_before_q_prep(self):
98
        (
99
            mod,
100
            sparsifier,
101
            sparse_config,
102
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
103

104
        sparsifier.prepare(mod, config=sparse_config)
105
        tq.prepare(mod, inplace=True)
106

107
        # check that correct modules had parametrizations added and
108
        # that none were lost during prepare
109
        self.assertTrue(hasattr(mod[0], "parametrizations"))
110
        self.assertTrue(hasattr(mod[5], "parametrizations"))
111

112
        # check that correct observers were inserted and that matching
113
        # occurred successfully
114
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
115

116
        _squash_mask_calibrate_and_convert(
117
            mod, sparsifier, torch.randn(1, 4, 4, 4)
118
        )
119

120
        # check that final module is the expected quantized module and that the model runs
121
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
122
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
123

124
    # if the sparsified modules have not undergone the final squash mask operation, its possible
125
    # that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies
126
    # both that the fix to the convert flow avoids this issue and that the resulting quantized
127
    # module uses the sparse version of the weight value.
128
    def test_convert_without_squash_mask(self):
129
        (
130
            mod,
131
            sparsifier,
132
            sparse_config,
133
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
134

135
        sparsifier.prepare(mod, config=sparse_config)
136
        tq.prepare(mod, inplace=True)
137

138
        # check that correct modules had parametrizations added and
139
        # that none were lost during prepare
140
        self.assertTrue(hasattr(mod[0], "parametrizations"))
141
        self.assertTrue(hasattr(mod[5], "parametrizations"))
142

143
        # check that correct observers were inserted and that matching
144
        # occurred successfully
145
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
146
        sparsifier.step()
147
        sparsity_level = _calculate_sparsity(mod[5].weight)
148
        mod(torch.randn(1, 4, 4, 4))
149
        tq.convert(mod, inplace=True)
150

151
        # check that final module is the expected quantized module and that the model runs
152
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
153
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
154

155
        # check that module was actually sparsified
156
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
157
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
158
        self.assertGreaterAlmostEqual(
159
            sparsity_level, sparse_config[0]["sparsity_level"]
160
        )
161
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
162

163
    # This tests whether performing sparse prepare before fusion causes any issues. The
164
    # worry was that the link created between the sparsifier and the modules that need to
165
    # be sparsified would be broken.
166
    def test_s_prep_before_fusion(self):
167
        (
168
            mod,
169
            sparsifier,
170
            sparse_config,
171
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
172
        sparsifier.prepare(mod, config=sparse_config)
173
        tq.fuse_modules(mod, [["5", "6"]], inplace=True)
174
        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
175
        tq.prepare(mod, inplace=True)
176

177
        # check that correct modules had parametrizations added and
178
        # that none were lost during prepare or fusion
179
        self.assertTrue(hasattr(mod[0], "parametrizations"))
180
        self.assertTrue(hasattr(mod[5][0], "parametrizations"))
181

182
        # check that correct observers were inserted and that matching
183
        # occurred successfully
184
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
185
        _squash_mask_calibrate_and_convert(
186
            mod, sparsifier, torch.randn(1, 4, 4, 4)
187
        )
188

189
        # check that final module is the expected quantized module and that the model runs
190
        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
191
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
192

193
    # This tests whether performing fusion before sparse prepare causes and issues. The
194
    # main worry was that the links to the modules in the sparse config would be broken by fusion.
195
    def test_fusion_before_s_prep(self):
196
        (
197
            mod,
198
            sparsifier,
199
            _,
200
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
201
        tq.fuse_modules(mod, [["5", "6"]], inplace=True)
202

203
        # its absolutely broken by fusion but will still work if you put the correct fqn in
204
        sparse_config = [
205
            {
206
                "tensor_fqn": "5.0.weight",
207
                "sparsity_level": 0.7,
208
                "sparse_block_shape": (1, 4),
209
                "zeros_per_block": 4,
210
            },
211
            {"tensor_fqn": "0.weight"},
212
        ]
213

214
        sparsifier.prepare(mod, config=sparse_config)
215
        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
216
        tq.prepare(mod, inplace=True)
217

218
        # check that correct modules had parametrizations added and
219
        # that none were lost during prepare
220
        self.assertTrue(hasattr(mod[0], "parametrizations"))
221
        self.assertTrue(hasattr(mod[5][0], "parametrizations"))
222

223
        # check that correct observers were inserted and that matching
224
        # occurred successfully
225
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
226
        sparsifier.step()
227
        sparsity_level = _calculate_sparsity(mod[5][0].weight)
228
        mod(torch.randn(1, 4, 4, 4))
229
        tq.convert(mod, inplace=True)
230

231
        # check that final module is the expected quantized module and that the model runs
232
        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
233
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
234

235
        # check that module was actually sparsified
236
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
237
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
238
        self.assertGreaterAlmostEqual(
239
            sparsity_level, sparse_config[0]["sparsity_level"]
240
        )
241
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
242

243
    # This tests whether performing sparse prepare before qat prepare causes issues.
244
    # The primary worries were that qat_prep wouldn't recognize the parametrized
245
    # modules and that the convert step for qat would remove the parametrizations
246
    # from the modules.
247
    def test_s_prep_before_qat_prep(self):
248
        (
249
            mod,
250
            sparsifier,
251
            sparse_config,
252
        ) = _get_model_and_sparsifier_and_sparse_config(
253
            tq.get_default_qat_qconfig("fbgemm")
254
        )
255
        sparsifier.prepare(mod, config=sparse_config)
256
        tq.prepare_qat(mod, inplace=True)
257
        self.assertTrue(hasattr(mod[0], "parametrizations"))
258
        self.assertTrue(hasattr(mod[5], "parametrizations"))
259

260
        # check that correct observers were inserted and that matching
261
        # occurred successfully
262
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
263
        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
264
        _squash_mask_calibrate_and_convert(
265
            mod, sparsifier, torch.randn(1, 4, 4, 4)
266
        )
267
        # check that final module is the expected quantized module and that the model runs
268
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
269
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
270

271
        # check that module was actually sparsified
272
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
273
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
274

275
    # This tests whether performing qat prepare before sparse prepare causes issues.
276
    def test_qat_prep_before_s_prep(self):
277
        mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config(
278
            tq.get_default_qat_qconfig("fbgemm")
279
        )
280
        tq.prepare_qat(mod, inplace=True)
281

282
        # need to setup sparse_config on new modules
283
        sparse_config = [
284
            {
285
                "tensor_fqn": "5.weight",
286
                "sparsity_level": 0.7,
287
                "sparse_block_shape": (1, 4),
288
                "zeros_per_block": 4,
289
            },
290
            {"tensor_fqn": "0.weight"},
291
        ]
292
        sparsifier.prepare(mod, config=sparse_config)
293

294
        # check that correct modules had parametrizations added and
295
        # that none were lost during qat prepare
296
        self.assertTrue(hasattr(mod[0], "parametrizations"))
297
        self.assertTrue(hasattr(mod[5], "parametrizations"))
298

299
        # check that correct observers were inserted and that matching
300
        # occurred successfully
301
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
302
        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
303

304
        _squash_mask_calibrate_and_convert(
305
            mod, sparsifier, torch.randn(1, 4, 4, 4)
306
        )
307

308
        # check that final module is the expected quantized module and that the model runs
309
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
310
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
311

312
        # check that module was actually sparsified
313
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
314
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
315

316
def _module_has_activation_post_process(model, fqn_of_module):
317
    for node in model.graph.nodes:
318
        # look for an observer whose arg is the target module
319
        if "activation_post_process" in node.name:
320
            if node.args[0].target == fqn_of_module:
321
                return True
322
    return False
323

324
class TestFxComposability(TestCase):
325
    r"""This series of tests checks that various steps of the quantization and sparsity flow
326
    compose cleanly despite variation in sequencing.
327
    """
328
    def test_q_prep_fx_before_s_prep(self):
329
        r"""
330
        This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
331
        compose cleanly without issue and that the final result is sparsified without
332
        having to call squash mask between sparse prepare and convert_fx. This also tests the
333
        automatic fusion that occurs during prepare_fx.
334
        """
335
        (
336
            mod,
337
            sparsifier,
338
            _,
339
        ) = _get_model_and_sparsifier_and_sparse_config()
340

341
        example = torch.randn(1, 4, 4, 4)
342
        qconfig = tq.get_default_qconfig("fbgemm")
343
        qconfig_mapping = tq.QConfigMapping() \
344
            .set_module_name("4", qconfig) \
345
            .set_module_name("5", qconfig)
346

347

348
        mod = prepare_fx(mod, qconfig_mapping, (example,))
349

350
        # its absolutely broken by auto fusion in fx
351
        # but will still work if you put the correct fqn in
352
        sparse_config = [
353
            {
354
                "tensor_fqn": "5.0.weight",
355
                "sparsity_level": 0.7,
356
                "sparse_block_shape": (1, 4),
357
                "zeros_per_block": 4,
358
            },
359
            {"tensor_fqn": "0.0.weight"},
360
        ]
361
        sparsifier.prepare(mod, config=sparse_config)
362

363
        # check that correct modules had parametrizations added and
364
        # that none were lost during prepare
365
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
366
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
367

368
        # check that correct observers were inserted and that matching
369
        # occurred successfully
370
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
371
        sparsifier.step()
372
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
373
        mod(example)
374
        mod = convert_fx(mod)
375

376
        # check that final module is the expected quantized module and that the model runs
377
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
378
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
379

380
        # check that module was actually sparsified
381
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
382
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
383
        self.assertGreaterAlmostEqual(
384
            sparsity_level, sparse_config[0]["sparsity_level"]
385
        )
386
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
387

388
    def test_q_prep_fx_s_prep_ref_conv(self):
389
        r"""
390
        This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx
391
        compose cleanly without issue and that the final result is sparsified without
392
        having to call squash mask before convert_to_reference_fx.
393
        """
394
        (
395
            mod,
396
            sparsifier,
397
            _,
398
        ) = _get_model_and_sparsifier_and_sparse_config()
399

400
        example = torch.randn(1, 4, 4, 4)
401
        qconfig = tq.get_default_qconfig("fbgemm")
402
        qconfig_mapping = tq.QConfigMapping() \
403
            .set_module_name("4", qconfig) \
404
            .set_module_name("5", qconfig)
405

406
        mod = prepare_fx(mod, qconfig_mapping, (example,))
407

408
        # its absolutely broken by auto fusion in fx
409
        # but will still work if you put the correct fqn in
410
        sparse_config = [
411
            {
412
                "tensor_fqn": "5.0.weight",
413
                "sparsity_level": 0.7,
414
                "sparse_block_shape": (1, 4),
415
                "zeros_per_block": 4,
416
            },
417
            {"tensor_fqn": "0.0.weight"},
418
        ]
419
        sparsifier.prepare(mod, config=sparse_config)
420

421
        # check that correct modules had parametrizations added and
422
        # that none were lost during prepare
423
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
424
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
425

426
        # check that correct observers were inserted and that matching
427
        # occurred successfully
428
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
429
        sparsifier.step()
430
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
431
        mod(example)
432
        mod = convert_to_reference_fx(mod)
433

434
        # check that final module is the expected quantized module and that the model runs
435
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
436
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
437
        self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
438

439
        # check that module was actually sparsified
440
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
441
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
442
        self.assertGreaterAlmostEqual(
443
            sparsity_level, sparse_config[0]["sparsity_level"]
444
        )
445
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
446

447
    def test_s_prep_before_q_prep_fx(self):
448
        r"""
449
        This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx
450
        compose cleanly without issue and that the final result is sparsified without
451
        having to call squash mask before convert_fx.
452
        """
453
        (
454
            mod,
455
            sparsifier,
456
            sparse_config,
457
        ) = _get_model_and_sparsifier_and_sparse_config()
458
        sparsifier.prepare(mod, config=sparse_config)
459

460
        example = torch.randn(1, 4, 4, 4)
461
        qconfig = tq.get_default_qconfig("fbgemm")
462
        qconfig_mapping = tq.QConfigMapping() \
463
            .set_module_name("4", qconfig) \
464
            .set_module_name("5", qconfig)
465
        mod = prepare_fx(mod, qconfig_mapping, (example,))
466

467
        # check that correct modules had parametrizations added and
468
        # that none were lost during prepare
469
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
470
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
471

472
        # check that correct observers were inserted and that matching
473
        # occurred successfully
474
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
475
        sparsifier.step()
476
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
477
        mod(example)
478
        mod = convert_fx(mod)
479

480
        # check that final module is the expected quantized module and that the model runs
481
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
482
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
483

484
        # check that module was actually sparsified
485
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
486
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
487
        self.assertGreaterAlmostEqual(
488
            sparsity_level, sparse_config[0]["sparsity_level"]
489
        )
490
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
491

492
    def test_s_prep_before_qat_prep_fx(self):
493
        r"""
494
        This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
495
        compose cleanly without issue and that the final result is sparsified without
496
        having to call squash mask before convert_fx.
497
        """
498
        (
499
            mod,
500
            sparsifier,
501
            sparse_config,
502
        ) = _get_model_and_sparsifier_and_sparse_config()
503
        sparsifier.prepare(mod, config=sparse_config)
504

505
        example = torch.randn(1, 4, 4, 4)
506
        qconfig = tq.get_default_qat_qconfig("fbgemm")
507
        qconfig_mapping = tq.QConfigMapping() \
508
            .set_module_name("4", qconfig) \
509
            .set_module_name("5", qconfig)
510
        mod = prepare_qat_fx(mod, qconfig_mapping, (example,))
511

512
        # check that correct modules had parametrizations added and
513
        # that none were lost during prepare
514
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
515
        self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
516
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
517

518
        # check that correct observers were inserted and that matching
519
        # occurred successfully
520
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
521
        sparsifier.step()
522
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
523
        mod(example)
524
        mod = convert_fx(mod)
525

526
        # check that final module is the expected quantized module and that the model runs
527
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
528
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
529

530
        # check that module was actually sparsified
531
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
532
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
533
        self.assertGreaterAlmostEqual(
534
            sparsity_level, sparse_config[0]["sparsity_level"]
535
        )
536
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
537

538
    def test_s_prep_q_prep_fx_ref(self):
539
        r"""
540
        This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx
541
        compose cleanly without issue and that the final result is sparsified without
542
        having to call squash mask before convert_to_reference_fx.
543
        """
544
        (
545
            mod,
546
            sparsifier,
547
            sparse_config,
548
        ) = _get_model_and_sparsifier_and_sparse_config()
549
        sparsifier.prepare(mod, config=sparse_config)
550

551
        example = torch.randn(1, 4, 4, 4)
552
        qconfig = tq.get_default_qconfig("fbgemm")
553
        qconfig_mapping = tq.QConfigMapping() \
554
            .set_module_name("4", qconfig) \
555
            .set_module_name("5", qconfig)
556
        mod = prepare_fx(mod, qconfig_mapping, (example,))
557

558
        # check that correct modules had parametrizations added and
559
        # that none were lost during prepare
560
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
561
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
562

563
        # check that correct observers were inserted and that matching
564
        # occurred successfully
565
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
566
        sparsifier.step()
567
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
568
        mod(example)
569
        mod = convert_to_reference_fx(mod)
570

571
        # check that final module is the expected quantized module and that the model runs
572
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
573
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
574
        self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
575

576
        # check that module was actually sparsified
577
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
578
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
579
        self.assertGreaterAlmostEqual(
580
            sparsity_level, sparse_config[0]["sparsity_level"]
581
        )
582
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
583

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.