pytorch

Форк
0
/
test_structured_sparsifier.py 
1039 строк · 36.5 Кб
1
# Owner(s): ["module: unknown"]
2
import copy
3
import logging
4
import random
5

6
import torch
7
from torch import nn
8
from torch.ao.pruning._experimental.pruner import (
9
    SaliencyPruner,
10
    LSTMSaliencyPruner,
11
    BaseStructuredSparsifier,
12
    FakeStructuredSparsity,
13
    FPGMPruner
14
)
15
from torch.nn.utils import parametrize
16

17
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
18
from torch.testing._internal.common_pruning import (
19
    SimpleLinear,
20
    LinearBias,
21
    LinearActivation,
22
    LinearActivationFunctional,
23
    SimpleConv2d,
24
    Conv2dBias,
25
    Conv2dActivation,
26
    Conv2dPadBias,
27
    Conv2dPool,
28
    Conv2dPoolFlatten,
29
    Conv2dPoolFlattenFunctional,
30
    LSTMLinearModel,
31
    LSTMLayerNormLinearModel,
32
    rows_are_subset,
33
)
34

35

36
logging.basicConfig(
37
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
38
)
39

40
DEVICES = {
41
    torch.device("cpu"),
42
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
43
}
44

45

46
class SimplePruner(BaseStructuredSparsifier):
47
    def update_mask(self, module, tensor_name, **kwargs):
48
        getattr(module.parametrizations, tensor_name)[0].mask[1] = False
49

50

51
class ImplementedPruner(BaseStructuredSparsifier):
52
    def update_mask(self, module, tensor_name, **kwargs):
53
        """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
54
        num_rows = len(module.parametrizations[tensor_name][0].mask)
55
        prune = random.sample(list(range(num_rows)), num_rows // 3)
56
        module.parametrizations[tensor_name][0].mask[prune] = False
57

58

59
class BottomHalfLSTMPruner(BaseStructuredSparsifier):
60
    """
61
    Pruner that will remove the bottom half of the rows.
62
    This is primarily meant for testing purposes
