pytorch

Форк
0
/
test_pruning.py 
917 строк · 37.4 Кб
1
# Owner(s): ["module: nn"]
2
import pickle
3
import unittest
4
import unittest.mock as mock
5

6
import torch
7
import torch.nn as nn
8
import torch.nn.utils.prune as prune
9
from torch.testing._internal.common_nn import NNTestCase
10
from torch.testing._internal.common_utils import (
11
    instantiate_parametrized_tests,
12
    run_tests,
13
    TemporaryFileName,
14
    TEST_NUMPY,
15
)
16

17

18
class TestPruningNN(NNTestCase):
19
    _do_cuda_memory_leak_check = True
20
    _do_cuda_non_default_stream = True
21

22
    # torch/nn/utils/prune.py
23
    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
24
    def test_validate_pruning_amount_init(self):
25
        r"""Test the first util function that validates the pruning
26
        amount requested by the user the moment the pruning method
27
        is initialized. This test checks that the expected errors are
28
        raised whenever the amount is invalid.
29
        The original function runs basic type checking + value range checks.
30
        It doesn't check the validity of the pruning amount with
31
        respect to the size of the tensor to prune. That's left to
32
        `_validate_pruning_amount`, tested below.
33
        """
34
        # neither float not int should raise TypeError
35
        with self.assertRaises(TypeError):
36
            prune._validate_pruning_amount_init(amount="I'm a string")
37

38
        # float not in [0, 1] should raise ValueError
39
        with self.assertRaises(ValueError):
40
            prune._validate_pruning_amount_init(amount=1.1)
41
        with self.assertRaises(ValueError):
42
            prune._validate_pruning_amount_init(amount=20.0)
43

44
        # negative int should raise ValueError
45
        with self.assertRaises(ValueError):
46
            prune._validate_pruning_amount_init(amount=-10)
47

48
        # all these should pass without errors because they're valid amounts
49
        prune._validate_pruning_amount_init(amount=0.34)
50
        prune._validate_pruning_amount_init(amount=1500)
51
        prune._validate_pruning_amount_init(amount=0)
52
        prune._validate_pruning_amount_init(amount=0.0)
53
        prune._validate_pruning_amount_init(amount=1)
54
        prune._validate_pruning_amount_init(amount=1.0)
55
        self.assertTrue(True)
56

57
    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
58
    def test_validate_pruning_amount(self):
59
        r"""Tests the second util function that validates the pruning
60
        amount requested by the user, this time with respect to the size
61
        of the tensor to prune. The rationale is that if the pruning amount,
62
        converted to absolute value of units to prune, is larger than
63
        the number of units in the tensor, then we expect the util function
64
        to raise a value error.
65
        """
66
        # if amount is int and amount > tensor_size, raise ValueError
67
        with self.assertRaises(ValueError):
68
            prune._validate_pruning_amount(amount=20, tensor_size=19)
69

70
        # amount is a float so this should not raise an error
71
        prune._validate_pruning_amount(amount=0.3, tensor_size=0)
72

73
        # this is okay
74
        prune._validate_pruning_amount(amount=19, tensor_size=20)
75
        prune._validate_pruning_amount(amount=0, tensor_size=0)
76
        prune._validate_pruning_amount(amount=1, tensor_size=1)
77
        self.assertTrue(True)
78

79
    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
80
    def test_compute_nparams_to_prune(self):
81
        r"""Test that requested pruning `amount` gets translated into the
82
        correct absolute number of units to prune.
83
        """
84
        self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0)
85
        self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10)
86
        # if 1 is int, means 1 unit
87
        self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1)
88
        # if 1. is float, means 100% of units
89
        self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15)
90
        self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7)
91

92
    def test_random_pruning_sizes(self):
93
        r"""Test that the new parameters and buffers created by the pruning
94
        method have the same size as the input tensor to prune. These, in
95
        fact, correspond to the pruned version of the tensor itself, its
96
        mask, and its original copy, so the size must match.
97
        """
98
        # fixturize test
99
        # TODO: add other modules
100
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
101
        names = ["weight", "bias"]
102

103
        for m in modules:
104
            for name in names:
105
                with self.subTest(m=m, name=name):
106
                    original_tensor = getattr(m, name)
107

108
                    prune.random_unstructured(m, name=name, amount=0.1)
109
                    # mask has the same size as tensor being pruned
110
                    self.assertEqual(
111
                        original_tensor.size(), getattr(m, name + "_mask").size()
112
                    )
