pytorch

Форк
0
/
prune.py 
1379 строк · 56.5 Кб
1
r"""Pruning methods."""
2
import numbers
3
from abc import ABC, abstractmethod
4
from collections.abc import Iterable
5
from typing import Tuple
6

7
import torch
8

9

10
class BasePruningMethod(ABC):
11
    r"""Abstract base class for creation of new pruning techniques.
12

13
    Provides a skeleton for customization requiring the overriding of methods
14
    such as :meth:`compute_mask` and :meth:`apply`.
15
    """
16

17
    _tensor_name: str
18

19
    def __call__(self, module, inputs):
20
        r"""Multiply the mask into original tensor and store the result.
21

22
        Multiplies the mask (stored in ``module[name + '_mask']``)
23
        into the original tensor (stored in ``module[name + '_orig']``)
24
        and stores the result into ``module[name]`` by using :meth:`apply_mask`.
25

26
        Args:
27
            module (nn.Module): module containing the tensor to prune
28
            inputs: not used.
29
        """
30
        setattr(module, self._tensor_name, self.apply_mask(module))
31

32
    @abstractmethod
33
    def compute_mask(self, t, default_mask):
34
        r"""Compute and returns a mask for the input tensor ``t``.
35

36
        Starting from a base ``default_mask`` (which should be a mask of ones
37
        if the tensor has not been pruned yet), generate a random mask to
38
        apply on top of the ``default_mask`` according to the specific pruning
39
        method recipe.
40

41
        Args:
42
            t (torch.Tensor): tensor representing the importance scores of the
43
            parameter to prune.
44
            default_mask (torch.Tensor): Base mask from previous pruning
45
            iterations, that need to be respected after the new mask is
46
            applied. Same dims as ``t``.
47

48
        Returns:
49
            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
50
        """
51
        pass
52

53
    def apply_mask(self, module):
54
        r"""Simply handles the multiplication between the parameter being pruned and the generated mask.
55

56
        Fetches the mask and the original tensor from the module
57
        and returns the pruned version of the tensor.
58

59
        Args:
60
            module (nn.Module): module containing the tensor to prune
61

62
        Returns:
63
            pruned_tensor (torch.Tensor): pruned version of the input tensor
64
        """
65
        # to carry out the multiplication, the mask needs to have been computed,
66
        # so the pruning method must know what tensor it's operating on
67
        assert self._tensor_name is not None, f"Module {module} has to be pruned"  # this gets set in apply()
68
        mask = getattr(module, self._tensor_name + "_mask")
69
        orig = getattr(module, self._tensor_name + "_orig")
70
        pruned_tensor = mask.to(dtype=orig.dtype) * orig
71
        return pruned_tensor
72

73
    @classmethod
74
    def apply(cls, module, name, *args, importance_scores=None, **kwargs):
75
        r"""Add pruning on the fly and reparametrization of a tensor.
76

77
        Adds the forward pre-hook that enables pruning on the fly and
78
        the reparametrization of a tensor in terms of the original tensor
79
        and the pruning mask.
80

81
        Args:
82
            module (nn.Module): module containing the tensor to prune
83
            name (str): parameter name within ``module`` on which pruning
84
                will act.
85
            args: arguments passed on to a subclass of
86
                :class:`BasePruningMethod`
87
            importance_scores (torch.Tensor): tensor of importance scores (of
88
                same shape as module parameter) used to compute mask for pruning.
89
                The values in this tensor indicate the importance of the
90
                corresponding elements in the parameter being pruned.
91
                If unspecified or None, the parameter will be used in its place.
92
            kwargs: keyword arguments passed on to a subclass of a
93
                :class:`BasePruningMethod`
94
        """
95

96
        def _get_composite_method(cls, module, name, *args, **kwargs):
97
            # Check if a pruning method has already been applied to
98
            # `module[name]`. If so, store that in `old_method`.
99
            old_method = None
100
            found = 0
101
            # there should technically be only 1 hook with hook.name == name
102
            # assert this using `found`
103
            hooks_to_remove = []
104
            for k, hook in module._forward_pre_hooks.items():
105
                # if it exists, take existing thing, remove hook, then
106
                # go through normal thing
107
                if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
108
                    old_method = hook
109
                    hooks_to_remove.append(k)
110
                    found += 1
111
            assert (
112
                found <= 1
113
            ), f"Avoid adding multiple pruning hooks to the\
114
                same tensor {name} of module {module}. Use a PruningContainer."
115

116
            for k in hooks_to_remove:
117
                del module._forward_pre_hooks[k]
118

119
            # Apply the new pruning method, either from scratch or on top of
120
            # the previous one.
121
            method = cls(*args, **kwargs)  # new pruning
122
            # Have the pruning method remember what tensor it's been applied to
123
            method._tensor_name = name
124

125
            # combine `methods` with `old_method`, if `old_method` exists
126
            if old_method is not None:  # meaning that there was a hook
127
                # if the hook is already a pruning container, just add the
128
                # new pruning method to the container
129
                if isinstance(old_method, PruningContainer):
130
                    old_method.add_pruning_method(method)
131
                    method = old_method  # rename old_method --> method
132

133
                # if the hook is simply a single pruning method, create a
134
                # container, add the old pruning method and the new one
135
                elif isinstance(old_method, BasePruningMethod):
136
                    container = PruningContainer(old_method)
137
                    # Have the pruning method remember the name of its tensor
138
                    # setattr(container, '_tensor_name', name)
139
                    container.add_pruning_method(method)
140
                    method = container  # rename container --> method
141
            return method
142

143
        method = _get_composite_method(cls, module, name, *args, **kwargs)
144
        # at this point we have no forward_pre_hooks but we could have an
145
        # active reparametrization of the tensor if another pruning method
146
        # had been applied (in which case `method` would be a PruningContainer
147
        # and not a simple pruning method).
148

149
        # Pruning is to be applied to the module's tensor named `name`,
150
        # starting from the state it is found in prior to this iteration of
151
        # pruning. The pruning mask is calculated based on importances scores.
152

153
        orig = getattr(module, name)
154
        if importance_scores is not None:
155
            assert (
156
                importance_scores.shape == orig.shape
157
            ), f"importance_scores should have the same shape as parameter                 {name} of {module}"
158
        else:
159
            importance_scores = orig
160

161
        # If this is the first time pruning is applied, take care of moving
162
        # the original tensor to a new parameter called name + '_orig' and
163
        # and deleting the original parameter
164
        if not isinstance(method, PruningContainer):
165
            # copy `module[name]` to `module[name + '_orig']`