63
    """
64

65
    def update_mask(self, module, tensor_name, **kwargs):
66
        for p in getattr(module.parametrizations, tensor_name):
67
            if isinstance(p, FakeStructuredSparsity):
68
                mask = p.mask
69
                masks = torch.split(mask, len(mask) // 4)
70
                for small in masks:
71
                    num = len(small)
72
                    small[num // 2 :] = False
73
                new_mask = torch.cat(masks)
74
                mask.data = new_mask.data
75

76
class TestSaliencyPruner(TestCase):
77
    def test_saliency_pruner_update_mask(self):
78
        """Test that we prune out the row with the lowest saliency (first row)"""
79
        model = SimpleLinear()
80
        with torch.no_grad():
81
            model.linear1.weight = nn.Parameter(
82
                torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
83
            )
84
        pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
85
        pruner = SaliencyPruner({})
86

87
        pruner.prepare(model, pruning_config)
88
        pruner.enable_mask_update = True
89
        pruner.step()
90
        pruned_model = pruner.prune()
91

92
        expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
93
        pruned = pruned_model.linear1.weight
94

95
        assert expected.shape == pruned.shape
96
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
97

98
    def test_lstm_saliency_pruner_update_mask(self):
99
        model = LSTMLinearModel(
100
            input_dim=2,
101
            hidden_dim=2,
102
            output_dim=2,
103
            num_layers=1,
104
        )
105

106
        manual_weights = torch.Tensor([[1, 1],
107
                                       [2, 2],
108
                                       [2, 2],
109
                                       [1, 1],
110
                                       [-1, -1],
111
                                       [-2, -2],
112
                                       [-2, -2],
113
                                       [-1, -1]])
114

115
        with torch.no_grad():
116
            model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
117
            model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
118
            model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
119
            model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])
120

121
        config = [
122
            {"tensor_fqn": "lstm.weight_ih_l0"},
123
            {"tensor_fqn": "lstm.weight_hh_l0"},
124
        ]
125
        lstm_input = torch.ones((1, 2))
126
        fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
127
        fx_pruner.prepare(model, config)
128
        fx_pruner.enable_mask_update = True
129
        fx_pruner.step()
130

131
        model.eval()
132
        pruned_model = fx_pruner.prune()
133
        pruned_model.eval()
134

135
        # make sure both models run
136
        model(lstm_input)
137
        pruned_model(lstm_input)
138

139
        # make sure lowest saliency rows are pruned
140
        expected = torch.Tensor([[2, 2],
141
                                 [2, 2],
142
                                 [-2, -2],
143
                                 [-2, -2]])
144
        pruned = model.lstm.weight_ih_l0
145
        assert expected.shape == pruned.shape
146
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
147

148
        expected = torch.Tensor([[2],
149
                                 [2],
150
                                 [-2],
151
                                 [-2]])
152
        pruned = model.lstm.weight_hh_l0
153
        assert expected.shape == pruned.shape
154
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
155

156
        expected = torch.Tensor([2, 2, -2, -2])
157
        for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
158
            assert expected.shape == pruned.shape
159
            assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
160

161

162

163
class TestBaseStructuredSparsifier(TestCase):
164
    def _check_pruner_prepared(self, model, pruner, device):
165
        for config in pruner.groups:
166
            module = config["module"]
167
            assert module.weight.device.type == device.type
168
            # Check mask exists
169
            assert config["tensor_fqn"] in pruner.state
170
            # Check parametrization exists and is correct
171
            assert parametrize.is_parametrized(module)
172
            assert hasattr(module, "parametrizations")
173
            # Assume that this is the 1st/only parametrization
174
            assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
175

176
    def _check_pruner_valid_before_step(self, model, pruner, device):
177
        for config in pruner.groups:
178
            modules = []
179
            if type(config["module"]) is tuple:
180
                modules.extend(config["module"])
181
            else:
182
                module = config["module"]
183
                modules.append(module)
184
            for module in modules:
185
                assert module.weight.device.type == device.type
186
                assert module.parametrizations.weight[0].mask.dtype == torch.bool
187

188
    def _check_pruner_valid_after_step(self, model, pruner, mask, device):
189
        for config in pruner.groups:
190
            modules = []
191
            if type(config["module"]) is tuple:
192
                modules.extend(config["module"])
193
            else:
194
                module = config["module"]
195
                modules.append(module)
196
            for module in modules:
197
                assert module.weight.device.type == device.type
198
                total = module.parametrizations.weight[0].mask.numel()
199
                assert (
200
                    module.parametrizations.weight[0].mask.count_nonzero()
201
                    == total - mask
202
                )
203

204
    def _test_constructor_on_device(self, model, device):
205
        self.assertRaisesRegex(
206
            TypeError,
207
            "BaseStructuredSparsifier.*update_mask",
208
            BaseStructuredSparsifier,
209
        )
210
        model1 = copy.deepcopy(model).to(device)
211
        pruner = SimplePruner(None)
212
        pruner.prepare(model1, None)
213
        pruner.enable_mask_update = True
214
        for g in pruner.groups:
215
            module = g["module"]
216
            assert module.weight.device.type == device.type
217
        assert len(pruner.groups) == 5
218
        pruner.step()
219
        # Can instantiate the model with configs
220
        model2 = copy.deepcopy(model).to(device)
221
        pruner = SimplePruner({"test": 3})
222
        pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
223
        assert len(pruner.groups) == 1
224
        assert pruner.groups[0]["module_fqn"] == "seq.0"
225
        assert "test" in pruner.groups[0]
226
        assert pruner.groups[0]["test"] == 3
227

228
    def test_constructor(self):
229
        model = SimpleLinear()
230
        for device in DEVICES:
231
            self._test_constructor_on_device(model, torch.device(device))
232

233
    def _test_prepare_linear_on_device(self, model, device):
234
        model = copy.deepcopy(model).to(device)
235
        x = torch.ones(128, 7, device=device)
236
        pruner = SimplePruner(None)
237
        pruner.prepare(model, None)
238
        self._check_pruner_prepared(model, pruner, device)
239
        assert model(x).shape == (128, 10)
240

241
    def test_prepare_linear(self):
242
        models = [
243
            SimpleLinear(),
244
            LinearBias(),
245
            LinearActivation(),
246
            LinearActivationFunctional(),
247
        ]  # without and with bias
248
        for device in DEVICES:
249
            for model in models:
250
                self._test_prepare_linear_on_device(model, torch.device(device))
251

252
    def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
253
        x = torch.ones((1, 1, 28, 28), device=device)
254
        pruner = SimplePruner(None)
255
        pruner.prepare(model, config)
256
        self._check_pruner_prepared(model, pruner, device)
257
        assert model(x).shape == expected_shape
258

259
    def test_prepare_conv2d(self):
260
        models = [
261
            SimpleConv2d(),
262
            Conv2dBias(),
263
            Conv2dActivation(),
264
            Conv2dPadBias(),
265
            Conv2dPool(),
266
        ]
267
        shapes = [
268
            (1, 52, 20, 20),
269
            (1, 52, 18, 18),
270
            (1, 52, 18, 18),
271
            (1, 52, 24, 24),
272
            (1, 52, 3, 3),
273
        ]
274
        configs = [None, None, None, None, None]
275
        for device in DEVICES:
276
            for model, shape, config in zip(models, shapes, configs):
277
                model = model.to(device)
278
                self._test_prepare_conv2d_on_device(
279
                    model, shape, config, torch.device(device)
280
                )
281

282
    def _test_step_linear_on_device(self, model, device):
283
        model = model.to(device)
284
        x = torch.ones(7, 7, device=device)
285
        pruner = SimplePruner(None)
286
        pruner.prepare(model, None)
287
        pruner.enable_mask_update = True
288
        self._check_pruner_valid_before_step(model, pruner, device)
289
        pruner.step()
290
        self._check_pruner_valid_after_step(model, pruner, 1, device)
291

292
    def test_step_linear(self):
293
        models = [
294
            SimpleLinear(),
295
            LinearBias(),
296
            LinearActivation(),
297
            LinearActivationFunctional(),
298
        ]
299
        for device in DEVICES:
300
            for model in models:
301
                self._test_step_linear_on_device(model, torch.device(device))
302

303
    def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
304
        model = model.to(device)
305
        x = torch.ones((1, 1, 28, 28), device=device)
306
        pruner = SimplePruner(None)
307
        pruner.prepare(model, config)
308
        pruner.enable_mask_update = True
309
        self._check_pruner_valid_before_step(model, pruner, device)
310
        pruner.step()
311
        self._check_pruner_valid_after_step(model, pruner, 1, device)
312
        assert model(x).shape == expected_shape
313

314
    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
315
    def test_step_conv2d(self):
316
        models = [
317
            SimpleConv2d(),
318
            Conv2dBias(),
319
            Conv2dActivation(),
320
            Conv2dPadBias(),
321
            Conv2dPool(),
322
        ]
323
        shapes = [
324
            (1, 52, 20, 20),
325
            (1, 52, 18, 18),
326
            (1, 52, 18, 18),
327
            (1, 52, 24, 24),
328
            (1, 52, 3, 3),
329
        ]
330
        configs = [None, None, None, None, None]
331
        for device in DEVICES:
332
            for model, shape, config in zip(models, shapes, configs):
333
                self._test_step_conv2d_on_device(
334
                    model, shape, config, torch.device(device)
335
                )
336

337
    def _check_pruner_pruned(self, model, pruner, device):
338
        for config in pruner.groups:
339
            module = config["module"]
340
            assert not hasattr(module, "parametrizations")
341
            assert not hasattr(module, "mask")
342

343
    def _test_linear_on_device(
344
        self, model, config, expected_shape, device, also_prune_bias
345
    ):
346
        model = model.to(device)
347
        model.eval()
348
        num_original_params = sum(p.numel() for p in model.parameters())
349
        x = torch.ones(128, 7, device=device)
350

351
        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
352
        pruner.prepare(model, config)
353
        pruner.enable_mask_update = True
354
        pruner.step()
355

356
        y_expected = model(x)
357

358
        assert y_expected.shape == (128, 10)
359
        self._check_pruner_prepared(model, pruner, device)
360

361
        # Pruning step
362
        pruned = pruner.prune()
363
        y_pruned = pruned(x)
364
        num_pruned_params = sum(p.numel() for p in pruned.parameters())
365

366
        assert y_pruned.shape == expected_shape
367
        self._check_pruner_pruned(model, pruner, device)
368
        if y_pruned.shape == y_expected.shape:
369
            assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
370
            assert num_pruned_params < num_original_params
371

372
    def test_prune_linear_linear(self):
373
        r"""test pruning linear-> linear modules"""
374
        configs, shapes = [], []
375
        configs.append(
376
            [
377
                {"tensor_fqn": "seq.0.weight"},
378
                {"tensor_fqn": "seq.1.weight"},
379
                {"tensor_fqn": "seq.2.weight"},
380
            ]
381
        )
382
        shapes.append((128, 10))
383

384
        configs.append(
385
            [
386
                {"tensor_fqn": "seq.0.weight"},
387
                {"tensor_fqn": "seq.1.weight"},
388
                {"tensor_fqn": "seq.2.weight"},
389
                {"tensor_fqn": "linear1.weight"},
390
            ]
391
        )
392
        shapes.append((128, 10))
393

394
        configs.append(
395
            [
396
                {"tensor_fqn": "seq.0.weight"},
397
                {"tensor_fqn": "seq.2.weight"},
398
            ]
399
        )
400
        shapes.append((128, 10))
401
        for device in DEVICES:
402
            for also_prune_bias in [True, False]:
403
                for config, shape in zip(configs, shapes):
404
                    self._test_linear_on_device(
405
                        SimpleLinear(),
406
                        config,
407
                        shape,
408
                        torch.device(device),
409
                        also_prune_bias,
410
                    )
411

412
    def test_prune_linear_bias_linear(self):
413
        # linear(bias) -> linear(no bias)
414
        configs, shapes = [], []
415
        configs.append(
416
            [
417
                {"tensor_fqn": "seq.0.weight"},
418
                {"tensor_fqn": "seq.1.weight"},
419
            ]
420
        )
421
        shapes.append((128, 10))
422

423
        # linear(bias) -> linear(bias)
424
        configs.append(
425
            [
426
                {"tensor_fqn": "seq.2.weight"},
427
                {"tensor_fqn": "seq.3.weight"},
428
            ]
429
        )
430
        shapes.append((128, 10))
431

432
        # linear(no bias) -> linear(bias)
433
        configs.append(
434
            [
435
                {"tensor_fqn": "seq.0.weight"},
436
                {"tensor_fqn": "seq.1.weight"},
437
                {"tensor_fqn": "seq.2.weight"},
438
            ]
439
        )
440
        shapes.append((128, 10))
441

442
        for device in DEVICES:
443
            for also_prune_bias in [True, False]:
444
                for config, shape in zip(configs, shapes):
445
                    self._test_linear_on_device(
446
                        LinearBias(),
447
                        config,
448
                        shape,
449
                        torch.device(device),
450
                        also_prune_bias,
451
                    )
452

453
    def test_prune_linear_activation_linear(self):
454
        config = [
455
            {"tensor_fqn": "seq.0.weight"},
456
            {"tensor_fqn": "seq.2.weight"},
457
            {"tensor_fqn": "seq.4.weight"},
458
            {"tensor_fqn": "linear1.weight"},
459
        ]
460
        shape = (128, 10)
461

462
        for device in DEVICES:
463
            for also_prune_bias in [True, False]:
464
                # test version with nn.Modules
465
                self._test_linear_on_device(
466
                    LinearActivation(),
467
                    config,
468
                    shape,
469
                    torch.device(device),
470
                    also_prune_bias,
471
                )
472
                # test functional version
473
                self._test_linear_on_device(
474
                    LinearActivationFunctional(),
475
                    config,
476
                    shape,
477
                    torch.device(device),
478
                    also_prune_bias,
479
                )
480

481
    def _test_conv2d_on_device(
482
        self, model, config, x, expected_shape, device, also_prune_bias
483
    ):
484
        model = model.to(device)
485
        num_original_params = sum(p.numel() for p in model.parameters())
486
        model.eval()
487

488
        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
489
        pruner.prepare(model, config)
490
        pruner.enable_mask_update = True
491
        pruner.step()
492

493
        y_expected = model(x)
494
        assert y_expected.shape == expected_shape
495

496
        self._check_pruner_prepared(model, pruner, device)
497

498
        # Fusion step
499
        pruned = pruner.prune()
500
        y_pruned = pruned(x)
501
        num_pruned_params = sum(p.numel() for p in pruned.parameters())
502

503
        assert y_pruned.shape == expected_shape
504
        self._check_pruner_pruned(model, pruner, device)
505
        if y_pruned.shape == y_expected.shape:
506
            # TODO This rtol is a little high, need to double check if something specific is causing this to fail
507
            assert torch.isclose(
508
                y_expected,
509
                y_pruned,
510
                rtol=1e-3,
511
                atol=1e-3,
512
            ).all(), f"fail for {type(model)}"
513
            # only time this should be equal is when all layers have padding and we can't prune
514
            assert num_pruned_params <= num_original_params
515

516
    def test_prune_conv2d_conv2d(self):
517
        configs, shapes = [], []
518
        # all within sequential blocks
519
        configs.append(
520
            [
521
                {"tensor_fqn": "seq.0.weight"},
522
            ]
523
        )
524
        shapes.append((1, 52, 20, 20))
525
        # prune across sequential blocks
526
        configs.append(
527
            [
528
                {"tensor_fqn": "seq.0.weight"},
529
                {"tensor_fqn": "seq.1.weight"},
530
                {"tensor_fqn": "conv2d1.weight"},
531
            ]
532
        )
533
        shapes.append((1, 52, 20, 20))
534

535
        for device in DEVICES:
536
            x = torch.ones((1, 1, 28, 28), device=device)
537
            for also_prune_bias in [True, False]:
538
                for config, shape in zip(configs, shapes):
539
                    self._test_conv2d_on_device(
540
                        SimpleConv2d(),
541
                        config,
542
                        x,
543
                        shape,
544
                        torch.device(device),
545
                        also_prune_bias,
546
                    )
547

548
    def test_prune_conv2d_bias_conv2d(self):
549
        # Conv2d with Bias and no Activation
550
        configs, shapes = [], []
551
        # conv2d(bias) -> conv2d(bias)
552
        configs.append(
553
            [
554
                {"tensor_fqn": "seq.0.weight"},
555
                {"tensor_fqn": "seq.1.weight"},
556
            ]
557
        )
558
        shapes.append((1, 52, 18, 18))
559

560
        # conv2d(no bias) -> conv2d(bias)
561
        configs.append(
562
            [
563
                {"tensor_fqn": "seq.0.weight"},
564
                {"tensor_fqn": "seq.1.weight"},
565
                {"tensor_fqn": "conv2d1.weight"},
566
            ]
567
        )
568
        shapes.append((1, 52, 18, 18))
569

570
        # conv2d(bias) -> conv2d(no bias)
571
        configs.append(
572
            [
573
                {"tensor_fqn": "seq.0.weight"},
574
                {"tensor_fqn": "seq.1.weight"},
575
                {"tensor_fqn": "seq.2.weight"},
576
            ]
577
        )
578
        shapes.append((1, 52, 18, 18))
579

580
        for device in DEVICES:
581
            x = torch.ones((1, 1, 28, 28), device=device)
582
            for also_prune_bias in [True, False]:
583
                for config, shape in zip(configs, shapes):
584
                    self._test_conv2d_on_device(
585
                        Conv2dBias(),
586
                        config,
587
                        x,
588
                        shape,
589
                        torch.device(device),
590
                        also_prune_bias,
591
                    )
592

593
    def test_prune_conv2d_activation_conv2d(self):
594
        # Conv2d with Activation and no Bias
595
        configs, shapes = [], []
596

597
        # conv2d(no bias) -> activation -> conv2d(no bias)
598
        configs.append(
599
            [
600
                {"tensor_fqn": "seq.4.weight"},
601
            ]
602
        )
603
        shapes.append((1, 52, 18, 18))
604

605
        # conv2d(bias) -> activation -> conv2d(bias)
606
        configs.append(
607
            [
608
                {"tensor_fqn": "seq.0.weight"},
609
                {"tensor_fqn": "seq.2.weight"},
610
            ]
611
        )
612
        shapes.append((1, 52, 18, 18))
613

614
        # conv2d(bias) -> activation -> conv2d(no bias)
615
        configs.append(
616
            [
617
                {"tensor_fqn": "seq.2.weight"},
618
                {"tensor_fqn": "seq.4.weight"},
619
            ]
620
        )
621
        shapes.append((1, 52, 18, 18))
622

623
        # conv2d(no bias) -> activation -> conv2d(bias)
624
        configs.append(
625
            [
626
                {"tensor_fqn": "conv2d1.weight"},
627
            ]
628
        )
629
        shapes.append((1, 52, 18, 18))
630

631
        for device in DEVICES:
632
            x = torch.ones((1, 1, 28, 28), device=device)
633
            for also_prune_bias in [True, False]:
634
                for config, shape in zip(configs, shapes):
635
                    self._test_conv2d_on_device(
636
                        Conv2dActivation(),
637
                        config,
638
                        x,
639
                        shape,
640
                        torch.device(device),
641
                        also_prune_bias,
642
                    )
643

644
    def test_prune_conv2d_padding_conv2d(self):
645
        # Conv2d with Padded layers after Bias layers
646
        configs, shapes = [], []
647

648
        # conv(padded, bias) -> conv(padded, bias)
649
        configs.append(
650
            [
651
                {"tensor_fqn": "seq.4.weight"},
652
            ]
653
        )
654
        shapes.append((1, 52, 24, 24))
655

656
        # conv(no bias, no pad) -> conv(padded, bias)
657
        configs.append(
658
            [
659
                {"tensor_fqn": "seq.2.weight"},
660
            ]
661
        )
662
        shapes.append((1, 52, 24, 24))
663

664
        # conv(padded, bias) -> conv ( no bias ,no pad)
665
        configs.append(
666
            [
667
                {"tensor_fqn": "seq.0.weight"},
668
            ]
669
        )
670
        shapes.append((1, 52, 24, 24))
671
        # conv(pad, bias) -> conv(no pad, bias)
672
        configs.append(
673
            [
674
                {"tensor_fqn": "seq.6.weight"},
675
            ]
676
        )
677
        shapes.append((1, 52, 24, 24))
678
        # conv(no pad, bias) -> conv(pad, bias)
679
        configs.append(
680
            [
681
                {"tensor_fqn": "seq.8.weight"},
682
            ]
683
        )
684
        shapes.append((1, 52, 24, 24))
685

686
        for device in DEVICES:
687
            x = torch.ones((1, 1, 28, 28), device=device)
688
            for also_prune_bias in [True, False]:
689
                for config, shape in zip(configs, shapes):
690
                    self._test_conv2d_on_device(
691
                        Conv2dPadBias(),
692
                        config,
693
                        x,
694
                        shape,
695
                        torch.device(device),
696
                        also_prune_bias,
697
                    )
698

699
    def test_prune_conv2d_pool_conv2d(self):
700
        # Conv2d with Pooling layers
701
        config = [
702
            {"tensor_fqn": "seq.0.weight"},
703
            {"tensor_fqn": "seq.3.weight"},
704
            {"tensor_fqn": "conv2d1.weight"},
705
            {"tensor_fqn": "conv2d2.weight"},
706
        ]
707
        shape = (1, 52, 3, 3)
708

709
        for device in DEVICES:
710
            x = torch.ones((1, 1, 28, 28), device=device)
711
            for also_prune_bias in [True, False]:
712
                self._test_conv2d_on_device(
713
                    Conv2dPool(),
714
                    config,
715
                    x,
716
                    shape,
717
                    torch.device(device),
718
                    also_prune_bias,
719
                )
720

721
    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
722
    def test_complex_conv2d(self):
723
        """Test fusion for models that contain Conv2d & Linear modules.