113
                    # 'orig' tensor has the same size as the original tensor
114
                    self.assertEqual(
115
                        original_tensor.size(), getattr(m, name + "_orig").size()
116
                    )
117
                    # new tensor has the same size as the original tensor
118
                    self.assertEqual(original_tensor.size(), getattr(m, name).size())
119

120
    def test_random_pruning_orig(self):
121
        r"""Test that original tensor is correctly stored in 'orig'
122
        after pruning is applied. Important to make sure we don't
123
        lose info about the original unpruned parameter.
124
        """
125
        # fixturize test
126
        # TODO: add other modules
127
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
128
        names = ["weight", "bias"]
129

130
        for m in modules:
131
            for name in names:
132
                with self.subTest(m=m, name=name):
133
                    # tensor prior to pruning
134
                    original_tensor = getattr(m, name)
135
                    prune.random_unstructured(m, name=name, amount=0.1)
136
                    self.assertEqual(original_tensor, getattr(m, name + "_orig"))
137

138
    def test_random_pruning_new_weight(self):
139
        r"""Test that module.name now contains a pruned version of
140
        the original tensor obtained from multiplying it by the mask.
141
        """
142
        # fixturize test
143
        # TODO: add other modules
144
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
145
        names = ["weight", "bias"]
146

147
        for m in modules:
148
            for name in names:
149
                with self.subTest(m=m, name=name):
150
                    # tensor prior to pruning
151
                    original_tensor = getattr(m, name)
152
                    prune.random_unstructured(m, name=name, amount=0.1)
153
                    # weight = weight_orig * weight_mask
154
                    self.assertEqual(
155
                        getattr(m, name),
156
                        getattr(m, name + "_orig")
157
                        * getattr(m, name + "_mask").to(dtype=original_tensor.dtype),
158
                    )
159

160
    def test_identity_pruning(self):
161
        r"""Test that a mask of 1s does not change forward or backward."""
162
        input_ = torch.ones(1, 5)
163
        m = nn.Linear(5, 2)
164
        y_prepruning = m(input_)  # output prior to pruning
165

166
        # compute grad pre-pruning and check it's equal to all ones
167
        y_prepruning.sum().backward()
168
        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
169
        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
170
        old_grad_bias = m.bias.grad.clone()
171
        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
172

173
        # remove grads
174
        m.zero_grad()
175

176
        # force the mask to be made of all 1s
177
        prune.identity(m, name="weight")
178

179
        # with mask of 1s, output should be identical to no mask
180
        y_postpruning = m(input_)
181
        self.assertEqual(y_prepruning, y_postpruning)
182

183
        # with mask of 1s, grad should be identical to no mask
184
        y_postpruning.sum().backward()
185
        self.assertEqual(old_grad_weight, m.weight_orig.grad)
186
        self.assertEqual(old_grad_bias, m.bias.grad)
187

188
        # calling forward twice in a row shouldn't change output
189
        y1 = m(input_)
190
        y2 = m(input_)
191
        self.assertEqual(y1, y2)
192

193
    def test_random_pruning_0perc(self):
194
        r"""Test that a mask of 1s does not change forward or backward."""
195
        input_ = torch.ones(1, 5)
196
        m = nn.Linear(5, 2)
197
        y_prepruning = m(input_)  # output prior to pruning
198

199
        # compute grad pre-pruning and check it's equal to all ones
200
        y_prepruning.sum().backward()
201
        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!
202
        self.assertEqual(old_grad_weight, torch.ones_like(m.weight))
203
        old_grad_bias = m.bias.grad.clone()
204
        self.assertEqual(old_grad_bias, torch.ones_like(m.bias))
205

206
        # remove grads
207
        m.zero_grad()
208

209
        # force the mask to be made of all 1s
210
        with mock.patch(
211
            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
212
        ) as compute_mask:
213
            compute_mask.return_value = torch.ones_like(m.weight)
214
            prune.random_unstructured(
215
                m, name="weight", amount=0.9
216
            )  # amount won't count
217

218
        # with mask of 1s, output should be identical to no mask
219
        y_postpruning = m(input_)
220
        self.assertEqual(y_prepruning, y_postpruning)
221

222
        # with mask of 1s, grad should be identical to no mask
223
        y_postpruning.sum().backward()
224
        self.assertEqual(old_grad_weight, m.weight_orig.grad)
225
        self.assertEqual(old_grad_bias, m.bias.grad)
226

227
        # calling forward twice in a row shouldn't change output
228
        y1 = m(input_)
