1
# Owner(s): ["module: unknown"]
8
from torch.ao.pruning._experimental.pruner import (
11
BaseStructuredSparsifier,
12
FakeStructuredSparsity,
15
from torch.nn.utils import parametrize
17
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
18
from torch.testing._internal.common_pruning import (
22
LinearActivationFunctional,
29
Conv2dPoolFlattenFunctional,
31
LSTMLayerNormLinearModel,
37
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
42
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
46
class SimplePruner(BaseStructuredSparsifier):
47
def update_mask(self, module, tensor_name, **kwargs):
48
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
51
class ImplementedPruner(BaseStructuredSparsifier):
52
def update_mask(self, module, tensor_name, **kwargs):
53
"""Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
54
num_rows = len(module.parametrizations[tensor_name][0].mask)
55
prune = random.sample(list(range(num_rows)), num_rows // 3)
56
module.parametrizations[tensor_name][0].mask[prune] = False
59
class BottomHalfLSTMPruner(BaseStructuredSparsifier):
61
Pruner that will remove the bottom half of the rows.
62
This is primarily meant for testing purposes
65
def update_mask(self, module, tensor_name, **kwargs):
66
for p in getattr(module.parametrizations, tensor_name):
67
if isinstance(p, FakeStructuredSparsity):
69
masks = torch.split(mask, len(mask) // 4)
72
small[num // 2 :] = False
73
new_mask = torch.cat(masks)
74
mask.data = new_mask.data
76
class TestSaliencyPruner(TestCase):
77
def test_saliency_pruner_update_mask(self):
78
"""Test that we prune out the row with the lowest saliency (first row)"""
79
model = SimpleLinear()
81
model.linear1.weight = nn.Parameter(
82
torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
84
pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
85
pruner = SaliencyPruner({})
87
pruner.prepare(model, pruning_config)
88
pruner.enable_mask_update = True
90
pruned_model = pruner.prune()
92
expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
93
pruned = pruned_model.linear1.weight
95
assert expected.shape == pruned.shape
96
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
98
def test_lstm_saliency_pruner_update_mask(self):
99
model = LSTMLinearModel(
106
manual_weights = torch.Tensor([[1, 1],
115
with torch.no_grad():
116
model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
117
model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
118
model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
119
model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])
122
{"tensor_fqn": "lstm.weight_ih_l0"},
123
{"tensor_fqn": "lstm.weight_hh_l0"},
125
lstm_input = torch.ones((1, 2))
126
fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
127
fx_pruner.prepare(model, config)
128
fx_pruner.enable_mask_update = True
132
pruned_model = fx_pruner.prune()
135
# make sure both models run
137
pruned_model(lstm_input)
139
# make sure lowest saliency rows are pruned
140
expected = torch.Tensor([[2, 2],
144
pruned = model.lstm.weight_ih_l0
145
assert expected.shape == pruned.shape
146
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
148
expected = torch.Tensor([[2],
152
pruned = model.lstm.weight_hh_l0
153
assert expected.shape == pruned.shape
154
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
156
expected = torch.Tensor([2, 2, -2, -2])
157
for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
158
assert expected.shape == pruned.shape
159
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
163
class TestBaseStructuredSparsifier(TestCase):
164
def _check_pruner_prepared(self, model, pruner, device):
165
for config in pruner.groups:
166
module = config["module"]
167
assert module.weight.device.type == device.type
169
assert config["tensor_fqn"] in pruner.state
170
# Check parametrization exists and is correct
171
assert parametrize.is_parametrized(module)
172
assert hasattr(module, "parametrizations")
173
# Assume that this is the 1st/only parametrization
174
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
176
def _check_pruner_valid_before_step(self, model, pruner, device):
177
for config in pruner.groups:
179
if type(config["module"]) is tuple:
180
modules.extend(config["module"])
182
module = config["module"]
183
modules.append(module)
184
for module in modules:
185
assert module.weight.device.type == device.type
186
assert module.parametrizations.weight[0].mask.dtype == torch.bool
188
def _check_pruner_valid_after_step(self, model, pruner, mask, device):
189
for config in pruner.groups:
191
if type(config["module"]) is tuple:
192
modules.extend(config["module"])
194
module = config["module"]
195
modules.append(module)
196
for module in modules:
197
assert module.weight.device.type == device.type
198
total = module.parametrizations.weight[0].mask.numel()
200
module.parametrizations.weight[0].mask.count_nonzero()
204
def _test_constructor_on_device(self, model, device):
205
self.assertRaisesRegex(
207
"BaseStructuredSparsifier.*update_mask",
208
BaseStructuredSparsifier,
210
model1 = copy.deepcopy(model).to(device)
211
pruner = SimplePruner(None)
212
pruner.prepare(model1, None)
213
pruner.enable_mask_update = True
214
for g in pruner.groups:
216
assert module.weight.device.type == device.type
217
assert len(pruner.groups) == 5
219
# Can instantiate the model with configs
220
model2 = copy.deepcopy(model).to(device)
221
pruner = SimplePruner({"test": 3})
222
pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
223
assert len(pruner.groups) == 1
224
assert pruner.groups[0]["module_fqn"] == "seq.0"
225
assert "test" in pruner.groups[0]
226
assert pruner.groups[0]["test"] == 3
228
def test_constructor(self):
229
model = SimpleLinear()
230
for device in DEVICES:
231
self._test_constructor_on_device(model, torch.device(device))
233
def _test_prepare_linear_on_device(self, model, device):
234
model = copy.deepcopy(model).to(device)
235
x = torch.ones(128, 7, device=device)
236
pruner = SimplePruner(None)
237
pruner.prepare(model, None)
238
self._check_pruner_prepared(model, pruner, device)
239
assert model(x).shape == (128, 10)
241
def test_prepare_linear(self):
246
LinearActivationFunctional(),
247
] # without and with bias
248
for device in DEVICES:
250
self._test_prepare_linear_on_device(model, torch.device(device))
252
def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
253
x = torch.ones((1, 1, 28, 28), device=device)
254
pruner = SimplePruner(None)
255
pruner.prepare(model, config)
256
self._check_pruner_prepared(model, pruner, device)
257
assert model(x).shape == expected_shape
259
def test_prepare_conv2d(self):
274
configs = [None, None, None, None, None]
275
for device in DEVICES:
276
for model, shape, config in zip(models, shapes, configs):
277
model = model.to(device)
278
self._test_prepare_conv2d_on_device(
279
model, shape, config, torch.device(device)
282
def _test_step_linear_on_device(self, model, device):
283
model = model.to(device)
284
x = torch.ones(7, 7, device=device)
285
pruner = SimplePruner(None)
286
pruner.prepare(model, None)
287
pruner.enable_mask_update = True
288
self._check_pruner_valid_before_step(model, pruner, device)
290
self._check_pruner_valid_after_step(model, pruner, 1, device)
292
def test_step_linear(self):
297
LinearActivationFunctional(),
299
for device in DEVICES:
301
self._test_step_linear_on_device(model, torch.device(device))
303
def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
304
model = model.to(device)
305
x = torch.ones((1, 1, 28, 28), device=device)
306
pruner = SimplePruner(None)
307
pruner.prepare(model, config)
308
pruner.enable_mask_update = True
309
self._check_pruner_valid_before_step(model, pruner, device)
311
self._check_pruner_valid_after_step(model, pruner, 1, device)
312
assert model(x).shape == expected_shape
314
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
315
def test_step_conv2d(self):
330
configs = [None, None, None, None, None]
331
for device in DEVICES:
332
for model, shape, config in zip(models, shapes, configs):
333
self._test_step_conv2d_on_device(
334
model, shape, config, torch.device(device)
337
def _check_pruner_pruned(self, model, pruner, device):
338
for config in pruner.groups:
339
module = config["module"]
340
assert not hasattr(module, "parametrizations")
341
assert not hasattr(module, "mask")
343
def _test_linear_on_device(
344
self, model, config, expected_shape, device, also_prune_bias
346
model = model.to(device)
348
num_original_params = sum(p.numel() for p in model.parameters())
349
x = torch.ones(128, 7, device=device)
351
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
352
pruner.prepare(model, config)
353
pruner.enable_mask_update = True
356
y_expected = model(x)
358
assert y_expected.shape == (128, 10)
359
self._check_pruner_prepared(model, pruner, device)
362
pruned = pruner.prune()
364
num_pruned_params = sum(p.numel() for p in pruned.parameters())
366
assert y_pruned.shape == expected_shape
367
self._check_pruner_pruned(model, pruner, device)
368
if y_pruned.shape == y_expected.shape:
369
assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
370
assert num_pruned_params < num_original_params
372
def test_prune_linear_linear(self):
373
r"""test pruning linear-> linear modules"""
374
configs, shapes = [], []
377
{"tensor_fqn": "seq.0.weight"},
378
{"tensor_fqn": "seq.1.weight"},
379
{"tensor_fqn": "seq.2.weight"},
382
shapes.append((128, 10))
386
{"tensor_fqn": "seq.0.weight"},
387
{"tensor_fqn": "seq.1.weight"},
388
{"tensor_fqn": "seq.2.weight"},
389
{"tensor_fqn": "linear1.weight"},
392
shapes.append((128, 10))
396
{"tensor_fqn": "seq.0.weight"},
397
{"tensor_fqn": "seq.2.weight"},
400
shapes.append((128, 10))
401
for device in DEVICES:
402
for also_prune_bias in [True, False]:
403
for config, shape in zip(configs, shapes):
404
self._test_linear_on_device(
408
torch.device(device),
412
def test_prune_linear_bias_linear(self):
413
# linear(bias) -> linear(no bias)
414
configs, shapes = [], []
417
{"tensor_fqn": "seq.0.weight"},
418
{"tensor_fqn": "seq.1.weight"},
421
shapes.append((128, 10))
423
# linear(bias) -> linear(bias)
426
{"tensor_fqn": "seq.2.weight"},
427
{"tensor_fqn": "seq.3.weight"},
430
shapes.append((128, 10))
432
# linear(no bias) -> linear(bias)
435
{"tensor_fqn": "seq.0.weight"},
436
{"tensor_fqn": "seq.1.weight"},
437
{"tensor_fqn": "seq.2.weight"},
440
shapes.append((128, 10))
442
for device in DEVICES:
443
for also_prune_bias in [True, False]:
444
for config, shape in zip(configs, shapes):
445
self._test_linear_on_device(
449
torch.device(device),
453
def test_prune_linear_activation_linear(self):
455
{"tensor_fqn": "seq.0.weight"},
456
{"tensor_fqn": "seq.2.weight"},
457
{"tensor_fqn": "seq.4.weight"},
458
{"tensor_fqn": "linear1.weight"},
462
for device in DEVICES:
463
for also_prune_bias in [True, False]:
464
# test version with nn.Modules
465
self._test_linear_on_device(
469
torch.device(device),
472
# test functional version
473
self._test_linear_on_device(
474
LinearActivationFunctional(),
477
torch.device(device),
481
def _test_conv2d_on_device(
482
self, model, config, x, expected_shape, device, also_prune_bias
484
model = model.to(device)
485
num_original_params = sum(p.numel() for p in model.parameters())
488
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
489
pruner.prepare(model, config)
490
pruner.enable_mask_update = True
493
y_expected = model(x)
494
assert y_expected.shape == expected_shape
496
self._check_pruner_prepared(model, pruner, device)
499
pruned = pruner.prune()
501
num_pruned_params = sum(p.numel() for p in pruned.parameters())
503
assert y_pruned.shape == expected_shape
504
self._check_pruner_pruned(model, pruner, device)
505
if y_pruned.shape == y_expected.shape:
506
# TODO This rtol is a little high, need to double check if something specific is causing this to fail
507
assert torch.isclose(
512
).all(), f"fail for {type(model)}"
513
# only time this should be equal is when all layers have padding and we can't prune
514
assert num_pruned_params <= num_original_params
516
def test_prune_conv2d_conv2d(self):
517
configs, shapes = [], []
518
# all within sequential blocks
521
{"tensor_fqn": "seq.0.weight"},
524
shapes.append((1, 52, 20, 20))
525
# prune across sequential blocks
528
{"tensor_fqn": "seq.0.weight"},
529
{"tensor_fqn": "seq.1.weight"},
530
{"tensor_fqn": "conv2d1.weight"},
533
shapes.append((1, 52, 20, 20))
535
for device in DEVICES:
536
x = torch.ones((1, 1, 28, 28), device=device)
537
for also_prune_bias in [True, False]:
538
for config, shape in zip(configs, shapes):
539
self._test_conv2d_on_device(
544
torch.device(device),
548
def test_prune_conv2d_bias_conv2d(self):
549
# Conv2d with Bias and no Activation
550
configs, shapes = [], []
551
# conv2d(bias) -> conv2d(bias)
554
{"tensor_fqn": "seq.0.weight"},
555
{"tensor_fqn": "seq.1.weight"},
558
shapes.append((1, 52, 18, 18))
560
# conv2d(no bias) -> conv2d(bias)
563
{"tensor_fqn": "seq.0.weight"},
564
{"tensor_fqn": "seq.1.weight"},
565
{"tensor_fqn": "conv2d1.weight"},
568
shapes.append((1, 52, 18, 18))
570
# conv2d(bias) -> conv2d(no bias)
573
{"tensor_fqn": "seq.0.weight"},
574
{"tensor_fqn": "seq.1.weight"},
575
{"tensor_fqn": "seq.2.weight"},
578
shapes.append((1, 52, 18, 18))
580
for device in DEVICES:
581
x = torch.ones((1, 1, 28, 28), device=device)
582
for also_prune_bias in [True, False]:
583
for config, shape in zip(configs, shapes):
584
self._test_conv2d_on_device(
589
torch.device(device),
593
def test_prune_conv2d_activation_conv2d(self):
594
# Conv2d with Activation and no Bias
595
configs, shapes = [], []
597
# conv2d(no bias) -> activation -> conv2d(no bias)
600
{"tensor_fqn": "seq.4.weight"},
603
shapes.append((1, 52, 18, 18))
605
# conv2d(bias) -> activation -> conv2d(bias)
608
{"tensor_fqn": "seq.0.weight"},
609
{"tensor_fqn": "seq.2.weight"},
612
shapes.append((1, 52, 18, 18))
614
# conv2d(bias) -> activation -> conv2d(no bias)
617
{"tensor_fqn": "seq.2.weight"},
618
{"tensor_fqn": "seq.4.weight"},
621
shapes.append((1, 52, 18, 18))
623
# conv2d(no bias) -> activation -> conv2d(bias)
626
{"tensor_fqn": "conv2d1.weight"},
629
shapes.append((1, 52, 18, 18))
631
for device in DEVICES:
632
x = torch.ones((1, 1, 28, 28), device=device)
633
for also_prune_bias in [True, False]:
634
for config, shape in zip(configs, shapes):
635
self._test_conv2d_on_device(
640
torch.device(device),
644
def test_prune_conv2d_padding_conv2d(self):
645
# Conv2d with Padded layers after Bias layers
646
configs, shapes = [], []
648
# conv(padded, bias) -> conv(padded, bias)
651
{"tensor_fqn": "seq.4.weight"},
654
shapes.append((1, 52, 24, 24))
656
# conv(no bias, no pad) -> conv(padded, bias)
659
{"tensor_fqn": "seq.2.weight"},
662
shapes.append((1, 52, 24, 24))
664
# conv(padded, bias) -> conv ( no bias ,no pad)
667
{"tensor_fqn": "seq.0.weight"},
670
shapes.append((1, 52, 24, 24))
671
# conv(pad, bias) -> conv(no pad, bias)
674
{"tensor_fqn": "seq.6.weight"},
677
shapes.append((1, 52, 24, 24))
678
# conv(no pad, bias) -> conv(pad, bias)
681
{"tensor_fqn": "seq.8.weight"},
684
shapes.append((1, 52, 24, 24))
686
for device in DEVICES:
687
x = torch.ones((1, 1, 28, 28), device=device)
688
for also_prune_bias in [True, False]:
689
for config, shape in zip(configs, shapes):
690
self._test_conv2d_on_device(
695
torch.device(device),
699
def test_prune_conv2d_pool_conv2d(self):
700
# Conv2d with Pooling layers
702
{"tensor_fqn": "seq.0.weight"},
703
{"tensor_fqn": "seq.3.weight"},
704
{"tensor_fqn": "conv2d1.weight"},
705
{"tensor_fqn": "conv2d2.weight"},
707
shape = (1, 52, 3, 3)
709
for device in DEVICES:
710
x = torch.ones((1, 1, 28, 28), device=device)
711
for also_prune_bias in [True, False]:
712
self._test_conv2d_on_device(
717
torch.device(device),
721
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
722
def test_complex_conv2d(self):
723
"""Test fusion for models that contain Conv2d & Linear modules.
724
Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
726
{"tensor_fqn": "seq.0.weight"},
727
{"tensor_fqn": "seq.3.weight"},
728
{"tensor_fqn": "conv2d1.weight"},
729
{"tensor_fqn": "conv2d2.weight"},
733
for device in DEVICES:
734
x = torch.ones((1, 1, 28, 28), device=device)
735
for also_prune_bias in [True, False]:
736
self._test_conv2d_on_device(
737
Conv2dPoolFlattenFunctional(),
741
torch.device(device),
744
self._test_conv2d_on_device(
749
torch.device(device),
753
def test_prune_lstm_linear_multiple_layer(self):
755
Test fusion support for LSTM(multi-layer) -> Linear
757
model = LSTMLinearModel(
765
{"tensor_fqn": "lstm.weight_ih_l0"},
766
{"tensor_fqn": "lstm.weight_hh_l0"},
767
{"tensor_fqn": "lstm.weight_ih_l1"},
768
{"tensor_fqn": "lstm.weight_hh_l1"},
771
lstm_input = torch.ones((1, 8))
772
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
773
fx_pruner.prepare(model, config)
775
fx_pruner.enable_mask_update = True
779
_, _ = model(lstm_input)
780
pruned_model = fx_pruner.prune()
782
_, _ = pruned_model(lstm_input)
784
expected_params = dict(model.named_parameters())
785
for name, param in model.named_parameters():
786
assert name in expected_params
787
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
788
# Instead we check that the weights of the new LSTM are a subset of the weights of
790
assert rows_are_subset(param, expected_params[name])
791
del expected_params[name]
793
# assert we haven't deleted any keys
794
assert len(expected_params) == 0
796
def test_prune_lstm_linear_single_layer(self):
798
Test fusion support for LSTM (single-layer) -> Linear
800
model = LSTMLinearModel(
808
{"tensor_fqn": "lstm.weight_ih_l0"},
809
{"tensor_fqn": "lstm.weight_hh_l0"},
812
lstm_input = torch.ones((1, 8))
813
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
814
fx_pruner.prepare(model, config)
815
fx_pruner.enable_mask_update = True
819
out_expected, lstm_out_expected = model(lstm_input)
820
pruned_model = fx_pruner.prune()
822
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
823
r, c = lstm_out_expected.size()
825
# We cannot check that y_expected == y_pruned as usual because
826
# zeros vs. missing elements yield different numerical results.
827
# Instead that we check that the pruned elements are the first half of the results
828
# since we are using a BottomHalfLSTMPruner
829
assert torch.isclose(
830
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
832
# also check that output of linear is the same shape, this means we've resized
833
# linear columns correctly.
834
assert out_expected.shape == out_pruned.shape
836
def test_prune_lstm_layernorm_linear_multiple_layer(self):
838
Test fusion support for LSTM(multi-layer) -> Linear
840
model = LSTMLayerNormLinearModel(
848
{"tensor_fqn": "lstm.weight_ih_l0"},
849
{"tensor_fqn": "lstm.weight_hh_l0"},
850
{"tensor_fqn": "lstm.weight_ih_l1"},
851
{"tensor_fqn": "lstm.weight_hh_l1"},
854
lstm_input = torch.ones((1, 8))
855
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
856
fx_pruner.prepare(model, config)
858
fx_pruner.enable_mask_update = True
862
_, _ = model(lstm_input)
863
pruned_model = fx_pruner.prune()
865
_, _ = pruned_model(lstm_input)
867
expected_params = dict(model.named_parameters())
868
for name, param in model.named_parameters():
869
assert name in expected_params
870
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
871
# Instead we check that the weights of the new LSTM are a subset of the weights of
873
assert rows_are_subset(param, expected_params[name])
874
del expected_params[name]
876
# assert we haven't deleted any keys
877
assert len(expected_params) == 0
879
def test_prune_lstm_layernorm_linear_single_layer(self):
881
Test fusion support for LSTM (single-layer) -> Linear
883
model = LSTMLinearModel(
891
{"tensor_fqn": "lstm.weight_ih_l0"},
892
{"tensor_fqn": "lstm.weight_hh_l0"},
895
lstm_input = torch.ones((1, 8))
896
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
897
fx_pruner.prepare(model, config)
898
fx_pruner.enable_mask_update = True
902
out_expected, lstm_out_expected = model(lstm_input)
903
pruned_model = fx_pruner.prune()
905
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
906
r, c = lstm_out_expected.size()
908
# We cannot check that y_expected == y_pruned as usual because
909
# zeros vs. missing elements yield different numerical results.
910
# Instead that we check that the pruned elements are the first half of the results
911
# since we are using a BottomHalfLSTMPruner
912
assert torch.isclose(
913
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
915
# also check that output of linear is the same shape, this means we've resized
916
# linear columns correctly.
917
assert out_expected.shape == out_pruned.shape
919
class TestFPGMPruner(TestCase):
921
Test case for the implementation of paper:
922
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
924
class SimpleConvFPGM(nn.Module):
927
self.conv2d1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False)
928
# Manually set the filter weights for demonstration purposes
930
Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
931
Different from the norm-based decision that prunes filter with value 0.1,
932
FPGM will prune the one with value 2.0.
934
weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter
935
weights = weights[:, None, None, None] # broadcasting
936
self.conv2d1.weight.data.copy_(torch.ones(self.conv2d1.weight.shape) * weights)
938
# Second Convolutional Layer
939
self.conv2d2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False)
940
weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
941
weights = weights[:, None, None, None]
942
self.conv2d2.weight.data.copy_(torch.ones(self.conv2d2.weight.shape) * weights)
944
def forward(self, x):
949
def test_compute_distance(self, device="cpu"):
950
"""Test the distance computation function"""
951
model = TestFPGMPruner.SimpleConvFPGM().to(device)
952
pruner = FPGMPruner(0.3)
953
dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
955
# compute the distance matrix using torch.cdist
956
flattened_filters = torch.Tensor([
957
[3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
958
[2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
959
[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]
963
Expected distance matrix should have the following values:
964
[0.0000, 3.0000, 8.7000],
965
[3.0000, 0.0000, 5.7000],
966
[8.7000, 5.7000, 0.0000],
967
the distance should therefore be:
968
[11.7000, 8.7000, 14.4000]
970
expected_dist_matrix_conv1 = torch.cdist(flattened_filters, flattened_filters, p=2)
971
expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
972
assert torch.isclose(dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07).all()
974
def _test_update_mask_on_single_layer(self, expected_conv1, device):
975
"""Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
976
# test pruning with one layer of conv2d
977
model = TestFPGMPruner.SimpleConvFPGM().to(device)
978
x = torch.ones((1, 1, 32, 32), device=device)
979
pruner = FPGMPruner(0.3)
980
config = [{"tensor_fqn": "conv2d1.weight"}]
981
pruner.prepare(model, config)
982
pruner.enable_mask_update = True
984
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
985
"do not prune the least-norm filter"
988
pruned_model = pruner.prune()
990
pruned_y = pruned_model(x)
992
expected_conv1 = expected_conv1.to(device)
993
assert pruned_y.shape == (1, 4, 32, 32)
994
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
995
assert pruned_model.conv2d2.weight.shape == (4, 2, 3, 3), "conv2d2 should have input channel pruned"
997
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
999
def _test_update_mask_on_multiple_layer(self, expected_conv1, expected_conv2, device):
1000
# the second setting
1001
model = TestFPGMPruner.SimpleConvFPGM().to(device)
1002
x = torch.ones((1, 1, 32, 32), device=device)
1003
pruner = FPGMPruner(0.3)
1005
{"tensor_fqn": "conv2d1.weight"},
1006
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}
1008
pruner.prepare(model, config)
1009
pruner.enable_mask_update = True
1011
# Get the masks for the two least-norm filters
1012
mask1 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-1]
1013
mask2 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-2]
1014
# Check if either of the least-norm filters is not pruned
1015
assert mask1.item() is not False or mask2.item() is not False, "Do not prune all least-norm filters"
1018
pruned_model = pruner.prune()
1019
pruned_y = pruned_model(x)
1021
expected_conv1 = expected_conv1.to(device)
1022
expected_conv2 = expected_conv2.to(device)
1023
assert pruned_y.shape == (1, 2, 32, 32)
1024
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
1025
assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
1027
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
1028
assert torch.isclose(pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07).all()
1030
def test_update_mask(self):
1031
weights = torch.tensor([3.0, 0.1])
1032
expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]
1034
weights = torch.tensor([7.0, 0.4])
1035
expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]
1037
for device in DEVICES:
1038
self._test_update_mask_on_single_layer(expected_conv1, device)
1039
self._test_update_mask_on_multiple_layer(expected_conv1, expected_conv2, device)