724
        Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
725
        config = [
726
            {"tensor_fqn": "seq.0.weight"},
727
            {"tensor_fqn": "seq.3.weight"},
728
            {"tensor_fqn": "conv2d1.weight"},
729
            {"tensor_fqn": "conv2d2.weight"},
730
        ]
731
        shape = (1, 13)
732

733
        for device in DEVICES:
734
            x = torch.ones((1, 1, 28, 28), device=device)
735
            for also_prune_bias in [True, False]:
736
                self._test_conv2d_on_device(
737
                    Conv2dPoolFlattenFunctional(),
738
                    config,
739
                    x,
740
                    shape,
741
                    torch.device(device),
742
                    also_prune_bias,
743
                )
744
                self._test_conv2d_on_device(
745
                    Conv2dPoolFlatten(),
746
                    config,
747
                    x,
748
                    shape,
749
                    torch.device(device),
750
                    also_prune_bias,
751
                )
752

753
    def test_prune_lstm_linear_multiple_layer(self):
754
        """
755
        Test fusion support for LSTM(multi-layer) -> Linear
756
        """
757
        model = LSTMLinearModel(
758
            input_dim=8,
759
            hidden_dim=8,
760
            output_dim=8,
761
            num_layers=2,
762
        )
763

764
        config = [
765
            {"tensor_fqn": "lstm.weight_ih_l0"},
766
            {"tensor_fqn": "lstm.weight_hh_l0"},
767
            {"tensor_fqn": "lstm.weight_ih_l1"},
768
            {"tensor_fqn": "lstm.weight_hh_l1"},
769
        ]