229
        y2 = m(input_)
230
        self.assertEqual(y1, y2)
231

232
    def test_random_pruning(self):
233
        input_ = torch.ones(1, 5)
234
        m = nn.Linear(5, 2)
235

236
        # define custom mask to assign with mock
237
        mask = torch.ones_like(m.weight)
238
        mask[1, 0] = 0
239
        mask[0, 3] = 0
240

241
        # check grad is zero for masked weights
242
        with mock.patch(
243
            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
244
        ) as compute_mask:
245
            compute_mask.return_value = mask
246
            prune.random_unstructured(m, name="weight", amount=0.9)
247

248
        y_postpruning = m(input_)
249
        y_postpruning.sum().backward()
250
        # weight_orig is the parameter, so it's the tensor that will accumulate the grad
251
        self.assertEqual(m.weight_orig.grad, mask)  # all 1s, except for masked units
252
        self.assertEqual(m.bias.grad, torch.ones_like(m.bias))
253

254
        # make sure that weight_orig update doesn't modify [1, 0] and [0, 3]
255
        old_weight_orig = m.weight_orig.clone()
256
        # update weights
257
        learning_rate = 1.0
258
        for p in m.parameters():
259
            p.data.sub_(p.grad.data * learning_rate)
260
        # since these are pruned, they should not be updated
261
        self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0])
262
        self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3])
263

264
    def test_random_pruning_forward(self):
265
        r"""check forward with mask (by hand)."""
266
        input_ = torch.ones(1, 5)
267
        m = nn.Linear(5, 2)
268

269
        # define custom mask to assign with mock
270
        mask = torch.zeros_like(m.weight)
271
        mask[1, 0] = 1
272
        mask[0, 3] = 1
273

274
        with mock.patch(
275
            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
276
        ) as compute_mask:
277
            compute_mask.return_value = mask
278
            prune.random_unstructured(m, name="weight", amount=0.9)
279

280
        yhat = m(input_)
281
        self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0])
282
        self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1])
283

284
    def test_remove_pruning_forward(self):
285
        r"""Remove pruning and check forward is unchanged from previous
286
        pruned state.
287
        """
288
        input_ = torch.ones(1, 5)
289
        m = nn.Linear(5, 2)
290

291
        # define custom mask to assign with mock
292
        mask = torch.ones_like(m.weight)
293
        mask[1, 0] = 0
294
        mask[0, 3] = 0
295

296
        # check grad is zero for masked weights
297
        with mock.patch(
298
            "torch.nn.utils.prune.RandomUnstructured.compute_mask"
299
        ) as compute_mask:
300
            compute_mask.return_value = mask
301
            prune.random_unstructured(m, name="weight", amount=0.9)
302

303
        y_postpruning = m(input_)
304

305
        prune.remove(m, "weight")
306

307
        y_postremoval = m(input_)
308
        self.assertEqual(y_postpruning, y_postremoval)
309

310
    def test_pruning_id_consistency(self):
311
        r"""Test that pruning doesn't change the id of the parameters, which
312
        would otherwise introduce issues with pre-existing optimizers that
313
        point to old parameters.
314
        """
315
        m = nn.Linear(5, 2, bias=False)
316

317
        tensor_id = id(next(iter(m.parameters())))
318

319
        prune.random_unstructured(m, name="weight", amount=0.9)
320
        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
321

322
        prune.remove(m, "weight")
323
        self.assertEqual(tensor_id, id(next(iter(m.parameters()))))
324

325
    def test_random_pruning_pickle(self):
326
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
327
        names = ["weight", "bias"]
328

329
        for m in modules:
330
            for name in names:
331
                with self.subTest(m=m, name=name):
332
                    prune.random_unstructured(m, name=name, amount=0.1)
333
                    m_new = pickle.loads(pickle.dumps(m))
334
                    self.assertIsInstance(m_new, type(m))
335

336
    def test_multiple_pruning_calls(self):
337
        # if you call pruning twice, the hook becomes a PruningContainer
338
        m = nn.Conv3d(2, 2, 2)
339
        prune.l1_unstructured(m, name="weight", amount=0.1)
340
        weight_mask0 = m.weight_mask  # save it for later sanity check
341

342
        # prune again
343
        prune.ln_structured(m, name="weight", amount=0.3, n=2, dim=0)
344
        hook = next(iter(m._forward_pre_hooks.values()))
345
        self.assertIsInstance(hook, torch.nn.utils.prune.PruningContainer)