166
            module.register_parameter(name + "_orig", orig)
167
            # temporarily delete `module[name]`
168
            del module._parameters[name]
169
            default_mask = torch.ones_like(orig)  # temp
170
        # If this is not the first time pruning is applied, all of the above
171
        # has been done before in a previous pruning iteration, so we're good
172
        # to go
173
        else:
174
            default_mask = (
175
                getattr(module, name + "_mask")
176
                .detach()
177
                .clone(memory_format=torch.contiguous_format)
178
            )
179

180
        # Use try/except because if anything goes wrong with the mask
181
        # computation etc., you'd want to roll back.
182
        try:
183
            # get the final mask, computed according to the specific method
184
            mask = method.compute_mask(importance_scores, default_mask=default_mask)
185
            # reparameterize by saving mask to `module[name + '_mask']`...
186
            module.register_buffer(name + "_mask", mask)
187
            # ... and the new pruned tensor to `module[name]`
188
            setattr(module, name, method.apply_mask(module))
189
            # associate the pruning method to the module via a hook to
190
            # compute the function before every forward() (compile by run)
191
            module.register_forward_pre_hook(method)
192

193
        except Exception as e:
194
            if not isinstance(method, PruningContainer):
195
                orig = getattr(module, name + "_orig")
196
                module.register_parameter(name, orig)
197
                del module._parameters[name + "_orig"]
198
            raise e
199

200
        return method
201

202
    def prune(self, t, default_mask=None, importance_scores=None):
203
        r"""Compute and returns a pruned version of input tensor ``t``.
204

205
        According to the pruning rule specified in :meth:`compute_mask`.
206

207
        Args:
208
            t (torch.Tensor): tensor to prune (of same dimensions as
209
                ``default_mask``).
210
            importance_scores (torch.Tensor): tensor of importance scores (of
211
                same shape as ``t``) used to compute mask for pruning ``t``.
212
                The values in this tensor indicate the importance of the
213
                corresponding elements in the ``t`` that is being pruned.
214
                If unspecified or None, the tensor ``t`` will be used in its place.
215
            default_mask (torch.Tensor, optional): mask from previous pruning
216
                iteration, if any. To be considered when determining what
217
                portion of the tensor that pruning should act on. If None,
218
                default to a mask of ones.
219

220
        Returns:
221
            pruned version of tensor ``t``.
222
        """
223
        if importance_scores is not None:
224
            assert (
225
                importance_scores.shape == t.shape
226
            ), "importance_scores should have the same shape as tensor t"
227
        else:
228
            importance_scores = t
229
        default_mask = default_mask if default_mask is not None else torch.ones_like(t)
230
        return t * self.compute_mask(importance_scores, default_mask=default_mask)
231

232
    def remove(self, module):
233
        r"""Remove the pruning reparameterization from a module.
234

235
        The pruned parameter named ``name`` remains permanently pruned,
236
        and the parameter named ``name+'_orig'`` is removed from the parameter list.
237
        Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.
238

239
        Note:
240
            Pruning itself is NOT undone or reversed!
241
        """
242
        # before removing pruning from a tensor, it has to have been applied
243
        assert (
244
            self._tensor_name is not None
245
        ), f"Module {module} has to be pruned            before pruning can be removed"  # this gets set in apply()
246

247
        # to update module[name] to latest trained weights
248
        weight = self.apply_mask(module)  # masked weights
249

250
        # delete and reset
251
        if hasattr(module, self._tensor_name):
252
            delattr(module, self._tensor_name)
253
        orig = module._parameters[self._tensor_name + "_orig"]
254
        orig.data = weight.data
255
        del module._parameters[self._tensor_name + "_orig"]
256
        del module._buffers[self._tensor_name + "_mask"]
257
        setattr(module, self._tensor_name, orig)
258

259

260
class PruningContainer(BasePruningMethod):
261
    """Container holding a sequence of pruning methods for iterative pruning.
262

263
    Keeps track of the order in which pruning methods are applied and handles
264
    combining successive pruning calls.
265

266
    Accepts as argument an instance of a BasePruningMethod or an iterable of
267
    them.
268
    """
269

270
    def __init__(self, *args):
271
        self._pruning_methods: Tuple[BasePruningMethod, ...] = tuple()
272
        if not isinstance(args, Iterable):  # only 1 item
273
            self._tensor_name = args._tensor_name
274
            self.add_pruning_method(args)
275
        elif len(args) == 1:  # only 1 item in a tuple
276
            self._tensor_name = args[0]._tensor_name
277
            self.add_pruning_method(args[0])
278
        else:  # manual construction from list or other iterable (or no args)
279
            for method in args:
280
                self.add_pruning_method(method)
281

282
    def add_pruning_method(self, method):
283
        r"""Add a child pruning ``method`` to the container.
284

285
        Args:
286
            method (subclass of BasePruningMethod): child pruning method
287
                to be added to the container.
288
        """
289
        # check that we're adding a pruning method to the container
290
        if not isinstance(method, BasePruningMethod) and method is not None:
291
            raise TypeError(
292
                f"{type(method)} is not a BasePruningMethod subclass"
293
            )
294
        elif method is not None and self._tensor_name != method._tensor_name:
295
            raise ValueError(
296
                "Can only add pruning methods acting on "
297
                f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
298
                + f" Found '{method._tensor_name}'"
299
            )
300
        # if all checks passed, add to _pruning_methods tuple
301
        self._pruning_methods += (method,)  # type: ignore[operator]
302

303
    def __len__(self):
304
        return len(self._pruning_methods)
305

306
    def __iter__(self):
307
        return iter(self._pruning_methods)
308

309
    def __getitem__(self, idx):
310
        return self._pruning_methods[idx]
311

312
    def compute_mask(self, t, default_mask):
313
        r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.
314

315
        The new partial mask should be computed on the entries or channels
316
        that were not zeroed out by the ``default_mask``.
317
        Which portions of the tensor ``t`` the new mask will be calculated from
318
        depends on the ``PRUNING_TYPE`` (handled by the type handler):
319

320
        * for 'unstructured', the mask will be computed from the raveled
321
          list of nonmasked entries;
322

323
        * for 'structured', the mask will be computed from the nonmasked
324
          channels in the tensor;
325

326
        * for 'global', the mask will be computed across all entries.
327

328
        Args:
329
            t (torch.Tensor): tensor representing the parameter to prune
330
                (of same dimensions as ``default_mask``).
331
            default_mask (torch.Tensor): mask from previous pruning iteration.
332

333
        Returns:
334
            mask (torch.Tensor): new mask that combines the effects
335
            of the ``default_mask`` and the new mask from the current
336
            pruning ``method`` (of same dimensions as ``default_mask`` and
337
            ``t``).
338
        """
339

340
        def _combine_masks(method, t, mask):
341
            r"""Combine the masks from all pruning methods and returns a new mask.
342

343
            Args:
344
                method (a BasePruningMethod subclass): pruning method
345
                    currently being applied.
346
                t (torch.Tensor): tensor representing the parameter to prune
347
                    (of same dimensions as mask).
348
                mask (torch.Tensor): mask from previous pruning iteration
349

350
            Returns:
351
                new_mask (torch.Tensor): new mask that combines the effects
352
                    of the old mask and the new mask from the current
353
                    pruning method (of same dimensions as mask and t).
354
            """
355
            new_mask = mask  # start off from existing mask
356
            new_mask = new_mask.to(dtype=t.dtype)
357

358
            # compute a slice of t onto which the new pruning method will operate
359
            if method.PRUNING_TYPE == "unstructured":
360
                # prune entries of t where the mask is 1
361
                slc = mask == 1
362

363
            # for struct pruning, exclude channels that have already been
364
            # entirely pruned
365
            elif method.PRUNING_TYPE == "structured":
366
                if not hasattr(method, "dim"):
367
                    raise AttributeError(
368
                        "Pruning methods of PRUNING_TYPE "
369
                        '"structured" need to have the attribute `dim` defined.'
370
                    )
371

372
                # find the channels to keep by removing the ones that have been
373
                # zeroed out already (i.e. where sum(entries) == 0)
374
                n_dims = t.dim()  # "is this a 2D tensor? 3D? ..."
375
                dim = method.dim
376
                # convert negative indexing
377
                if dim < 0:
378
                    dim = n_dims + dim
379
                # if dim is still negative after subtracting it from n_dims
380
                if dim < 0:
381
                    raise IndexError(
382
                        f"Index is out of bounds for tensor with dimensions {n_dims}"
383
                    )
384
                # find channels along dim = dim that aren't already tots 0ed out
385
                keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
386
                # create slice to identify what to prune
387
                slc = [slice(None)] * n_dims
388
                slc[dim] = keep_channel
389

390
            elif method.PRUNING_TYPE == "global":
391
                n_dims = len(t.shape)  # "is this a 2D tensor? 3D? ..."
392
                slc = [slice(None)] * n_dims
393

394
            else:
395
                raise ValueError(
396
                    f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}"
397
                )
398

399
            # compute the new mask on the unpruned slice of the tensor t
400
            partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
401
            new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
402

403
            return new_mask
404

405
        method = self._pruning_methods[-1]
406
        mask = _combine_masks(method, t, default_mask)
407
        return mask
408

409

410
class Identity(BasePruningMethod):
411
    r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""
412

413
    PRUNING_TYPE = "unstructured"
414

415
    def compute_mask(self, t, default_mask):
416
        mask = default_mask
417
        return mask
418

419
    @classmethod
420
    def apply(cls, module, name):
421
        r"""Add pruning on the fly and reparametrization of a tensor.
422

423
        Adds the forward pre-hook that enables pruning on the fly and
424
        the reparametrization of a tensor in terms of the original tensor
425
        and the pruning mask.
426

427
        Args:
428
            module (nn.Module): module containing the tensor to prune
429
            name (str): parameter name within ``module`` on which pruning
430
                will act.
431
        """
432
        return super().apply(module, name)
433

434

435
class RandomUnstructured(BasePruningMethod):
436
    r"""Prune (currently unpruned) units in a tensor at random.
437

438
    Args:
439
        name (str): parameter name within ``module`` on which pruning
440
            will act.
441
        amount (int or float): quantity of parameters to prune.
442
            If ``float``, should be between 0.0 and 1.0 and represent the
443
            fraction of parameters to prune. If ``int``, it represents the
444
            absolute number of parameters to prune.
445
    """
446

447
    PRUNING_TYPE = "unstructured"
448

449
    def __init__(self, amount):
450
        # Check range of validity of pruning amount
451
        _validate_pruning_amount_init(amount)
452
        self.amount = amount
453

454
    def compute_mask(self, t, default_mask):
455
        # Check that the amount of units to prune is not > than the number of
456
        # parameters in t
457
        tensor_size = t.nelement()
458
        # Compute number of units to prune: amount if int,
459
        # else amount * tensor_size
460
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
461
        # This should raise an error if the number of units to prune is larger
462
        # than the number of units in the tensor
463
        _validate_pruning_amount(nparams_toprune, tensor_size)
464

465
        mask = default_mask.clone(memory_format=torch.contiguous_format)
466

467
        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
468
            prob = torch.rand_like(t)
469
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
470
            mask.view(-1)[topk.indices] = 0
471

472
        return mask
473

474
    @classmethod
475
    def apply(cls, module, name, amount):
476
        r"""Add pruning on the fly and reparametrization of a tensor.
477

478
        Adds the forward pre-hook that enables pruning on the fly and
479
        the reparametrization of a tensor in terms of the original tensor
480
        and the pruning mask.
481

482
        Args:
483
            module (nn.Module): module containing the tensor to prune
484
            name (str): parameter name within ``module`` on which pruning
485
                will act.
486
            amount (int or float): quantity of parameters to prune.
487
                If ``float``, should be between 0.0 and 1.0 and represent the
488
                fraction of parameters to prune. If ``int``, it represents the
489
                absolute number of parameters to prune.
490
        """
491
        return super().apply(module, name, amount=amount)
492

493

494
class L1Unstructured(BasePruningMethod):
495
    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
496

497
    Args:
498
        amount (int or float): quantity of parameters to prune.
499
            If ``float``, should be between 0.0 and 1.0 and represent the
500
            fraction of parameters to prune. If ``int``, it represents the
501
            absolute number of parameters to prune.
502
    """
503

504
    PRUNING_TYPE = "unstructured"
505

506
    def __init__(self, amount):
507
        # Check range of validity of pruning amount
508
        _validate_pruning_amount_init(amount)
509
        self.amount = amount
510

511
    def compute_mask(self, t, default_mask):
512
        # Check that the amount of units to prune is not > than the number of
513
        # parameters in t
514
        tensor_size = t.nelement()
515
        # Compute number of units to prune: amount if int,
516
        # else amount * tensor_size
517
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
518
        # This should raise an error if the number of units to prune is larger
519
        # than the number of units in the tensor
520
        _validate_pruning_amount(nparams_toprune, tensor_size)