770

771
        lstm_input = torch.ones((1, 8))
772
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
773
        fx_pruner.prepare(model, config)
774

775
        fx_pruner.enable_mask_update = True
776
        fx_pruner.step()
777

778
        model.eval()
779
        _, _ = model(lstm_input)
780
        pruned_model = fx_pruner.prune()
781
        pruned_model.eval()
782
        _, _ = pruned_model(lstm_input)
783

784
        expected_params = dict(model.named_parameters())
785
        for name, param in model.named_parameters():
786
            assert name in expected_params
787
            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
788
            # Instead we check that the weights of the new LSTM are a subset of the weights of
789
            # the old LSTM
790
            assert rows_are_subset(param, expected_params[name])
791
            del expected_params[name]
792

793
        # assert we haven't deleted any keys
794
        assert len(expected_params) == 0
795

796
    def test_prune_lstm_linear_single_layer(self):
797
        """
798
        Test fusion support for LSTM (single-layer) -> Linear
799
        """
800
        model = LSTMLinearModel(
801
            input_dim=8,
802
            hidden_dim=8,
803
            output_dim=8,
804
            num_layers=1,
805
        )
806

807
        config = [
808
            {"tensor_fqn": "lstm.weight_ih_l0"},
809
            {"tensor_fqn": "lstm.weight_hh_l0"},
810
        ]