346
        # check that container._tensor_name is correctly set no matter how
347
        # many pruning methods are in the container
348
        self.assertEqual(hook._tensor_name, "weight")
349

350
        # check that the pruning container has the right length
351
        # equal to the number of pruning iters
352
        self.assertEqual(len(hook), 2)  # m.weight has been pruned twice
353

354
        # check that the entries of the pruning container are of the expected
355
        # type and in the expected order
356
        self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured)
357
        self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured)
358

359
        # check that all entries that are 0 in the 1st mask are 0 in the
360
        # 2nd mask too
361
        self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0))
362

363
        # prune again
364
        prune.ln_structured(m, name="weight", amount=0.1, n=float("inf"), dim=1)
365
        # check that container._tensor_name is correctly set no matter how
366
        # many pruning methods are in the container
367
        hook = next(iter(m._forward_pre_hooks.values()))
368
        self.assertEqual(hook._tensor_name, "weight")
369

370
    def test_pruning_container(self):
371
        # create an empty container
372
        container = prune.PruningContainer()
373
        container._tensor_name = "test"
374
        self.assertEqual(len(container), 0)
375

376
        p = prune.L1Unstructured(amount=2)
377
        p._tensor_name = "test"
378

379
        # test adding a pruning method to a container
380
        container.add_pruning_method(p)
381

382
        # test error raised if tensor name is different
383
        q = prune.L1Unstructured(amount=2)
384
        q._tensor_name = "another_test"
385
        with self.assertRaises(ValueError):
386
            container.add_pruning_method(q)
387

388
        # test that adding a non-pruning method object to a pruning container
389
        # raises a TypeError
390
        with self.assertRaises(TypeError):
391
            container.add_pruning_method(10)
392
        with self.assertRaises(TypeError):
393
            container.add_pruning_method("ugh")
394

395
    def test_pruning_container_compute_mask(self):
396
        r"""Test `compute_mask` of pruning container with a known `t` and
397
        `default_mask`. Indirectly checks that Ln structured pruning is
398
        acting on the right axis.
399
        """
400
        # create an empty container
401
        container = prune.PruningContainer()
402
        container._tensor_name = "test"
403

404
        # 1) test unstructured pruning
405
        # create a new pruning method
406
        p = prune.L1Unstructured(amount=2)
407
        p._tensor_name = "test"
408
        # add the pruning method to the container
409
        container.add_pruning_method(p)
410

411
        # create tensor to be pruned
412
        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
413
        # create prior mask by hand
414
        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
415
        # since we are pruning the two lowest magnitude units, the outcome of
416
        # the calculation should be this:
417
        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32)
418
        computed_mask = container.compute_mask(t, default_mask)
419
        self.assertEqual(expected_mask, computed_mask)
420

421
        # 2) test structured pruning
422
        q = prune.LnStructured(amount=1, n=2, dim=0)
423
        q._tensor_name = "test"
424
        container.add_pruning_method(q)
425
        # since we are pruning the lowest magnitude one of the two rows, the
426
        # outcome of the calculation should be this:
427
        expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32)
428
        computed_mask = container.compute_mask(t, default_mask)
429
        self.assertEqual(expected_mask, computed_mask)
430

431
        # 2) test structured pruning, along another axis
432
        r = prune.LnStructured(amount=1, n=2, dim=1)
433
        r._tensor_name = "test"
434
        container.add_pruning_method(r)
435
        # since we are pruning the lowest magnitude of the four columns, the
436
        # outcome of the calculation should be this:
437
        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32)
438
        computed_mask = container.compute_mask(t, default_mask)
439
        self.assertEqual(expected_mask, computed_mask)
440

441
    def test_l1_unstructured_pruning(self):
442
        r"""Test that l1 unstructured pruning actually removes the lowest
443
        entries by l1 norm (by hand). It also checks that applying l1
444
        unstructured pruning more than once respects the previous mask.
445
        """
446
        m = nn.Linear(4, 2)
447
        # modify its weight matrix by hand
448
        m.weight = torch.nn.Parameter(
449
            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
450
        )
451

452
        prune.l1_unstructured(m, "weight", amount=2)
453
        expected_weight = torch.tensor(
454
            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
455
        )
456
        self.assertEqual(expected_weight, m.weight)
457

458
        # check that pruning again removes the next two smallest entries
459
        prune.l1_unstructured(m, "weight", amount=2)