521

522
        mask = default_mask.clone(memory_format=torch.contiguous_format)
523

524
        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
525
            # largest=True --> top k; largest=False --> bottom k
526
            # Prune the smallest k
527
            topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
528
            # topk will have .indices and .values
529
            mask.view(-1)[topk.indices] = 0
530

531
        return mask
532

533
    @classmethod
534
    def apply(cls, module, name, amount, importance_scores=None):
535
        r"""Add pruning on the fly and reparametrization of a tensor.
536

537
        Adds the forward pre-hook that enables pruning on the fly and
538
        the reparametrization of a tensor in terms of the original tensor
539
        and the pruning mask.
540

541
        Args:
542
            module (nn.Module): module containing the tensor to prune
543
            name (str): parameter name within ``module`` on which pruning
544
                will act.
545
            amount (int or float): quantity of parameters to prune.
546
                If ``float``, should be between 0.0 and 1.0 and represent the
547
                fraction of parameters to prune. If ``int``, it represents the
548
                absolute number of parameters to prune.
549
            importance_scores (torch.Tensor): tensor of importance scores (of same
550
                shape as module parameter) used to compute mask for pruning.
551
                The values in this tensor indicate the importance of the corresponding
552
                elements in the parameter being pruned.
553
                If unspecified or None, the module parameter will be used in its place.
554
        """
555
        return super().apply(
556
            module, name, amount=amount, importance_scores=importance_scores
557
        )
558

559

560
class RandomStructured(BasePruningMethod):
561
    r"""Prune entire (currently unpruned) channels in a tensor at random.
562

563
    Args:
564
        amount (int or float): quantity of parameters to prune.
565
            If ``float``, should be between 0.0 and 1.0 and represent the
566
            fraction of parameters to prune. If ``int``, it represents the
567
            absolute number of parameters to prune.
568
        dim (int, optional): index of the dim along which we define
569
            channels to prune. Default: -1.
570
    """
571

572
    PRUNING_TYPE = "structured"
573

574
    def __init__(self, amount, dim=-1):
575
        # Check range of validity of amount
576
        _validate_pruning_amount_init(amount)
577
        self.amount = amount
578
        self.dim = dim
579

580
    def compute_mask(self, t, default_mask):
581
        r"""Compute and returns a mask for the input tensor ``t``.
582

583
        Starting from a base ``default_mask`` (which should be a mask of ones
584
        if the tensor has not been pruned yet), generate a random mask to
585
        apply on top of the ``default_mask`` by randomly zeroing out channels
586
        along the specified dim of the tensor.
587

588
        Args:
589
            t (torch.Tensor): tensor representing the parameter to prune
590
            default_mask (torch.Tensor): Base mask from previous pruning
591
                iterations, that need to be respected after the new mask is
592
                applied. Same dims as ``t``.
593

594
        Returns:
595
            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
596

597
        Raises:
598
            IndexError: if ``self.dim >= len(t.shape)``
599
        """
600
        # Check that tensor has structure (i.e. more than 1 dimension) such
601
        # that the concept of "channels" makes sense
602
        _validate_structured_pruning(t)
603

604
        # Check that self.dim is a valid dim to index t, else raise IndexError
605
        _validate_pruning_dim(t, self.dim)
606

607
        # Check that the amount of channels to prune is not > than the number of
608
        # channels in t along the dim to prune
609
        tensor_size = t.shape[self.dim]
610
        # Compute number of units to prune: amount if int,
611
        # else amount * tensor_size
612
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
613
        # This should raise an error if the number of units to prune is larger
614
        # than the number of units in the tensor
615
        _validate_pruning_amount(nparams_toprune, tensor_size)
616

617
        # Compute binary mask by initializing it to all 0s and then filling in
618
        # 1s wherever topk.indices indicates, along self.dim.
619
        # mask has the same shape as tensor t
620
        def make_mask(t, dim, nchannels, nchannels_toprune):
621
            # generate a random number in [0, 1] to associate to each channel
622
            prob = torch.rand(nchannels)
623
            # generate mask for each channel by 0ing out the channels that
624
            # got assigned the k = nchannels_toprune lowest values in prob
625
            threshold = torch.kthvalue(prob, k=nchannels_toprune).values
626
            channel_mask = prob > threshold
627

628
            mask = torch.zeros_like(t)
629
            slc = [slice(None)] * len(t.shape)
630
            slc[dim] = channel_mask
631
            mask[slc] = 1
632
            return mask
633

634
        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
635
            mask = default_mask
636
        else:
637
            # apply the new structured mask on top of prior (potentially
638
            # unstructured) mask
639
            mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
640
            mask *= default_mask.to(dtype=mask.dtype)
641
        return mask
642

643
    @classmethod
644
    def apply(cls, module, name, amount, dim=-1):
645
        r"""Add pruning on the fly and reparametrization of a tensor.
646

647
        Adds the forward pre-hook that enables pruning on the fly and
648
        the reparametrization of a tensor in terms of the original tensor
649
        and the pruning mask.
650

651
        Args:
652
            module (nn.Module): module containing the tensor to prune
653
            name (str): parameter name within ``module`` on which pruning
654
                will act.
655
            amount (int or float): quantity of parameters to prune.
656
                If ``float``, should be between 0.0 and 1.0 and represent the
657
                fraction of parameters to prune. If ``int``, it represents the
658
                absolute number of parameters to prune.
659
            dim (int, optional): index of the dim along which we define
660
                channels to prune. Default: -1.
661
        """
662
        return super().apply(module, name, amount=amount, dim=dim)
663

664

665
class LnStructured(BasePruningMethod):
666
    r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.
667

668
    Args:
669
        amount (int or float): quantity of channels to prune.
670
            If ``float``, should be between 0.0 and 1.0 and represent the
671
            fraction of parameters to prune. If ``int``, it represents the
672
            absolute number of parameters to prune.
673
        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
674
            entries for argument ``p`` in :func:`torch.norm`.
675
        dim (int, optional): index of the dim along which we define
676
            channels to prune. Default: -1.
677
    """
678

679
    PRUNING_TYPE = "structured"
680

681
    def __init__(self, amount, n, dim=-1):
682
        # Check range of validity of amount
683
        _validate_pruning_amount_init(amount)
684
        self.amount = amount
685
        self.n = n
686
        self.dim = dim
687

688
    def compute_mask(self, t, default_mask):
689
        r"""Compute and returns a mask for the input tensor ``t``.
690

691
        Starting from a base ``default_mask`` (which should be a mask of ones
692
        if the tensor has not been pruned yet), generate a mask to apply on