811

812
        lstm_input = torch.ones((1, 8))
813
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
814
        fx_pruner.prepare(model, config)
815
        fx_pruner.enable_mask_update = True
816
        fx_pruner.step()
817
        model.eval()
818

819
        out_expected, lstm_out_expected = model(lstm_input)
820
        pruned_model = fx_pruner.prune()
821
        pruned_model.eval()
822
        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
823
        r, c = lstm_out_expected.size()
824

825
        # We cannot check that y_expected == y_pruned as usual because
826
        # zeros vs. missing elements yield different numerical results.
827
        # Instead that we check that the pruned elements are the first half of the results
828
        # since we are using a BottomHalfLSTMPruner
829
        assert torch.isclose(
830
            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
831
        ).all()
832
        # also check that output of linear is the same shape, this means we've resized
833
        # linear columns correctly.
834
        assert out_expected.shape == out_pruned.shape
835

836
    def test_prune_lstm_layernorm_linear_multiple_layer(self):
837
        """
838
        Test fusion support for LSTM(multi-layer) -> Linear
839
        """
840
        model = LSTMLayerNormLinearModel(
841
            input_dim=8,
842
            output_dim=8,
843
            hidden_dim=8,
844
            num_layers=2,
845
        )
846

847
        config = [
848
            {"tensor_fqn": "lstm.weight_ih_l0"},
849
            {"tensor_fqn": "lstm.weight_hh_l0"},
850
            {"tensor_fqn": "lstm.weight_ih_l1"},
851
            {"tensor_fqn": "lstm.weight_hh_l1"},
852
        ]