460
        expected_weight = torch.tensor(
461
            [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype
462
        )
463
        self.assertEqual(expected_weight, m.weight)
464

465
    def test_l1_unstructured_pruning_with_importance_scores(self):
466
        r"""Test that l1 unstructured pruning actually removes the lowest
467
        entries of importance scores and not the parameter by l1 norm (by hand).
468
        It also checks that applying l1 unstructured pruning more than once
469
        respects the previous mask.
470
        """
471
        m = nn.Linear(4, 2)
472
        # modify its weight matrix by hand
473
        m.weight = torch.nn.Parameter(
474
            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32)
475
        )
476
        importance_scores = torch.tensor(
477
            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
478
        )
479

480
        prune.l1_unstructured(
481
            m, "weight", amount=2, importance_scores=importance_scores
482
        )
483
        expected_weight = torch.tensor(
484
            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
485
        )
486
        self.assertEqual(expected_weight, m.weight)
487

488
        # check that pruning again removes two entries of m.weight that are colocated with
489
        # the next two smallest absolute values of importance scores.
490
        prune.l1_unstructured(
491
            m, "weight", amount=2, importance_scores=importance_scores
492
        )
493
        expected_weight = torch.tensor(
494
            [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype
495
        )
496
        self.assertEqual(expected_weight, m.weight)
497

498
    def test_unstructured_pruning_same_magnitude(self):
499
        r"""Since it may happen that the tensor to prune has entries with the
500
        same exact magnitude, it is important to check that pruning happens
501
        consistenly based on the bottom % of weights, and not by threshold,
502
        which would instead kill off *all* units with magnitude = threshold.
503
        """
504
        AMOUNT = 0.2
505
        p = prune.L1Unstructured(amount=AMOUNT)
506
        # create a random tensors with entries in {-2, 0, 2}
507
        t = 2 * torch.randint(low=-1, high=2, size=(10, 7))
508
        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement())
509

510
        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
511
        nparams_pruned = torch.sum(computed_mask == 0)
512
        self.assertEqual(nparams_toprune, nparams_pruned)
513

514
    def test_random_structured_pruning_amount(self):
515
        AMOUNT = 0.6
516
        AXIS = 2
517
        p = prune.RandomStructured(amount=AMOUNT, dim=AXIS)
518
        t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=torch.float32)
519
        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS])
520

521
        computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t))
522
        # check that 1 column is fully prune, the others are left untouched
523
        remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS]
524
        per_column_sums = sorted(torch.sum(computed_mask == 0, axis=remaining_axes))
525
        assert per_column_sums == [0, 20]
526

527
    def test_ln_structured_pruning(self):
528
        r"""Check Ln structured pruning by hand."""
529
        m = nn.Conv2d(3, 1, 2)
530
        m.weight.data = torch.tensor(
531
            [
532
                [
533
                    [[1.0, 2.0], [1.0, 2.5]],
534
                    [[0.5, 1.0], [0.1, 0.1]],
535
                    [[-3.0, -5.0], [0.1, -1.0]],
536
                ]
537
            ]
538
        )
539
        # expected effect of pruning 1 of the 3 channels by L2-norm
540
        expected_mask_axis1 = torch.ones_like(m.weight)
541
        expected_mask_axis1[:, 1] = 0.0
542

543
        prune.ln_structured(m, "weight", amount=1, n=2, dim=1)
544
        self.assertEqual(expected_mask_axis1, m.weight_mask)
545

546
        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
547
        expected_mask_axis3 = expected_mask_axis1
548
        expected_mask_axis3[:, :, :, 0] = 0.0
549

550
        prune.ln_structured(m, "weight", amount=1, n=1, dim=-1)
551
        self.assertEqual(expected_mask_axis3, m.weight_mask)
552

553
    def test_ln_structured_pruning_importance_scores(self):
554
        r"""Check Ln structured pruning by hand."""
555
        m = nn.Conv2d(3, 1, 2)
556
        m.weight.data = torch.tensor(
557
            [
558
                [
559
                    [[1.0, 2.0], [1.0, 2.5]],
560
                    [[0.5, 1.0], [0.1, 0.1]],
561
                    [[-3.0, -5.0], [0.1, -1.0]],
562
                ]
563
            ]
564
        )
565
        importance_scores = torch.tensor(
566
            [
567
                [
568
                    [[10.0, 1.0], [10.0, 1.0]],
569
                    [[30.0, 3.0], [30.0, 3.0]],
570
                    [[-20.0, -2.0], [-20.0, -2.0]],
571
                ]
572
            ]
573
        )