693
        top of the ``default_mask`` by zeroing out the channels along the
694
        specified dim with the lowest L\ ``n``-norm.
695

696
        Args:
697
            t (torch.Tensor): tensor representing the parameter to prune
698
            default_mask (torch.Tensor): Base mask from previous pruning
699
                iterations, that need to be respected after the new mask is
700
                applied.  Same dims as ``t``.
701

702
        Returns:
703
            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
704

705
        Raises:
706
            IndexError: if ``self.dim >= len(t.shape)``
707
        """
708
        # Check that tensor has structure (i.e. more than 1 dimension) such
709
        # that the concept of "channels" makes sense
710
        _validate_structured_pruning(t)
711
        # Check that self.dim is a valid dim to index t, else raise IndexError
712
        _validate_pruning_dim(t, self.dim)
713

714
        # Check that the amount of channels to prune is not > than the number of
715
        # channels in t along the dim to prune
716
        tensor_size = t.shape[self.dim]
717
        # Compute number of units to prune: amount if int,
718
        # else amount * tensor_size
719
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
720
        nparams_tokeep = tensor_size - nparams_toprune
721
        # This should raise an error if the number of units to prune is larger
722
        # than the number of units in the tensor
723
        _validate_pruning_amount(nparams_toprune, tensor_size)
724

725
        # Structured pruning prunes entire channels so we need to know the
726
        # L_n norm along each channel to then find the topk based on this
727
        # metric
728
        norm = _compute_norm(t, self.n, self.dim)
729
        # largest=True --> top k; largest=False --> bottom k
730
        # Keep the largest k channels along dim=self.dim
731
        topk = torch.topk(norm, k=nparams_tokeep, largest=True)
732
        # topk will have .indices and .values
733

734
        # Compute binary mask by initializing it to all 0s and then filling in
735
        # 1s wherever topk.indices indicates, along self.dim.
736
        # mask has the same shape as tensor t
737
        def make_mask(t, dim, indices):
738
            # init mask to 0
739
            mask = torch.zeros_like(t)
740
            # e.g.: slc = [None, None, None], if len(t.shape) = 3
741
            slc = [slice(None)] * len(t.shape)
742
            # replace a None at position=dim with indices
743
            # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
744
            slc[dim] = indices
745
            # use slc to slice mask and replace all its entries with 1s
746
            # e.g.: mask[:, :, [0, 2, 3]] = 1
747
            mask[slc] = 1
748
            return mask
749

750
        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
751
            mask = default_mask
752
        else:
753
            mask = make_mask(t, self.dim, topk.indices)
754
            mask *= default_mask.to(dtype=mask.dtype)
755

756
        return mask
757

758
    @classmethod
759
    def apply(cls, module, name, amount, n, dim, importance_scores=None):
760
        r"""Add pruning on the fly and reparametrization of a tensor.
761

762
        Adds the forward pre-hook that enables pruning on the fly and
763
        the reparametrization of a tensor in terms of the original tensor
764
        and the pruning mask.
765

766
        Args:
767
            module (nn.Module): module containing the tensor to prune
768
            name (str): parameter name within ``module`` on which pruning
769
                will act.
770
            amount (int or float): quantity of parameters to prune.
771
                If ``float``, should be between 0.0 and 1.0 and represent the
772
                fraction of parameters to prune. If ``int``, it represents the
773
                absolute number of parameters to prune.
774
            n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
775
                entries for argument ``p`` in :func:`torch.norm`.
776
            dim (int): index of the dim along which we define channels to
777
                prune.
778
            importance_scores (torch.Tensor): tensor of importance scores (of same
779
                shape as module parameter) used to compute mask for pruning.
780
                The values in this tensor indicate the importance of the corresponding
781
                elements in the parameter being pruned.
782
                If unspecified or None, the module parameter will be used in its place.
783
        """
784
        return super().apply(
785
            module,
786
            name,
787
            amount=amount,
788
            n=n,
789
            dim=dim,
790
            importance_scores=importance_scores,
791
        )
792

793

794
class CustomFromMask(BasePruningMethod):
795

796
    PRUNING_TYPE = "global"
797

798
    def __init__(self, mask):
799
        self.mask = mask
800

801
    def compute_mask(self, t, default_mask):
802
        assert default_mask.shape == self.mask.shape
803
        mask = default_mask * self.mask.to(dtype=default_mask.dtype)
804
        return mask
805

806
    @classmethod
807
    def apply(cls, module, name, mask):
808
        r"""Add pruning on the fly and reparametrization of a tensor.
809

810
        Adds the forward pre-hook that enables pruning on the fly and
811
        the reparametrization of a tensor in terms of the original tensor
812
        and the pruning mask.
813

814
        Args:
815
            module (nn.Module): module containing the tensor to prune
816
            name (str): parameter name within ``module`` on which pruning
817
                will act.
818
        """
819
        return super().apply(module, name, mask=mask)
820

821

822
def identity(module, name):
823
    r"""Apply pruning reparametrization without pruning any units.
824

825
    Applies pruning reparametrization to the tensor corresponding to the
826
    parameter called ``name`` in ``module`` without actually pruning any
827
    units. Modifies module in place (and also return the modified module)
828
    by:
829

830
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
831
       binary mask applied to the parameter ``name`` by the pruning method.
832
    2) replacing the parameter ``name`` by its pruned version, while the
833
       original (unpruned) parameter is stored in a new parameter named
834
       ``name+'_orig'``.
835

836
    Note:
837
        The mask is a tensor of ones.
838

839
    Args:
840
        module (nn.Module): module containing the tensor to prune.
841
        name (str): parameter name within ``module`` on which pruning
842
                will act.
843

844
    Returns:
845
        module (nn.Module): modified (i.e. pruned) version of the input module
846

847
    Examples:
848
        >>> # xdoctest: +SKIP
849
        >>> m = prune.identity(nn.Linear(2, 3), 'bias')
850
        >>> print(m.bias_mask)
851
        tensor([1., 1., 1.])
852
    """
853
    Identity.apply(module, name)
854
    return module
855

856

857
def random_unstructured(module, name, amount):
858
    r"""Prune tensor by removing random (currently unpruned) units.
859

860
    Prunes tensor corresponding to parameter called ``name`` in ``module``
861
    by removing the specified ``amount`` of (currently unpruned) units
862
    selected at random.
863
    Modifies module in place (and also return the modified module) by:
864

865
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
866
       binary mask applied to the parameter ``name`` by the pruning method.
