4
import unittest.mock as mock
8
import torch.nn.utils.prune as prune
9
from torch.testing._internal.common_nn import NNTestCase
10
from torch.testing._internal.common_utils import (
11
instantiate_parametrized_tests,
18
class TestPruningNN(NNTestCase):
19
_do_cuda_memory_leak_check = True
20
_do_cuda_non_default_stream = True
23
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
24
def test_validate_pruning_amount_init(self):
25
r"""Test the first util function that validates the pruning
26
amount requested by the user the moment the pruning method
27
is initialized. This test checks that the expected errors are
28
raised whenever the amount is invalid.
29
The original function runs basic type checking + value range checks.
30
It doesn't check the validity of the pruning amount with
31
respect to the size of the tensor to prune. That's left to
32
`_validate_pruning_amount`, tested below.
35
with self.assertRaises(TypeError):
36
prune._validate_pruning_amount_init(amount="I'm a string")
39
with self.assertRaises(ValueError):
40
prune._validate_pruning_amount_init(amount=1.1)
41
with self.assertRaises(ValueError):
42
prune._validate_pruning_amount_init(amount=20.0)
45
with self.assertRaises(ValueError):
46
prune._validate_pruning_amount_init(amount=-10)
49
prune._validate_pruning_amount_init(amount=0.34)
50
prune._validate_pruning_amount_init(amount=1500)
51
prune._validate_pruning_amount_init(amount=0)
52
prune._validate_pruning_amount_init(amount=0.0)
53
prune._validate_pruning_amount_init(amount=1)
54
prune._validate_pruning_amount_init(amount=1.0)
57
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
58
def test_validate_pruning_amount(self):
59
r"""Tests the second util function that validates the pruning
60
amount requested by the user, this time with respect to the size
61
of the tensor to prune. The rationale is that if the pruning amount,
62
converted to absolute value of units to prune, is larger than
63
the number of units in the tensor, then we expect the util function
64
to raise a value error.
67
with self.assertRaises(ValueError):
68
prune._validate_pruning_amount(amount=20, tensor_size=19)
71
prune._validate_pruning_amount(amount=0.3, tensor_size=0)
74
prune._validate_pruning_amount(amount=19, tensor_size=20)
75
prune._validate_pruning_amount(amount=0, tensor_size=0)
76
prune._validate_pruning_amount(amount=1, tensor_size=1)
79
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
80
def test_compute_nparams_to_prune(self):
81
r"""Test that requested pruning `amount` gets translated into the
82
correct absolute number of units to prune.
84
self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0)
85
self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10)
87
self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1)
89
self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15)
90
self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7)
92
def test_random_pruning_sizes(self):
93
r"""Test that the new parameters and buffers created by the pruning
94
method have the same size as the input tensor to prune. These, in
95
fact, correspond to the pruned version of the tensor itself, its
96
mask, and its original copy, so the size must match.
100
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
101
names = ["weight", "bias"]
105
with self.subTest(m=m, name=name):
106
original_tensor = getattr(m, name)
108
prune.random_unstructured(m, name=name, amount=0.1)
111
original_tensor.size(), getattr(m, name + "_mask").size()
115
original_tensor.size(), getattr(m, name + "_orig").size()
118
self.assertEqual(original_tensor.size(), getattr(m, name).size())
120
def test_random_pruning_orig(self):
121
r"""Test that original tensor is correctly stored in 'orig'
122
after pruning is applied. Important to make sure we don't
123
lose info about the original unpruned parameter.
127
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
128
names = ["weight", "bias"]
132
with self.subTest(m=m, name=name):
134
original_tensor = getattr(m, name)
135
prune.random_unstructured(m, name=name, amount=0.1)
136
self.assertEqual(original_tensor, getattr(m, name + "_orig"))
138
def test_random_pruning_new_weight(self):
139
r"""Test that module.name now contains a pruned version of
140
the original tensor obtained from multiplying it by the mask.
144
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
145
names = ["weight", "bias"]
149
with self.subTest(m=m, name=name):
151
original_tensor = getattr(m, name)
152
prune.random_unstructured(m, name=name, amount=0.1)
156
getattr(m, name + "_orig")
157
* getattr(m, name + "_mask").to(dtype=original_tensor.dtype),
160
def test_identity_pruning(self):
161
r"""Test that a mask of 1s does not change forward or backward."""
162
input_ = torch.ones(1, 5)
164
y_prepruning = m(input_)
167
y_prepruning.sum().backward()
168
old_grad_weight = m.weight.grad.clone()
169
self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
170
old_grad_bias = m.bias.grad.clone()
171
self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
177
prune.identity(m, name="weight")
180
y_postpruning = m(input_)
181
self.assertEqual(y_prepruning, y_postpruning)
184
y_postpruning.sum().backward()
185
self.assertEqual(old_grad_weight, m.weight_orig.grad)
186
self.assertEqual(old_grad_bias, m.bias.grad)
191
self.assertEqual(y1, y2)
193
def test_random_pruning_0perc(self):
194
r"""Test that a mask of 1s does not change forward or backward."""
195
input_ = torch.ones(1, 5)
197
y_prepruning = m(input_)
200
y_prepruning.sum().backward()
201
old_grad_weight = m.weight.grad.clone()
202
self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
203
old_grad_bias = m.bias.grad.clone()
204
self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
211
"torch.nn.utils.prune.RandomUnstructured.compute_mask"
213
compute_mask.return_value = torch.ones_like(m.weight)
214
prune.random_unstructured(
215
m, name="weight", amount=0.9
219
y_postpruning = m(input_)
220
self.assertEqual(y_prepruning, y_postpruning)
223
y_postpruning.sum().backward()
224
self.assertEqual(old_grad_weight, m.weight_orig.grad)
225
self.assertEqual(old_grad_bias, m.bias.grad)
230
self.assertEqual(y1, y2)
232
def test_random_pruning(self):
233
input_ = torch.ones(1, 5)
237
mask = torch.ones_like(m.weight)
243
"torch.nn.utils.prune.RandomUnstructured.compute_mask"
245
compute_mask.return_value = mask
246
prune.random_unstructured(m, name="weight", amount=0.9)
248
y_postpruning = m(input_)
249
y_postpruning.sum().backward()
251
self.assertEqual(m.weight_orig.grad, mask)
252
self.assertEqual(m.bias.grad, torch.ones_like(m.bias))
255
old_weight_orig = m.weight_orig.clone()
258
for p in m.parameters():
259
p.data.sub_(p.grad.data * learning_rate)
261
self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0])
262
self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3])
264
def test_random_pruning_forward(self):
265
r"""check forward with mask (by hand)."""
266
input_ = torch.ones(1, 5)
270
mask = torch.zeros_like(m.weight)
275
"torch.nn.utils.prune.RandomUnstructured.compute_mask"
277
compute_mask.return_value = mask
278
prune.random_unstructured(m, name="weight", amount=0.9)
281
self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0])
282
self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1])
284
def test_remove_pruning_forward(self):
285
r"""Remove pruning and check forward is unchanged from previous
288
input_ = torch.ones(1, 5)
292
mask = torch.ones_like(m.weight)
298
"torch.nn.utils.prune.RandomUnstructured.compute_mask"
300
compute_mask.return_value = mask
301
prune.random_unstructured(m, name="weight", amount=0.9)
303
y_postpruning = m(input_)
305
prune.remove(m, "weight")
307
y_postremoval = m(input_)
308
self.assertEqual(y_postpruning, y_postremoval)
310
def test_pruning_id_consistency(self):
311
r"""Test that pruning doesn't change the id of the parameters, which
312
would otherwise introduce issues with pre-existing optimizers that
313
point to old parameters.
315
m = nn.Linear(5, 2, bias=False)
317
tensor_id = id(next(iter(m.parameters())))
319
prune.random_unstructured(m, name="weight", amount=0.9)
320
self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
322
prune.remove(m, "weight")
323
self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
325
def test_random_pruning_pickle(self):
326
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
327
names = ["weight", "bias"]
331
with self.subTest(m=m, name=name):
332
prune.random_unstructured(m, name=name, amount=0.1)
333
m_new = pickle.loads(pickle.dumps(m))
334
self.assertIsInstance(m_new, type(m))
336
def test_multiple_pruning_calls(self):
338
m = nn.Conv3d(2, 2, 2)
339
prune.l1_unstructured(m, name="weight", amount=0.1)
340
weight_mask0 = m.weight_mask
343
prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0)
344
hook = next(iter(m._forward_pre_hooks.values()))
345
self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer)
348
self.assertEqual(hook._tensor_name, "weight")
352
self.assertEqual(len(hook), 2)
356
self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured)
357
self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured)
361
self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0))
364
prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1)
367
hook = next(iter(m._forward_pre_hooks.values()))
368
self.assertEqual(hook._tensor_name, "weight")
370
def test_pruning_container(self):
372
container = prune.PruningContainer()
373
container._tensor_name = "test"
374
self.assertEqual(len(container), 0)
376
p = prune.L1Unstructured(amount=2)
377
p._tensor_name = "test"
380
container.add_pruning_method(p)
383
q = prune.L1Unstructured(amount=2)
384
q._tensor_name = "another_test"
385
with self.assertRaises(ValueError):
386
container.add_pruning_method(q)
390
with self.assertRaises(TypeError):
391
container.add_pruning_method(10)
392
with self.assertRaises(TypeError):
393
container.add_pruning_method("ugh")
395
def test_pruning_container_compute_mask(self):
396
r"""Test `compute_mask` of pruning container with a known `t` and
397
`default_mask`. Indirectly checks that Ln structured pruning is
398
acting on the right axis.
401
container = prune.PruningContainer()
402
container._tensor_name = "test"
406
p = prune.L1Unstructured(amount=2)
407
p._tensor_name = "test"
409
container.add_pruning_method(p)
412
t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
414
default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
417
expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32)
418
computed_mask = container.compute_mask(t, default_mask)
419
self.assertEqual(expected_mask, computed_mask)
422
q = prune.LnStructured(amount=1, n=2, dim=0)
423
q._tensor_name = "test"
424
container.add_pruning_method(q)
427
expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32)
428
computed_mask = container.compute_mask(t, default_mask)
429
self.assertEqual(expected_mask, computed_mask)
432
r = prune.LnStructured(amount=1, n=2, dim=1)
433
r._tensor_name = "test"
434
container.add_pruning_method(r)
437
expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32)
438
computed_mask = container.compute_mask(t, default_mask)
439
self.assertEqual(expected_mask, computed_mask)
441
def test_l1_unstructured_pruning(self):
442
r"""Test that l1 unstructured pruning actually removes the lowest
443
entries by l1 norm (by hand). It also checks that applying l1
444
unstructured pruning more than once respects the previous mask.
448
m.weight = torch.nn.Parameter(
449
torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
452
prune.l1_unstructured(m, "weight", amount=2)
453
expected_weight = torch.tensor(
454
[[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
456
self.assertEqual(expected_weight, m.weight)
459
prune.l1_unstructured(m, "weight", amount=2)
460
expected_weight = torch.tensor(
461
[[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype
463
self.assertEqual(expected_weight, m.weight)
465
def test_l1_unstructured_pruning_with_importance_scores(self):
466
r"""Test that l1 unstructured pruning actually removes the lowest
467
entries of importance scores and not the parameter by l1 norm (by hand).
468
It also checks that applying l1 unstructured pruning more than once
469
respects the previous mask.
473
m.weight = torch.nn.Parameter(
474
torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
476
importance_scores = torch.tensor(
477
[[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
480
prune.l1_unstructured(
481
m, "weight", amount=2, importance_scores=importance_scores
483
expected_weight = torch.tensor(
484
[[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
486
self.assertEqual(expected_weight, m.weight)
490
prune.l1_unstructured(
491
m, "weight", amount=2, importance_scores=importance_scores
493
expected_weight = torch.tensor(
494
[[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype
496
self.assertEqual(expected_weight, m.weight)
498
def test_unstructured_pruning_same_magnitude(self):
499
r"""Since it may happen that the tensor to prune has entries with the
500
same exact magnitude, it is important to check that pruning happens
501
consistenly based on the bottom % of weights, and not by threshold,
502
which would instead kill off *all* units with magnitude = threshold.
505
p = prune.L1Unstructured(amount=AMOUNT)
507
t = 2 * torch.randint(low=-1, high=2, size=(10, 7))
508
nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement())
510
computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
511
nparams_pruned = torch.sum(computed_mask == 0)
512
self.assertEqual(nparams_toprune, nparams_pruned)
514
def test_random_structured_pruning_amount(self):
517
p = prune.RandomStructured(amount=AMOUNT, dim=AXIS)
518
t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32)
519
nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS])
521
computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
523
remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS]
524
per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes))
525
assert per_column_sums == [0, 20]
527
def test_ln_structured_pruning(self):
528
r"""Check Ln structured pruning by hand."""
529
m = nn.Conv2d(3, 1, 2)
530
m.weight.data = torch.tensor(
533
[[1.0, 2.0], [1.0, 2.5]],
534
[[0.5, 1.0], [0.1, 0.1]],
535
[[-3.0, -5.0], [0.1, -1.0]],
540
expected_mask_axis1 = torch.ones_like(m.weight)
541
expected_mask_axis1[:, 1] = 0.0
543
prune.ln_structured(m, "weight", amount=1, n=2, dim=1)
544
self.assertEqual(expected_mask_axis1, m.weight_mask)
547
expected_mask_axis3 = expected_mask_axis1
548
expected_mask_axis3[:, :, :, 0] = 0.0
550
prune.ln_structured(m, "weight", amount=1, n=1, dim=-1)
551
self.assertEqual(expected_mask_axis3, m.weight_mask)
553
def test_ln_structured_pruning_importance_scores(self):
554
r"""Check Ln structured pruning by hand."""
555
m = nn.Conv2d(3, 1, 2)
556
m.weight.data = torch.tensor(
559
[[1.0, 2.0], [1.0, 2.5]],
560
[[0.5, 1.0], [0.1, 0.1]],
561
[[-3.0, -5.0], [0.1, -1.0]],
565
importance_scores = torch.tensor(
568
[[10.0, 1.0], [10.0, 1.0]],
569
[[30.0, 3.0], [30.0, 3.0]],
570
[[-20.0, -2.0], [-20.0, -2.0]],
575
expected_mask_axis1 = torch.ones_like(m.weight)
576
expected_mask_axis1[:, 0] = 0.0
579
m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores
581
self.assertEqual(expected_mask_axis1, m.weight_mask)
584
expected_mask_axis3 = expected_mask_axis1
585
expected_mask_axis3[:, :, :, 1] = 0.0
588
m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores
590
self.assertEqual(expected_mask_axis3, m.weight_mask)
592
def test_remove_pruning(self):
593
r"""`prune.remove` removes the hook and the reparametrization
594
and makes the pruning final in the original parameter.
596
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
597
names = ["weight", "bias"]
601
with self.subTest(m=m, name=name):
603
prune.random_unstructured(m, name, amount=0.5)
604
self.assertIn(name + "_orig", dict(m.named_parameters()))
605
self.assertIn(name + "_mask", dict(m.named_buffers()))
606
self.assertNotIn(name, dict(m.named_parameters()))
607
self.assertTrue(hasattr(m, name))
608
pruned_t = getattr(m, name)
611
prune.remove(m, name)
612
self.assertIn(name, dict(m.named_parameters()))
613
self.assertNotIn(name + "_orig", dict(m.named_parameters()))
614
self.assertNotIn(name + "_mask", dict(m.named_buffers()))
615
final_t = getattr(m, name)
617
self.assertEqual(pruned_t, final_t)
619
def test_remove_pruning_exception(self):
620
r"""Removing from an unpruned tensor throws an assertion error"""
621
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
622
names = ["weight", "bias"]
626
with self.subTest(m=m, name=name):
628
self.assertFalse(prune.is_pruned(m))
630
with self.assertRaises(ValueError):
631
prune.remove(m, name)
633
def test_global_pruning(self):
634
r"""Test that global l1 unstructured pruning over 2 parameters removes
635
the `amount=4` smallest global weights across the 2 parameters.
640
m.weight = torch.nn.Parameter(
641
torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
643
n.weight = torch.nn.Parameter(
644
torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
653
prune.global_unstructured(
654
params_to_prune, pruning_method=prune.L1Unstructured, amount=4
657
expected_mweight = torch.tensor(
658
[[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
660
self.assertEqual(expected_mweight, m.weight)
662
expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype)
663
self.assertEqual(expected_nweight, n.weight)
665
def test_global_pruning_importance_scores(self):
666
r"""Test that global l1 unstructured pruning over 2 parameters removes
667
the `amount=4` smallest global weights across the 2 parameters.
672
m.weight = torch.nn.Parameter(
673
torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
675
m_importance_scores = torch.tensor(
676
[[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
678
n.weight = torch.nn.Parameter(
679
torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
681
n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32)
687
importance_scores = {
688
(m, "weight"): m_importance_scores,
689
(n, "weight"): n_importance_scores,
693
prune.global_unstructured(
695
pruning_method=prune.L1Unstructured,
697
importance_scores=importance_scores,
700
expected_m_weight = torch.tensor(
701
[[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
703
self.assertEqual(expected_m_weight, m.weight)
705
expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype)
706
self.assertEqual(expected_n_weight, n.weight)
708
def test_custom_from_mask_pruning(self):
709
r"""Test that the CustomFromMask is capable of receiving
710
as input at instantiation time a custom mask, and combining it with
711
the previous default mask to generate the correct final mask.
714
mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]])
716
default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]])
719
t = torch.rand_like(mask.to(dtype=torch.float32))
721
p = prune.CustomFromMask(mask=mask)
723
computed_mask = p.compute_mask(t, default_mask)
724
expected_mask = torch.tensor(
725
[[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype
728
self.assertEqual(computed_mask, expected_mask)
730
def test_pruning_rollback(self):
731
r"""Test that if something fails when the we try to compute the mask,
732
then the model isn't left in some intermediate half-pruned state.
733
The try/except statement in `apply` should handle rolling back
734
to the previous state before pruning began.
736
modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
737
names = ["weight", "bias"]
741
with self.subTest(m=m, name=name):
743
"torch.nn.utils.prune.L1Unstructured.compute_mask"
745
compute_mask.side_effect = Exception("HA!")
746
with self.assertRaises(Exception):
747
prune.l1_unstructured(m, name=name, amount=0.9)
749
self.assertTrue(name in dict(m.named_parameters()))
750
self.assertFalse(name + "_mask" in dict(m.named_buffers()))
751
self.assertFalse(name + "_orig" in dict(m.named_parameters()))
753
def test_pruning_serialization_model(self):
755
model = torch.nn.Sequential(
756
torch.nn.Linear(10, 10),
758
torch.nn.Linear(10, 1),
761
self.assertNotIn("0.weight_orig", model.state_dict())
762
self.assertNotIn("0.weight_mask", model.state_dict())
763
self.assertIn("0.weight", model.state_dict())
766
prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
769
self.assertIn("0.weight_orig", model.state_dict())
770
self.assertIn("0.weight_mask", model.state_dict())
771
self.assertNotIn("0.weight", model.state_dict())
772
self.assertTrue(hasattr(model[0], "weight"))
774
pruned_weight = model[0].weight
776
with TemporaryFileName() as fname:
777
torch.save(model, fname)
779
new_model = torch.load(fname, weights_only=False)
782
self.assertIn("0.weight_orig", new_model.state_dict())
783
self.assertIn("0.weight_mask", new_model.state_dict())
784
self.assertNotIn("0.weight", new_model.state_dict())
785
self.assertTrue(hasattr(new_model[0], "weight"))
787
self.assertEqual(pruned_weight, new_model[0].weight)
789
def test_pruning_serialization_state_dict(self):
791
model = torch.nn.Sequential(
792
torch.nn.Linear(10, 10),
794
torch.nn.Linear(10, 1),
797
self.assertNotIn("0.weight_orig", model.state_dict())
798
self.assertNotIn("0.weight_mask", model.state_dict())
799
self.assertIn("0.weight", model.state_dict())
802
prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
805
self.assertIn("0.weight_orig", model.state_dict())
806
self.assertIn("0.weight_mask", model.state_dict())
807
self.assertNotIn("0.weight", model.state_dict())
808
self.assertTrue(hasattr(model[0], "weight"))
810
pruned_weight = model[0].weight
814
prune.remove(module=model[0], name="weight")
817
self.assertNotIn("0.weight_orig", model.state_dict())
818
self.assertNotIn("0.weight_mask", model.state_dict())
819
self.assertIn("0.weight", model.state_dict())
822
new_model = torch.nn.Sequential(
823
torch.nn.Linear(10, 10),
825
torch.nn.Linear(10, 1),
827
with TemporaryFileName() as fname:
828
torch.save(model.state_dict(), fname)
829
new_model.load_state_dict(torch.load(fname))
833
self.assertNotIn("0.weight_orig", new_model.state_dict())
834
self.assertNotIn("0.weight_mask", new_model.state_dict())
835
self.assertIn("0.weight", new_model.state_dict())
837
self.assertEqual(pruned_weight, new_model[0].weight)
839
def test_prune(self):
841
p = prune.L1Unstructured(amount=2)
843
t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
845
default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
848
expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
849
pruned_tensor = p.prune(t, default_mask)
850
self.assertEqual(t * expected_mask, pruned_tensor)
852
def test_prune_importance_scores(self):
854
p = prune.L1Unstructured(amount=2)
856
t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
857
importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to(
861
default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
864
expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]])
865
pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores)
866
self.assertEqual(t * expected_mask, pruned_tensor)
868
def test_prune_importance_scores_mimic_default(self):
870
p = prune.L1Unstructured(amount=2)
872
t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
874
default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
877
expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
878
pruned_tensor_without_importance_scores = p.prune(t, default_mask)
879
pruned_tensor_with_importance_scores = p.prune(
880
t, default_mask, importance_scores=t
883
pruned_tensor_without_importance_scores,
884
pruned_tensor_with_importance_scores,
886
self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores)
888
def test_rnn_pruning(self):
889
l = torch.nn.LSTM(32, 32)
894
prune.l1_unstructured(l, "weight_ih_l0", 0.5)
895
assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3
898
prune.remove(l, "weight_ih_l0")
899
assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4
906
assert "weight_ih_l0" in l._parameters
907
assert l._parameters["weight_ih_l0"] is not None
908
assert "weight_ih_l0_orig" not in l._parameters
909
assert "weight_ih_l0" in dict(l.named_parameters())
910
assert dict(l.named_parameters())["weight_ih_l0"] is not None
911
assert "weight_ih_l0_orig" not in dict(l.named_parameters())
914
instantiate_parametrized_tests(TestPruningNN)
916
if __name__ == "__main__":