574
        # expected effect of pruning 1 of the 3 channels by L2-norm
575
        expected_mask_axis1 = torch.ones_like(m.weight)
576
        expected_mask_axis1[:, 0] = 0.0
577

578
        prune.ln_structured(
579
            m, "weight", amount=1, n=2, dim=1, importance_scores=importance_scores
580
        )
581
        self.assertEqual(expected_mask_axis1, m.weight_mask)
582

583
        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm
584
        expected_mask_axis3 = expected_mask_axis1
585
        expected_mask_axis3[:, :, :, 1] = 0.0
586

587
        prune.ln_structured(
588
            m, "weight", amount=1, n=1, dim=-1, importance_scores=importance_scores
589
        )
590
        self.assertEqual(expected_mask_axis3, m.weight_mask)
591

592
    def test_remove_pruning(self):
593
        r"""`prune.remove` removes the hook and the reparametrization
594
        and makes the pruning final in the original parameter.
595
        """
596
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
597
        names = ["weight", "bias"]
598

599
        for m in modules:
600
            for name in names:
601
                with self.subTest(m=m, name=name):
602
                    # first prune
603
                    prune.random_unstructured(m, name, amount=0.5)
604
                    self.assertIn(name + "_orig", dict(m.named_parameters()))
605
                    self.assertIn(name + "_mask", dict(m.named_buffers()))
606
                    self.assertNotIn(name, dict(m.named_parameters()))
607
                    self.assertTrue(hasattr(m, name))
608
                    pruned_t = getattr(m, name)
609

610
                    # then remove pruning
611
                    prune.remove(m, name)
612
                    self.assertIn(name, dict(m.named_parameters()))
613
                    self.assertNotIn(name + "_orig", dict(m.named_parameters()))
614
                    self.assertNotIn(name + "_mask", dict(m.named_buffers()))
615
                    final_t = getattr(m, name)
616

617
                    self.assertEqual(pruned_t, final_t)
618

619
    def test_remove_pruning_exception(self):
620
        r"""Removing from an unpruned tensor throws an assertion error"""
621
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
622
        names = ["weight", "bias"]
623

624
        for m in modules:
625
            for name in names:
626
                with self.subTest(m=m, name=name):
627
                    # check that the module isn't pruned
628
                    self.assertFalse(prune.is_pruned(m))
629
                    # since it isn't pruned, pruning can't be removed from it
630
                    with self.assertRaises(ValueError):
631
                        prune.remove(m, name)
632

633
    def test_global_pruning(self):
634
        r"""Test that global l1 unstructured pruning over 2 parameters removes
635
        the `amount=4` smallest global weights across the 2 parameters.
636
        """
637
        m = nn.Linear(4, 2)
638
        n = nn.Linear(3, 1)
639
        # modify the weight matrices by hand
640
        m.weight = torch.nn.Parameter(
641
            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
642
        )
643
        n.weight = torch.nn.Parameter(
644
            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
645
        )
646

647
        params_to_prune = (
648
            (m, "weight"),
649
            (n, "weight"),
650
        )
651

652
        # prune the 4 smallest weights globally by L1 magnitude
653
        prune.global_unstructured(
654
            params_to_prune, pruning_method=prune.L1Unstructured, amount=4
655
        )
656

657
        expected_mweight = torch.tensor(
658
            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype
659
        )
660
        self.assertEqual(expected_mweight, m.weight)
661

662
        expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype)
663
        self.assertEqual(expected_nweight, n.weight)
664

665
    def test_global_pruning_importance_scores(self):
666
        r"""Test that global l1 unstructured pruning over 2 parameters removes
667
        the `amount=4` smallest global weights across the 2 parameters.
668
        """
669
        m = nn.Linear(4, 2)
670
        n = nn.Linear(3, 1)
671
        # modify the weight matrices by hand
672
        m.weight = torch.nn.Parameter(
673
            torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=torch.float32)
674
        )
675
        m_importance_scores = torch.tensor(
676
            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32
677
        )
678
        n.weight = torch.nn.Parameter(
679
            torch.tensor([[0, 0.1, -2]]).to(dtype=torch.float32)
680
        )
681
        n_importance_scores = torch.tensor([[0, 10.0, -0.2]]).to(dtype=torch.float32)
682

683
        params_to_prune = (
684
            (m, "weight"),
685
            (n, "weight"),
686
        )
687
        importance_scores = {
688
            (m, "weight"): m_importance_scores,
689
            (n, "weight"): n_importance_scores,
690
        }
691