853

854
        lstm_input = torch.ones((1, 8))
855
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
856
        fx_pruner.prepare(model, config)
857

858
        fx_pruner.enable_mask_update = True
859
        fx_pruner.step()
860

861
        model.eval()
862
        _, _ = model(lstm_input)
863
        pruned_model = fx_pruner.prune()
864
        pruned_model.eval()
865
        _, _ = pruned_model(lstm_input)
866

867
        expected_params = dict(model.named_parameters())
868
        for name, param in model.named_parameters():
869
            assert name in expected_params
870
            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
871
            # Instead we check that the weights of the new LSTM are a subset of the weights of
872
            # the old LSTM
873
            assert rows_are_subset(param, expected_params[name])
874
            del expected_params[name]
875

876
        # assert we haven't deleted any keys
877
        assert len(expected_params) == 0
878

879
    def test_prune_lstm_layernorm_linear_single_layer(self):
880
        """
881
        Test fusion support for LSTM (single-layer) -> Linear
882
        """
883
        model = LSTMLinearModel(
884
            input_dim=8,
885
            hidden_dim=8,
886
            output_dim=8,
887
            num_layers=1,
888
        )
889

890
        config = [
891
            {"tensor_fqn": "lstm.weight_ih_l0"},
892
            {"tensor_fqn": "lstm.weight_hh_l0"},
893
        ]