867
    2) replacing the parameter ``name`` by its pruned version, while the
868
       original (unpruned) parameter is stored in a new parameter named
869
       ``name+'_orig'``.
870

871
    Args:
872
        module (nn.Module): module containing the tensor to prune
873
        name (str): parameter name within ``module`` on which pruning
874
                will act.
875
        amount (int or float): quantity of parameters to prune.
876
            If ``float``, should be between 0.0 and 1.0 and represent the
877
            fraction of parameters to prune. If ``int``, it represents the
878
            absolute number of parameters to prune.
879

880
    Returns:
881
        module (nn.Module): modified (i.e. pruned) version of the input module
882

883
    Examples:
884
        >>> # xdoctest: +SKIP
885
        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
886
        >>> torch.sum(m.weight_mask == 0)
887
        tensor(1)
888

889
    """
890
    RandomUnstructured.apply(module, name, amount)
891
    return module
892

893

894
def l1_unstructured(module, name, amount, importance_scores=None):
895
    r"""Prune tensor by removing units with the lowest L1-norm.
896

897
    Prunes tensor corresponding to parameter called ``name`` in ``module``
898
    by removing the specified `amount` of (currently unpruned) units with the
899
    lowest L1-norm.
900
    Modifies module in place (and also return the modified module)
901
    by:
902

903
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
904
       binary mask applied to the parameter ``name`` by the pruning method.
905
    2) replacing the parameter ``name`` by its pruned version, while the
906
       original (unpruned) parameter is stored in a new parameter named
907
       ``name+'_orig'``.
908

909
    Args:
910
        module (nn.Module): module containing the tensor to prune
911
        name (str): parameter name within ``module`` on which pruning
912
                will act.
913
        amount (int or float): quantity of parameters to prune.
914
            If ``float``, should be between 0.0 and 1.0 and represent the
915
            fraction of parameters to prune. If ``int``, it represents the
916
            absolute number of parameters to prune.
917
        importance_scores (torch.Tensor): tensor of importance scores (of same
918
            shape as module parameter) used to compute mask for pruning.
919
            The values in this tensor indicate the importance of the corresponding
920
            elements in the parameter being pruned.
921
            If unspecified or None, the module parameter will be used in its place.
922

923
    Returns:
924
        module (nn.Module): modified (i.e. pruned) version of the input module
925

926
    Examples:
927
        >>> # xdoctest: +SKIP
928
        >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
929
        >>> m.state_dict().keys()
930
        odict_keys(['bias', 'weight_orig', 'weight_mask'])
931
    """
932
    L1Unstructured.apply(
933
        module, name, amount=amount, importance_scores=importance_scores
934
    )
935
    return module
936

937

938
def random_structured(module, name, amount, dim):
939
    r"""Prune tensor by removing random channels along the specified dimension.
940

941
    Prunes tensor corresponding to parameter called ``name`` in ``module``
942
    by removing the specified ``amount`` of (currently unpruned) channels
943
    along the specified ``dim`` selected at random.
944
    Modifies module in place (and also return the modified module)
945
    by:
946

947
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
948
       binary mask applied to the parameter ``name`` by the pruning method.
949
    2) replacing the parameter ``name`` by its pruned version, while the
950
       original (unpruned) parameter is stored in a new parameter named
951
       ``name+'_orig'``.
952

953
    Args:
954
        module (nn.Module): module containing the tensor to prune
955
        name (str): parameter name within ``module`` on which pruning
956
                will act.
957
        amount (int or float): quantity of parameters to prune.
958
            If ``float``, should be between 0.0 and 1.0 and represent the
959
            fraction of parameters to prune. If ``int``, it represents the
960
            absolute number of parameters to prune.
961
        dim (int): index of the dim along which we define channels to prune.
962

963
    Returns:
964
        module (nn.Module): modified (i.e. pruned) version of the input module
965

966
    Examples:
967
        >>> # xdoctest: +SKIP
968
        >>> m = prune.random_structured(
969
        ...     nn.Linear(5, 3), 'weight', amount=3, dim=1
970
        ... )
971
        >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
972
        >>> print(columns_pruned)
973
        3
974
    """
975
    RandomStructured.apply(module, name, amount, dim)
976
    return module
977

978

979
def ln_structured(module, name, amount, n, dim, importance_scores=None):
980
    r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.
981

982
    Prunes tensor corresponding to parameter called ``name`` in ``module``
983
    by removing the specified ``amount`` of (currently unpruned) channels
984
    along the specified ``dim`` with the lowest L\ ``n``-norm.
985
    Modifies module in place (and also return the modified module)
986
    by:
987

988
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
989
       binary mask applied to the parameter ``name`` by the pruning method.
990
    2) replacing the parameter ``name`` by its pruned version, while the
991
       original (unpruned) parameter is stored in a new parameter named
992
       ``name+'_orig'``.
993

994
    Args:
995
        module (nn.Module): module containing the tensor to prune
996
        name (str): parameter name within ``module`` on which pruning
997
                will act.
998
        amount (int or float): quantity of parameters to prune.
999
            If ``float``, should be between 0.0 and 1.0 and represent the
1000
            fraction of parameters to prune. If ``int``, it represents the
1001
            absolute number of parameters to prune.
1002
        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1003
            entries for argument ``p`` in :func:`torch.norm`.
1004
        dim (int): index of the dim along which we define channels to prune.
1005
        importance_scores (torch.Tensor): tensor of importance scores (of same
1006
            shape as module parameter) used to compute mask for pruning.
1007
            The values in this tensor indicate the importance of the corresponding
1008
            elements in the parameter being pruned.
1009
            If unspecified or None, the module parameter will be used in its place.
1010

1011
    Returns:
1012
        module (nn.Module): modified (i.e. pruned) version of the input module
1013

1014
    Examples:
1015
        >>> from torch.nn.utils import prune
1016
        >>> m = prune.ln_structured(
1017
        ...     nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
1018
        ... )
1019
    """
1020
    LnStructured.apply(
1021
        module, name, amount, n, dim, importance_scores=importance_scores
1022
    )
1023
    return module
1024

1025