692
        # prune the 4 smallest weights globally by L1 magnitude
693
        prune.global_unstructured(
694
            params_to_prune,
695
            pruning_method=prune.L1Unstructured,
696
            amount=4,
697
            importance_scores=importance_scores,
698
        )
699

700
        expected_m_weight = torch.tensor(
701
            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype
702
        )
703
        self.assertEqual(expected_m_weight, m.weight)
704

705
        expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype)
706
        self.assertEqual(expected_n_weight, n.weight)
707

708
    def test_custom_from_mask_pruning(self):
709
        r"""Test that the CustomFromMask is capable of receiving
710
        as input at instantiation time a custom mask, and combining it with
711
        the previous default mask to generate the correct final mask.
712
        """
713
        # new mask
714
        mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]])
715
        # old mask
716
        default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]])
717

718
        # some tensor (not actually used)
719
        t = torch.rand_like(mask.to(dtype=torch.float32))
720

721
        p = prune.CustomFromMask(mask=mask)
722

723
        computed_mask = p.compute_mask(t, default_mask)
724
        expected_mask = torch.tensor(
725
            [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype
726
        )
727

728
        self.assertEqual(computed_mask, expected_mask)
729

730
    def test_pruning_rollback(self):
731
        r"""Test that if something fails when the we try to compute the mask,
732
        then the model isn't left in some intermediate half-pruned state.
733
        The try/except statement in `apply` should handle rolling back
734
        to the previous state before pruning began.
735
        """
736
        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]
737
        names = ["weight", "bias"]
738

739
        for m in modules:
740
            for name in names:
741
                with self.subTest(m=m, name=name):
742
                    with mock.patch(
743
                        "torch.nn.utils.prune.L1Unstructured.compute_mask"
744
                    ) as compute_mask:
745
                        compute_mask.side_effect = Exception("HA!")
746
                        with self.assertRaises(Exception):
747
                            prune.l1_unstructured(m, name=name, amount=0.9)
748

749
                        self.assertTrue(name in dict(m.named_parameters()))
750
                        self.assertFalse(name + "_mask" in dict(m.named_buffers()))
751
                        self.assertFalse(name + "_orig" in dict(m.named_parameters()))
752

753
    def test_pruning_serialization_model(self):
754
        # create a model
755
        model = torch.nn.Sequential(
756
            torch.nn.Linear(10, 10),
757
            torch.nn.ReLU(),
758
            torch.nn.Linear(10, 1),
759
        )
760
        # check that everything looks normal before pruning
761
        self.assertNotIn("0.weight_orig", model.state_dict())
762
        self.assertNotIn("0.weight_mask", model.state_dict())
763
        self.assertIn("0.weight", model.state_dict())
764

765
        # prune one of its parameters
766
        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
767

768
        # check that the original weight and the new mask are present
769
        self.assertIn("0.weight_orig", model.state_dict())
770
        self.assertIn("0.weight_mask", model.state_dict())
771
        self.assertNotIn("0.weight", model.state_dict())
772
        self.assertTrue(hasattr(model[0], "weight"))
773

774
        pruned_weight = model[0].weight
775

776
        with TemporaryFileName() as fname:
777
            torch.save(model, fname)
778
            # weights_only=False as this is legacy code that saves the model
779
            new_model = torch.load(fname, weights_only=False)
780

781
        # check that the original weight and the new mask are present
782
        self.assertIn("0.weight_orig", new_model.state_dict())
783
        self.assertIn("0.weight_mask", new_model.state_dict())
784
        self.assertNotIn("0.weight", new_model.state_dict())
785
        self.assertTrue(hasattr(new_model[0], "weight"))
786

787
        self.assertEqual(pruned_weight, new_model[0].weight)
788

789
    def test_pruning_serialization_state_dict(self):
790
        # create a model
791
        model = torch.nn.Sequential(
792
            torch.nn.Linear(10, 10),
793
            torch.nn.ReLU(),
794
            torch.nn.Linear(10, 1),
795
        )
796
        # check that everything looks normal before pruning
797
        self.assertNotIn("0.weight_orig", model.state_dict())
798
        self.assertNotIn("0.weight_mask", model.state_dict())
799
        self.assertIn("0.weight", model.state_dict())
800

801
        # prune one of its parameters
802
        prune.l1_unstructured(module=model[0], name="weight", amount=0.9)
803

804
        # check that the original weight and the new mask are present
805
        self.assertIn("0.weight_orig", model.state_dict())