894

895
        lstm_input = torch.ones((1, 8))
896
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
897
        fx_pruner.prepare(model, config)
898
        fx_pruner.enable_mask_update = True
899
        fx_pruner.step()
900
        model.eval()
901

902
        out_expected, lstm_out_expected = model(lstm_input)
903
        pruned_model = fx_pruner.prune()
904
        pruned_model.eval()
905
        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
906
        r, c = lstm_out_expected.size()
907

908
        # We cannot check that y_expected == y_pruned as usual because
909
        # zeros vs. missing elements yield different numerical results.
910
        # Instead that we check that the pruned elements are the first half of the results
911
        # since we are using a BottomHalfLSTMPruner
912
        assert torch.isclose(
913
            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
914
        ).all()
915
        # also check that output of linear is the same shape, this means we've resized
916
        # linear columns correctly.
917
        assert out_expected.shape == out_pruned.shape
918

919
class TestFPGMPruner(TestCase):
920
    """
921
    Test case for the implementation of paper:
922
    `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
923
    """
924
    class SimpleConvFPGM(nn.Module):
925
        def __init__(self):
926
            super().__init__()
927
            self.conv2d1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False)
928
            # Manually set the filter weights for demonstration purposes
929
            """
930
            Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
931
            Different from the norm-based decision that prunes filter with value 0.1,
932
            FPGM will prune the one with value 2.0.
933
            """