1026
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
1027
    r"""
1028
    Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``.
1029

1030
    Modifies modules in place by:
1031

1032
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
1033
       binary mask applied to the parameter ``name`` by the pruning method.
1034
    2) replacing the parameter ``name`` by its pruned version, while the
1035
       original (unpruned) parameter is stored in a new parameter named
1036
       ``name+'_orig'``.
1037

1038
    Args:
1039
        parameters (Iterable of (module, name) tuples): parameters of
1040
            the model to prune in a global fashion, i.e. by aggregating all
1041
            weights prior to deciding which ones to prune. module must be of
1042
            type :class:`nn.Module`, and name must be a string.
1043
        pruning_method (function): a valid pruning function from this module,
1044
            or a custom one implemented by the user that satisfies the
1045
            implementation guidelines and has ``PRUNING_TYPE='unstructured'``.
1046
        importance_scores (dict): a dictionary mapping (module, name) tuples to
1047
            the corresponding parameter's importance scores tensor. The tensor
1048
            should be the same shape as the parameter, and is used for computing
1049
            mask for pruning.
1050
            If unspecified or None, the parameter will be used in place of its
1051
            importance scores.
1052
        kwargs: other keyword arguments such as:
1053
            amount (int or float): quantity of parameters to prune across the
1054
            specified parameters.
1055
            If ``float``, should be between 0.0 and 1.0 and represent the
1056
            fraction of parameters to prune. If ``int``, it represents the
1057
            absolute number of parameters to prune.
1058

1059
    Raises:
1060
        TypeError: if ``PRUNING_TYPE != 'unstructured'``
1061

1062
    Note:
1063
        Since global structured pruning doesn't make much sense unless the
1064
        norm is normalized by the size of the parameter, we now limit the
1065
        scope of global pruning to unstructured methods.
1066

1067
    Examples:
1068
        >>> from torch.nn.utils import prune
1069
        >>> from collections import OrderedDict
1070
        >>> net = nn.Sequential(OrderedDict([
1071
        ...     ('first', nn.Linear(10, 4)),
1072
        ...     ('second', nn.Linear(4, 1)),
1073
        ... ]))
1074
        >>> parameters_to_prune = (
1075
        ...     (net.first, 'weight'),
1076
        ...     (net.second, 'weight'),
1077
        ... )
1078
        >>> prune.global_unstructured(
1079
        ...     parameters_to_prune,
1080
        ...     pruning_method=prune.L1Unstructured,
1081
        ...     amount=10,
1082
        ... )
1083
        >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
1084
        tensor(10)
1085

1086
    """
1087
    # ensure parameters is a list or generator of tuples
1088
    if not isinstance(parameters, Iterable):
1089
        raise TypeError("global_unstructured(): parameters is not an Iterable")
1090

1091
    importance_scores = importance_scores if importance_scores is not None else {}
1092
    if not isinstance(importance_scores, dict):
1093
        raise TypeError("global_unstructured(): importance_scores must be of type dict")
1094

1095
    # flatten importance scores to consider them all at once in global pruning
1096
    relevant_importance_scores = torch.nn.utils.parameters_to_vector(
1097
        [
1098
            importance_scores.get((module, name), getattr(module, name))
1099
            for (module, name) in parameters
1100
        ]
1101
    )
1102
    # similarly, flatten the masks (if they exist), or use a flattened vector
1103
    # of 1s of the same dimensions as t
1104
    default_mask = torch.nn.utils.parameters_to_vector(
1105
        [
1106
            getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
1107
            for (module, name) in parameters
1108
        ]
1109
    )
1110

1111
    # use the canonical pruning methods to compute the new mask, even if the
1112
    # parameter is now a flattened out version of `parameters`
1113
    container = PruningContainer()
1114
    container._tensor_name = "temp"  # to make it match that of `method`
1115
    method = pruning_method(**kwargs)
1116
    method._tensor_name = "temp"  # to make it match that of `container`
1117
    if method.PRUNING_TYPE != "unstructured":
1118
        raise TypeError(
1119
            'Only "unstructured" PRUNING_TYPE supported for '
1120
            f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}"
1121
        )
1122

1123
    container.add_pruning_method(method)
1124

1125
    # use the `compute_mask` method from `PruningContainer` to combine the
1126
    # mask computed by the new method with the pre-existing mask
1127
    final_mask = container.compute_mask(relevant_importance_scores, default_mask)
1128

1129
    # Pointer for slicing the mask to match the shape of each parameter
1130
    pointer = 0
1131
    for module, name in parameters:
1132

1133
        param = getattr(module, name)
1134
        # The length of the parameter
1135
        num_param = param.numel()
1136
        # Slice the mask, reshape it
1137
        param_mask = final_mask[pointer : pointer + num_param].view_as(param)
1138
        # Assign the correct pre-computed mask to each parameter and add it
1139
        # to the forward_pre_hooks like any other pruning method
1140
        custom_from_mask(module, name, mask=param_mask)
1141

1142
        # Increment the pointer to continue slicing the final_mask
1143
        pointer += num_param
1144

1145

1146
def custom_from_mask(module, name, mask):
1147
    r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``.
1148

1149
    Modifies module in place (and also return the modified module) by:
1150

1151
    1) adding a named buffer called ``name+'_mask'`` corresponding to the
1152
       binary mask applied to the parameter ``name`` by the pruning method.
1153
    2) replacing the parameter ``name`` by its pruned version, while the
1154
       original (unpruned) parameter is stored in a new parameter named
1155
       ``name+'_orig'``.
1156

1157
    Args:
1158
        module (nn.Module): module containing the tensor to prune
1159
        name (str): parameter name within ``module`` on which pruning
1160
            will act.
1161
        mask (Tensor): binary mask to be applied to the parameter.
1162

1163
    Returns:
1164
        module (nn.Module): modified (i.e. pruned) version of the input module
1165

1166
    Examples:
1167
        >>> from torch.nn.utils import prune
1168
        >>> m = prune.custom_from_mask(
1169
        ...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
1170
        ... )
1171
        >>> print(m.bias_mask)
1172
        tensor([0., 1., 0.])
1173

1174
    """
1175
    CustomFromMask.apply(module, name, mask)
1176
    return module
1177

1178

1179
def remove(module, name):
1180
    r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook.
1181

1182
    The pruned parameter named ``name`` remains permanently pruned, and the parameter
1183
    named ``name+'_orig'`` is removed from the parameter list. Similarly,
1184
    the buffer named ``name+'_mask'`` is removed from the buffers.
1185

1186
    Note:
1187
        Pruning itself is NOT undone or reversed!
1188

1189
    Args:
1190
        module (nn.Module): module containing the tensor to prune
1191
        name (str): parameter name within ``module`` on which pruning
1192
            will act.
1193

1194
    Examples:
1195
        >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)
1196
        >>> m = remove(m, name='weight')
1197
    """
1198
    for k, hook in module._forward_pre_hooks.items():
1199
        if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
1200
            hook.remove(module)