806
        self.assertIn("0.weight_mask", model.state_dict())
807
        self.assertNotIn("0.weight", model.state_dict())
808
        self.assertTrue(hasattr(model[0], "weight"))
809

810
        pruned_weight = model[0].weight
811

812
        # make pruning permanent and restore parameter names as in base
813
        # architecture
814
        prune.remove(module=model[0], name="weight")
815

816
        # check that the original weight and the new mask are no longer present
817
        self.assertNotIn("0.weight_orig", model.state_dict())
818
        self.assertNotIn("0.weight_mask", model.state_dict())
819
        self.assertIn("0.weight", model.state_dict())
820

821
        # save the state dict of model and reload it into new_model
822
        new_model = torch.nn.Sequential(
823
            torch.nn.Linear(10, 10),
824
            torch.nn.ReLU(),
825
            torch.nn.Linear(10, 1),
826
        )
827
        with TemporaryFileName() as fname:
828
            torch.save(model.state_dict(), fname)
829
            new_model.load_state_dict(torch.load(fname))
830

831
        # check that the original weight and the new mask are not present in
832
        # new_model either.
833
        self.assertNotIn("0.weight_orig", new_model.state_dict())
834
        self.assertNotIn("0.weight_mask", new_model.state_dict())
835
        self.assertIn("0.weight", new_model.state_dict())
836

837
        self.assertEqual(pruned_weight, new_model[0].weight)
838

839
    def test_prune(self):
840
        # create a new pruning method
841
        p = prune.L1Unstructured(amount=2)
842
        # create tensor to be pruned
843
        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
844
        # create prior mask by hand
845
        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
846
        # since we are pruning the two lowest magnitude units, the outcome of
847
        # the calculation should be this:
848
        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
849
        pruned_tensor = p.prune(t, default_mask)
850
        self.assertEqual(t * expected_mask, pruned_tensor)
851

852
    def test_prune_importance_scores(self):
853
        # create a new pruning method
854
        p = prune.L1Unstructured(amount=2)
855
        # create tensor to be pruned
856
        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
857
        importance_scores = torch.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to(
858
            dtype=torch.float32
859
        )
860
        # create prior mask by hand
861
        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
862
        # since we are pruning the two lowest magnitude units, the outcome of
863
        # the calculation should be this:
864
        expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]])
865
        pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores)
866
        self.assertEqual(t * expected_mask, pruned_tensor)
867

868
    def test_prune_importance_scores_mimic_default(self):
869
        # create a new pruning method
870
        p = prune.L1Unstructured(amount=2)
871
        # create tensor to be pruned
872
        t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32)
873
        # create prior mask by hand
874
        default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])
875
        # since we are pruning the two lowest magnitude units, the outcome of
876
        # the calculation should be this:
877
        expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])
878
        pruned_tensor_without_importance_scores = p.prune(t, default_mask)
879
        pruned_tensor_with_importance_scores = p.prune(
880
            t, default_mask, importance_scores=t
881
        )
882
        self.assertEqual(
883
            pruned_tensor_without_importance_scores,
884
            pruned_tensor_with_importance_scores,
885
        )
886
        self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores)
887

888
    def test_rnn_pruning(self):
889
        l = torch.nn.LSTM(32, 32)
890
        # This Module has 4 parameters called:
891
        # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
892

893
        # Pruning one of them causes one of the weights to become a tensor
894
        prune.l1_unstructured(l, "weight_ih_l0", 0.5)
895
        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 3
896

897
        # Removing the pruning reparametrization restores the Parameter
898
        prune.remove(l, "weight_ih_l0")
899
        assert sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights) == 4
900

901
        # Make sure that, upon removal of the reparametrization, the
902
        # `._parameters` and `.named_parameters` contain the right params.
903
        # Specifically, the original weight ('weight_ih_l0') should be placed
904
        # back in the parameters, while the reparametrization component
905
        # ('weight_ih_l0_orig') should be removed.
906
        assert "weight_ih_l0" in l._parameters
907
        assert l._parameters["weight_ih_l0"] is not None
908
        assert "weight_ih_l0_orig" not in l._parameters
909
        assert "weight_ih_l0" in dict(l.named_parameters())
910
        assert dict(l.named_parameters())["weight_ih_l0"] is not None
911
        assert "weight_ih_l0_orig" not in dict(l.named_parameters())
912

913

914
instantiate_parametrized_tests(TestPruningNN)
915

916
if __name__ == "__main__":
917
    run_tests()
918

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.