934
            weights = torch.tensor([3.0, 2.0, 0.1])  # Weight weights for each filter
935
            weights = weights[:, None, None, None]  # broadcasting
936
            self.conv2d1.weight.data.copy_(torch.ones(self.conv2d1.weight.shape) * weights)
937

938
            # Second Convolutional Layer
939
            self.conv2d2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False)
940
            weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
941
            weights = weights[:, None, None, None]
942
            self.conv2d2.weight.data.copy_(torch.ones(self.conv2d2.weight.shape) * weights)
943

944
        def forward(self, x):
945
            x = self.conv2d1(x)
946
            x = self.conv2d2(x)
947
            return x
948

949
    def test_compute_distance(self, device="cpu"):
950
        """Test the distance computation function"""
951
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
952
        pruner = FPGMPruner(0.3)
953
        dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
954

955
        # compute the distance matrix using torch.cdist
956
        flattened_filters = torch.Tensor([
957
            [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
958
            [2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
959
            [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]
960
        ])
961

962
        """
963
        Expected distance matrix should have the following values:
964
            [0.0000, 3.0000, 8.7000],
965
            [3.0000, 0.0000, 5.7000],
966
            [8.7000, 5.7000, 0.0000],
967
        the distance should therefore be:
968
            [11.7000, 8.7000, 14.4000]
969
        """
970
        expected_dist_matrix_conv1 = torch.cdist(flattened_filters, flattened_filters, p=2)
971
        expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
972
        assert torch.isclose(dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07).all()
973

974
    def _test_update_mask_on_single_layer(self, expected_conv1, device):
975
        """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
976
        # test pruning with one layer of conv2d
977
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
978
        x = torch.ones((1, 1, 32, 32), device=device)
979
        pruner = FPGMPruner(0.3)
980
        config = [{"tensor_fqn": "conv2d1.weight"}]
981
        pruner.prepare(model, config)
982
        pruner.enable_mask_update = True
983
        pruner.step()
984
        assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
985
            "do not prune the least-norm filter"
986

987
        # fusion step
988
        pruned_model = pruner.prune()
989

990
        pruned_y = pruned_model(x)
991
        # assert shapes
992
        expected_conv1 = expected_conv1.to(device)
993
        assert pruned_y.shape == (1, 4, 32, 32)
994
        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
995
        assert pruned_model.conv2d2.weight.shape == (4, 2, 3, 3), "conv2d2 should have input channel pruned"
996
        # assert value
997
        assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
998

999
    def _test_update_mask_on_multiple_layer(self, expected_conv1, expected_conv2, device):
1000
        # the second setting
1001
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
1002
        x = torch.ones((1, 1, 32, 32), device=device)
1003
        pruner = FPGMPruner(0.3)
1004
        config = [
1005
            {"tensor_fqn": "conv2d1.weight"},
1006
            {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}
1007
        ]
1008
        pruner.prepare(model, config)
1009
        pruner.enable_mask_update = True
1010
        pruner.step()
1011
        # Get the masks for the two least-norm filters
1012
        mask1 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-1]
1013
        mask2 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-2]
1014
        # Check if either of the least-norm filters is not pruned
1015
        assert mask1.item() is not False or mask2.item() is not False, "Do not prune all least-norm filters"
1016

1017
        # fusion step
1018
        pruned_model = pruner.prune()
1019
        pruned_y = pruned_model(x)
1020
        # assert shapes
1021
        expected_conv1 = expected_conv1.to(device)
1022
        expected_conv2 = expected_conv2.to(device)
1023
        assert pruned_y.shape == (1, 2, 32, 32)
1024
        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
1025
        assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
1026
        # assert values
1027
        assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
1028
        assert torch.isclose(pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07).all()
1029

1030
    def test_update_mask(self):
1031
        weights = torch.tensor([3.0, 0.1])
1032
        expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]
1033

1034
        weights = torch.tensor([7.0, 0.4])
1035
        expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]
1036

1037
        for device in DEVICES:
1038
            self._test_update_mask_on_single_layer(expected_conv1, device)
1039
            self._test_update_mask_on_multiple_layer(expected_conv1, expected_conv2, device)
1040

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.