1201
            del module._forward_pre_hooks[k]
1202
            return module
1203

1204
    raise ValueError(
1205
        f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed"
1206
    )
1207

1208

1209
def is_pruned(module):
1210
    r"""Check if a module is pruned by looking for pruning pre-hooks.
1211

1212
    Check whether ``module`` is pruned by looking for
1213
    ``forward_pre_hooks`` in its modules that inherit from the
1214
    :class:`BasePruningMethod`.
1215

1216
    Args:
1217
        module (nn.Module): object that is either pruned or unpruned
1218

1219
    Returns:
1220
        binary answer to whether ``module`` is pruned.
1221

1222
    Examples:
1223
        >>> from torch.nn.utils import prune
1224
        >>> m = nn.Linear(5, 7)
1225
        >>> print(prune.is_pruned(m))
1226
        False
1227
        >>> prune.random_unstructured(m, name='weight', amount=0.2)
1228
        >>> print(prune.is_pruned(m))
1229
        True
1230
    """
1231
    for _, submodule in module.named_modules():
1232
        for hook in submodule._forward_pre_hooks.values():
1233
            if isinstance(hook, BasePruningMethod):
1234
                return True
1235
    return False
1236

1237

1238
def _validate_pruning_amount_init(amount):
1239
    r"""Validate helper to check the range of amount at init.
1240

1241
    Args:
1242
        amount (int or float): quantity of parameters to prune.
1243
            If float, should be between 0.0 and 1.0 and represent the
1244
            fraction of parameters to prune. If int, it represents the
1245
            absolute number of parameters to prune.
1246

1247
    Raises:
1248
        ValueError: if amount is a float not in [0, 1], or if it's a negative
1249
            integer.
1250
        TypeError: if amount is neither a float nor an integer.
1251

1252
    Note:
1253
        This does not take into account the number of parameters in the
1254
        tensor to be pruned, which is known only at prune.
1255
    """
1256
    if not isinstance(amount, numbers.Real):
1257
        raise TypeError(
1258
            f"Invalid type for amount: {amount}. Must be int or float."
1259
        )
1260

1261
    if (isinstance(amount, numbers.Integral) and amount < 0) or (
1262
        not isinstance(amount, numbers.Integral)  # so it's a float
1263
        and (float(amount) > 1.0 or float(amount) < 0.0)
1264
    ):
1265
        raise ValueError(
1266
            f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer"
1267
        )
1268

1269

1270
def _validate_pruning_amount(amount, tensor_size):
1271
    r"""Validate that the pruning amount is meaningful wrt to the size of the data.
1272

1273
    Validation helper to check that the amount of parameters to prune
1274
    is meaningful wrt to the size of the data (`tensor_size`).
1275

1276
    Args:
1277
        amount (int or float): quantity of parameters to prune.
1278
            If float, should be between 0.0 and 1.0 and represent the
1279
            fraction of parameters to prune. If int, it represents the
1280
            absolute number of parameters to prune.
1281
        tensor_size (int): absolute number of parameters in the tensor
1282
            to prune.
1283
    """
1284
    # TODO: consider removing this check and allowing users to specify
1285
    # a number of units to prune that is greater than the number of units
1286
    # left to prune. In this case, the tensor will just be fully pruned.
1287

1288
    if isinstance(amount, numbers.Integral) and amount > tensor_size:
1289
        raise ValueError(
1290
            f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}"
1291
        )
1292

1293

1294
def _validate_structured_pruning(t):
1295
    r"""Validate that the tensor to be pruned is at least 2-Dimensional.
1296

1297
    Validation helper to check that the tensor to be pruned is multi-
1298
    dimensional, such that the concept of "channels" is well-defined.
1299

1300
    Args:
1301
        t (torch.Tensor): tensor representing the parameter to prune
1302

1303
    Raises:
1304
        ValueError: if the tensor `t` is not at least 2D.
1305
    """
1306
    shape = t.shape
1307
    if len(shape) <= 1:
1308
        raise ValueError(
1309
            "Structured pruning can only be applied to "
1310
            "multidimensional tensors. Found tensor of shape "
1311
            f"{shape} with {len(shape)} dims"
1312
        )
1313

1314

1315
def _compute_nparams_toprune(amount, tensor_size):
1316
    r"""Convert the pruning amount from a percentage to absolute value.
1317

1318
    Since amount can be expressed either in absolute value or as a
1319
    percentage of the number of units/channels in a tensor, this utility
1320
    function converts the percentage to absolute value to standardize
1321
    the handling of pruning.
1322

1323
    Args:
1324
        amount (int or float): quantity of parameters to prune.
1325
            If float, should be between 0.0 and 1.0 and represent the
1326
            fraction of parameters to prune. If int, it represents the
1327
            absolute number of parameters to prune.
1328
        tensor_size (int): absolute number of parameters in the tensor
1329
            to prune.
1330

1331
    Returns:
1332
        int: the number of units to prune in the tensor
1333
    """
1334
    # incorrect type already checked in _validate_pruning_amount_init
1335
    if isinstance(amount, numbers.Integral):
1336
        return amount
1337
    else:
1338
        return round(amount * tensor_size)
1339

1340

1341
def _validate_pruning_dim(t, dim):
1342
    r"""Validate that the pruning dimension is within the bounds of the tensor dimension.
1343

1344
    Args:
1345
        t (torch.Tensor): tensor representing the parameter to prune
1346
        dim (int): index of the dim along which we define channels to prune
1347
    """
1348
    if dim >= t.dim():
1349
        raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}")
1350

1351

1352
def _compute_norm(t, n, dim):
1353
    r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension.
1354

1355
    The L_n-norm will be computed across all entries in tensor `t` along all dimension
1356
    except for the one identified by dim.
1357
    Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),
1358
    then norm will have Size [4], and each entry will represent the
1359
    `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.
1360

1361
    Args:
1362
        t (torch.Tensor): tensor representing the parameter to prune
1363
        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1364
            entries for argument p in torch.norm
1365
        dim (int): dim identifying the channels to prune
1366

1367
    Returns:
1368
        norm (torch.Tensor): L_n norm computed across all dimensions except
1369
            for `dim`. By construction, `norm.shape = t.shape[-1]`.
1370
    """
1371
    # dims = all axes, except for the one identified by `dim`
1372
    dims = list(range(t.dim()))
1373
    # convert negative indexing
1374
    if dim < 0:
1375
        dim = dims[dim]
1376
    dims.remove(dim)
1377

1378
    norm = torch.norm(t, p=n, dim=dims)
1379
    return norm
1380

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.