3
from abc import ABC, abstractmethod
4
from collections.abc import Iterable
5
from typing import Tuple
10
class BasePruningMethod(ABC):
11
r"""Abstract base class for creation of new pruning techniques.
13
Provides a skeleton for customization requiring the overriding of methods
14
such as :meth:`compute_mask` and :meth:`apply`.
19
def __call__(self, module, inputs):
20
r"""Multiply the mask into original tensor and store the result.
22
Multiplies the mask (stored in ``module[name + '_mask']``)
23
into the original tensor (stored in ``module[name + '_orig']``)
24
and stores the result into ``module[name]`` by using :meth:`apply_mask`.
27
module (nn.Module): module containing the tensor to prune
30
setattr(module, self._tensor_name, self.apply_mask(module))
33
def compute_mask(self, t, default_mask):
34
r"""Compute and returns a mask for the input tensor ``t``.
36
Starting from a base ``default_mask`` (which should be a mask of ones
37
if the tensor has not been pruned yet), generate a random mask to
38
apply on top of the ``default_mask`` according to the specific pruning
42
t (torch.Tensor): tensor representing the importance scores of the
44
default_mask (torch.Tensor): Base mask from previous pruning
45
iterations, that need to be respected after the new mask is
46
applied. Same dims as ``t``.
49
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
53
def apply_mask(self, module):
54
r"""Simply handles the multiplication between the parameter being pruned and the generated mask.
56
Fetches the mask and the original tensor from the module
57
and returns the pruned version of the tensor.
60
module (nn.Module): module containing the tensor to prune
63
pruned_tensor (torch.Tensor): pruned version of the input tensor
65
# to carry out the multiplication, the mask needs to have been computed,
66
# so the pruning method must know what tensor it's operating on
67
assert self._tensor_name is not None, f"Module {module} has to be pruned" # this gets set in apply()
68
mask = getattr(module, self._tensor_name + "_mask")
69
orig = getattr(module, self._tensor_name + "_orig")
70
pruned_tensor = mask.to(dtype=orig.dtype) * orig
74
def apply(cls, module, name, *args, importance_scores=None, **kwargs):
75
r"""Add pruning on the fly and reparametrization of a tensor.
77
Adds the forward pre-hook that enables pruning on the fly and
78
the reparametrization of a tensor in terms of the original tensor
82
module (nn.Module): module containing the tensor to prune
83
name (str): parameter name within ``module`` on which pruning
85
args: arguments passed on to a subclass of
86
:class:`BasePruningMethod`
87
importance_scores (torch.Tensor): tensor of importance scores (of
88
same shape as module parameter) used to compute mask for pruning.
89
The values in this tensor indicate the importance of the
90
corresponding elements in the parameter being pruned.
91
If unspecified or None, the parameter will be used in its place.
92
kwargs: keyword arguments passed on to a subclass of a
93
:class:`BasePruningMethod`
96
def _get_composite_method(cls, module, name, *args, **kwargs):
97
# Check if a pruning method has already been applied to
98
# `module[name]`. If so, store that in `old_method`.
101
# there should technically be only 1 hook with hook.name == name
102
# assert this using `found`
104
for k, hook in module._forward_pre_hooks.items():
105
# if it exists, take existing thing, remove hook, then
106
# go through normal thing
107
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
109
hooks_to_remove.append(k)
113
), f"Avoid adding multiple pruning hooks to the\
114
same tensor {name} of module {module}. Use a PruningContainer."
116
for k in hooks_to_remove:
117
del module._forward_pre_hooks[k]
119
# Apply the new pruning method, either from scratch or on top of
121
method = cls(*args, **kwargs) # new pruning
122
# Have the pruning method remember what tensor it's been applied to
123
method._tensor_name = name
125
# combine `methods` with `old_method`, if `old_method` exists
126
if old_method is not None: # meaning that there was a hook
127
# if the hook is already a pruning container, just add the
128
# new pruning method to the container
129
if isinstance(old_method, PruningContainer):
130
old_method.add_pruning_method(method)
131
method = old_method # rename old_method --> method
133
# if the hook is simply a single pruning method, create a
134
# container, add the old pruning method and the new one
135
elif isinstance(old_method, BasePruningMethod):
136
container = PruningContainer(old_method)
137
# Have the pruning method remember the name of its tensor
138
# setattr(container, '_tensor_name', name)
139
container.add_pruning_method(method)
140
method = container # rename container --> method
143
method = _get_composite_method(cls, module, name, *args, **kwargs)
144
# at this point we have no forward_pre_hooks but we could have an
145
# active reparametrization of the tensor if another pruning method
146
# had been applied (in which case `method` would be a PruningContainer
147
# and not a simple pruning method).
149
# Pruning is to be applied to the module's tensor named `name`,
150
# starting from the state it is found in prior to this iteration of
151
# pruning. The pruning mask is calculated based on importances scores.
153
orig = getattr(module, name)
154
if importance_scores is not None:
156
importance_scores.shape == orig.shape
157
), f"importance_scores should have the same shape as parameter {name} of {module}"
159
importance_scores = orig
161
# If this is the first time pruning is applied, take care of moving
162
# the original tensor to a new parameter called name + '_orig' and
163
# and deleting the original parameter
164
if not isinstance(method, PruningContainer):
165
# copy `module[name]` to `module[name + '_orig']`
166
module.register_parameter(name + "_orig", orig)
167
# temporarily delete `module[name]`
168
del module._parameters[name]
169
default_mask = torch.ones_like(orig) # temp
170
# If this is not the first time pruning is applied, all of the above
171
# has been done before in a previous pruning iteration, so we're good
175
getattr(module, name + "_mask")
177
.clone(memory_format=torch.contiguous_format)
180
# Use try/except because if anything goes wrong with the mask
181
# computation etc., you'd want to roll back.
183
# get the final mask, computed according to the specific method
184
mask = method.compute_mask(importance_scores, default_mask=default_mask)
185
# reparameterize by saving mask to `module[name + '_mask']`...
186
module.register_buffer(name + "_mask", mask)
187
# ... and the new pruned tensor to `module[name]`
188
setattr(module, name, method.apply_mask(module))
189
# associate the pruning method to the module via a hook to
190
# compute the function before every forward() (compile by run)
191
module.register_forward_pre_hook(method)
193
except Exception as e:
194
if not isinstance(method, PruningContainer):
195
orig = getattr(module, name + "_orig")
196
module.register_parameter(name, orig)
197
del module._parameters[name + "_orig"]
202
def prune(self, t, default_mask=None, importance_scores=None):
203
r"""Compute and returns a pruned version of input tensor ``t``.
205
According to the pruning rule specified in :meth:`compute_mask`.
208
t (torch.Tensor): tensor to prune (of same dimensions as
210
importance_scores (torch.Tensor): tensor of importance scores (of
211
same shape as ``t``) used to compute mask for pruning ``t``.
212
The values in this tensor indicate the importance of the
213
corresponding elements in the ``t`` that is being pruned.
214
If unspecified or None, the tensor ``t`` will be used in its place.
215
default_mask (torch.Tensor, optional): mask from previous pruning
216
iteration, if any. To be considered when determining what
217
portion of the tensor that pruning should act on. If None,
218
default to a mask of ones.
221
pruned version of tensor ``t``.
223
if importance_scores is not None:
225
importance_scores.shape == t.shape
226
), "importance_scores should have the same shape as tensor t"
228
importance_scores = t
229
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
230
return t * self.compute_mask(importance_scores, default_mask=default_mask)
232
def remove(self, module):
233
r"""Remove the pruning reparameterization from a module.
235
The pruned parameter named ``name`` remains permanently pruned,
236
and the parameter named ``name+'_orig'`` is removed from the parameter list.
237
Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.
240
Pruning itself is NOT undone or reversed!
242
# before removing pruning from a tensor, it has to have been applied
244
self._tensor_name is not None
245
), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply()
247
# to update module[name] to latest trained weights
248
weight = self.apply_mask(module) # masked weights
251
if hasattr(module, self._tensor_name):
252
delattr(module, self._tensor_name)
253
orig = module._parameters[self._tensor_name + "_orig"]
254
orig.data = weight.data
255
del module._parameters[self._tensor_name + "_orig"]
256
del module._buffers[self._tensor_name + "_mask"]
257
setattr(module, self._tensor_name, orig)
260
class PruningContainer(BasePruningMethod):
261
"""Container holding a sequence of pruning methods for iterative pruning.
263
Keeps track of the order in which pruning methods are applied and handles
264
combining successive pruning calls.
266
Accepts as argument an instance of a BasePruningMethod or an iterable of
270
def __init__(self, *args):
271
self._pruning_methods: Tuple[BasePruningMethod, ...] = tuple()
272
if not isinstance(args, Iterable): # only 1 item
273
self._tensor_name = args._tensor_name
274
self.add_pruning_method(args)
275
elif len(args) == 1: # only 1 item in a tuple
276
self._tensor_name = args[0]._tensor_name
277
self.add_pruning_method(args[0])
278
else: # manual construction from list or other iterable (or no args)
280
self.add_pruning_method(method)
282
def add_pruning_method(self, method):
283
r"""Add a child pruning ``method`` to the container.
286
method (subclass of BasePruningMethod): child pruning method
287
to be added to the container.
289
# check that we're adding a pruning method to the container
290
if not isinstance(method, BasePruningMethod) and method is not None:
292
f"{type(method)} is not a BasePruningMethod subclass"
294
elif method is not None and self._tensor_name != method._tensor_name:
296
"Can only add pruning methods acting on "
297
f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
298
+ f" Found '{method._tensor_name}'"
300
# if all checks passed, add to _pruning_methods tuple
301
self._pruning_methods += (method,) # type: ignore[operator]
304
return len(self._pruning_methods)
307
return iter(self._pruning_methods)
309
def __getitem__(self, idx):
310
return self._pruning_methods[idx]
312
def compute_mask(self, t, default_mask):
313
r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.
315
The new partial mask should be computed on the entries or channels
316
that were not zeroed out by the ``default_mask``.
317
Which portions of the tensor ``t`` the new mask will be calculated from
318
depends on the ``PRUNING_TYPE`` (handled by the type handler):
320
* for 'unstructured', the mask will be computed from the raveled
321
list of nonmasked entries;
323
* for 'structured', the mask will be computed from the nonmasked
324
channels in the tensor;
326
* for 'global', the mask will be computed across all entries.
329
t (torch.Tensor): tensor representing the parameter to prune
330
(of same dimensions as ``default_mask``).
331
default_mask (torch.Tensor): mask from previous pruning iteration.
334
mask (torch.Tensor): new mask that combines the effects
335
of the ``default_mask`` and the new mask from the current
336
pruning ``method`` (of same dimensions as ``default_mask`` and
340
def _combine_masks(method, t, mask):
341
r"""Combine the masks from all pruning methods and returns a new mask.
344
method (a BasePruningMethod subclass): pruning method
345
currently being applied.
346
t (torch.Tensor): tensor representing the parameter to prune
347
(of same dimensions as mask).
348
mask (torch.Tensor): mask from previous pruning iteration
351
new_mask (torch.Tensor): new mask that combines the effects
352
of the old mask and the new mask from the current
353
pruning method (of same dimensions as mask and t).
355
new_mask = mask # start off from existing mask
356
new_mask = new_mask.to(dtype=t.dtype)
358
# compute a slice of t onto which the new pruning method will operate
359
if method.PRUNING_TYPE == "unstructured":
360
# prune entries of t where the mask is 1
363
# for struct pruning, exclude channels that have already been
365
elif method.PRUNING_TYPE == "structured":
366
if not hasattr(method, "dim"):
367
raise AttributeError(
368
"Pruning methods of PRUNING_TYPE "
369
'"structured" need to have the attribute `dim` defined.'
372
# find the channels to keep by removing the ones that have been
373
# zeroed out already (i.e. where sum(entries) == 0)
374
n_dims = t.dim() # "is this a 2D tensor? 3D? ..."
376
# convert negative indexing
379
# if dim is still negative after subtracting it from n_dims
382
f"Index is out of bounds for tensor with dimensions {n_dims}"
384
# find channels along dim = dim that aren't already tots 0ed out
385
keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
386
# create slice to identify what to prune
387
slc = [slice(None)] * n_dims
388
slc[dim] = keep_channel
390
elif method.PRUNING_TYPE == "global":
391
n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..."
392
slc = [slice(None)] * n_dims
396
f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}"
399
# compute the new mask on the unpruned slice of the tensor t
400
partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
401
new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
405
method = self._pruning_methods[-1]
406
mask = _combine_masks(method, t, default_mask)
410
class Identity(BasePruningMethod):
411
r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""
413
PRUNING_TYPE = "unstructured"
415
def compute_mask(self, t, default_mask):
420
def apply(cls, module, name):
421
r"""Add pruning on the fly and reparametrization of a tensor.
423
Adds the forward pre-hook that enables pruning on the fly and
424
the reparametrization of a tensor in terms of the original tensor
425
and the pruning mask.
428
module (nn.Module): module containing the tensor to prune
429
name (str): parameter name within ``module`` on which pruning
432
return super().apply(module, name)
435
class RandomUnstructured(BasePruningMethod):
436
r"""Prune (currently unpruned) units in a tensor at random.
439
name (str): parameter name within ``module`` on which pruning
441
amount (int or float): quantity of parameters to prune.
442
If ``float``, should be between 0.0 and 1.0 and represent the
443
fraction of parameters to prune. If ``int``, it represents the
444
absolute number of parameters to prune.
447
PRUNING_TYPE = "unstructured"
449
def __init__(self, amount):
450
# Check range of validity of pruning amount
451
_validate_pruning_amount_init(amount)
454
def compute_mask(self, t, default_mask):
455
# Check that the amount of units to prune is not > than the number of
457
tensor_size = t.nelement()
458
# Compute number of units to prune: amount if int,
459
# else amount * tensor_size
460
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
461
# This should raise an error if the number of units to prune is larger
462
# than the number of units in the tensor
463
_validate_pruning_amount(nparams_toprune, tensor_size)
465
mask = default_mask.clone(memory_format=torch.contiguous_format)
467
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
468
prob = torch.rand_like(t)
469
topk = torch.topk(prob.view(-1), k=nparams_toprune)
470
mask.view(-1)[topk.indices] = 0
475
def apply(cls, module, name, amount):
476
r"""Add pruning on the fly and reparametrization of a tensor.
478
Adds the forward pre-hook that enables pruning on the fly and
479
the reparametrization of a tensor in terms of the original tensor
480
and the pruning mask.
483
module (nn.Module): module containing the tensor to prune
484
name (str): parameter name within ``module`` on which pruning
486
amount (int or float): quantity of parameters to prune.
487
If ``float``, should be between 0.0 and 1.0 and represent the
488
fraction of parameters to prune. If ``int``, it represents the
489
absolute number of parameters to prune.
491
return super().apply(module, name, amount=amount)
494
class L1Unstructured(BasePruningMethod):
495
r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
498
amount (int or float): quantity of parameters to prune.
499
If ``float``, should be between 0.0 and 1.0 and represent the
500
fraction of parameters to prune. If ``int``, it represents the
501
absolute number of parameters to prune.
504
PRUNING_TYPE = "unstructured"
506
def __init__(self, amount):
507
# Check range of validity of pruning amount
508
_validate_pruning_amount_init(amount)
511
def compute_mask(self, t, default_mask):
512
# Check that the amount of units to prune is not > than the number of
514
tensor_size = t.nelement()
515
# Compute number of units to prune: amount if int,
516
# else amount * tensor_size
517
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
518
# This should raise an error if the number of units to prune is larger
519
# than the number of units in the tensor
520
_validate_pruning_amount(nparams_toprune, tensor_size)
522
mask = default_mask.clone(memory_format=torch.contiguous_format)
524
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
525
# largest=True --> top k; largest=False --> bottom k
526
# Prune the smallest k
527
topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
528
# topk will have .indices and .values
529
mask.view(-1)[topk.indices] = 0
534
def apply(cls, module, name, amount, importance_scores=None):
535
r"""Add pruning on the fly and reparametrization of a tensor.
537
Adds the forward pre-hook that enables pruning on the fly and
538
the reparametrization of a tensor in terms of the original tensor
539
and the pruning mask.
542
module (nn.Module): module containing the tensor to prune
543
name (str): parameter name within ``module`` on which pruning
545
amount (int or float): quantity of parameters to prune.
546
If ``float``, should be between 0.0 and 1.0 and represent the
547
fraction of parameters to prune. If ``int``, it represents the
548
absolute number of parameters to prune.
549
importance_scores (torch.Tensor): tensor of importance scores (of same
550
shape as module parameter) used to compute mask for pruning.
551
The values in this tensor indicate the importance of the corresponding
552
elements in the parameter being pruned.
553
If unspecified or None, the module parameter will be used in its place.
555
return super().apply(
556
module, name, amount=amount, importance_scores=importance_scores
560
class RandomStructured(BasePruningMethod):
561
r"""Prune entire (currently unpruned) channels in a tensor at random.
564
amount (int or float): quantity of parameters to prune.
565
If ``float``, should be between 0.0 and 1.0 and represent the
566
fraction of parameters to prune. If ``int``, it represents the
567
absolute number of parameters to prune.
568
dim (int, optional): index of the dim along which we define
569
channels to prune. Default: -1.
572
PRUNING_TYPE = "structured"
574
def __init__(self, amount, dim=-1):
575
# Check range of validity of amount
576
_validate_pruning_amount_init(amount)
580
def compute_mask(self, t, default_mask):
581
r"""Compute and returns a mask for the input tensor ``t``.
583
Starting from a base ``default_mask`` (which should be a mask of ones
584
if the tensor has not been pruned yet), generate a random mask to
585
apply on top of the ``default_mask`` by randomly zeroing out channels
586
along the specified dim of the tensor.
589
t (torch.Tensor): tensor representing the parameter to prune
590
default_mask (torch.Tensor): Base mask from previous pruning
591
iterations, that need to be respected after the new mask is
592
applied. Same dims as ``t``.
595
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
598
IndexError: if ``self.dim >= len(t.shape)``
600
# Check that tensor has structure (i.e. more than 1 dimension) such
601
# that the concept of "channels" makes sense
602
_validate_structured_pruning(t)
604
# Check that self.dim is a valid dim to index t, else raise IndexError
605
_validate_pruning_dim(t, self.dim)
607
# Check that the amount of channels to prune is not > than the number of
608
# channels in t along the dim to prune
609
tensor_size = t.shape[self.dim]
610
# Compute number of units to prune: amount if int,
611
# else amount * tensor_size
612
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
613
# This should raise an error if the number of units to prune is larger
614
# than the number of units in the tensor
615
_validate_pruning_amount(nparams_toprune, tensor_size)
617
# Compute binary mask by initializing it to all 0s and then filling in
618
# 1s wherever topk.indices indicates, along self.dim.
619
# mask has the same shape as tensor t
620
def make_mask(t, dim, nchannels, nchannels_toprune):
621
# generate a random number in [0, 1] to associate to each channel
622
prob = torch.rand(nchannels)
623
# generate mask for each channel by 0ing out the channels that
624
# got assigned the k = nchannels_toprune lowest values in prob
625
threshold = torch.kthvalue(prob, k=nchannels_toprune).values
626
channel_mask = prob > threshold
628
mask = torch.zeros_like(t)
629
slc = [slice(None)] * len(t.shape)
630
slc[dim] = channel_mask
634
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
637
# apply the new structured mask on top of prior (potentially
639
mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
640
mask *= default_mask.to(dtype=mask.dtype)
644
def apply(cls, module, name, amount, dim=-1):
645
r"""Add pruning on the fly and reparametrization of a tensor.
647
Adds the forward pre-hook that enables pruning on the fly and
648
the reparametrization of a tensor in terms of the original tensor
649
and the pruning mask.
652
module (nn.Module): module containing the tensor to prune
653
name (str): parameter name within ``module`` on which pruning
655
amount (int or float): quantity of parameters to prune.
656
If ``float``, should be between 0.0 and 1.0 and represent the
657
fraction of parameters to prune. If ``int``, it represents the
658
absolute number of parameters to prune.
659
dim (int, optional): index of the dim along which we define
660
channels to prune. Default: -1.
662
return super().apply(module, name, amount=amount, dim=dim)
665
class LnStructured(BasePruningMethod):
666
r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.
669
amount (int or float): quantity of channels to prune.
670
If ``float``, should be between 0.0 and 1.0 and represent the
671
fraction of parameters to prune. If ``int``, it represents the
672
absolute number of parameters to prune.
673
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
674
entries for argument ``p`` in :func:`torch.norm`.
675
dim (int, optional): index of the dim along which we define
676
channels to prune. Default: -1.
679
PRUNING_TYPE = "structured"
681
def __init__(self, amount, n, dim=-1):
682
# Check range of validity of amount
683
_validate_pruning_amount_init(amount)
688
def compute_mask(self, t, default_mask):
689
r"""Compute and returns a mask for the input tensor ``t``.
691
Starting from a base ``default_mask`` (which should be a mask of ones
692
if the tensor has not been pruned yet), generate a mask to apply on
693
top of the ``default_mask`` by zeroing out the channels along the
694
specified dim with the lowest L\ ``n``-norm.
697
t (torch.Tensor): tensor representing the parameter to prune
698
default_mask (torch.Tensor): Base mask from previous pruning
699
iterations, that need to be respected after the new mask is
700
applied. Same dims as ``t``.
703
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
706
IndexError: if ``self.dim >= len(t.shape)``
708
# Check that tensor has structure (i.e. more than 1 dimension) such
709
# that the concept of "channels" makes sense
710
_validate_structured_pruning(t)
711
# Check that self.dim is a valid dim to index t, else raise IndexError
712
_validate_pruning_dim(t, self.dim)
714
# Check that the amount of channels to prune is not > than the number of
715
# channels in t along the dim to prune
716
tensor_size = t.shape[self.dim]
717
# Compute number of units to prune: amount if int,
718
# else amount * tensor_size
719
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
720
nparams_tokeep = tensor_size - nparams_toprune
721
# This should raise an error if the number of units to prune is larger
722
# than the number of units in the tensor
723
_validate_pruning_amount(nparams_toprune, tensor_size)
725
# Structured pruning prunes entire channels so we need to know the
726
# L_n norm along each channel to then find the topk based on this
728
norm = _compute_norm(t, self.n, self.dim)
729
# largest=True --> top k; largest=False --> bottom k
730
# Keep the largest k channels along dim=self.dim
731
topk = torch.topk(norm, k=nparams_tokeep, largest=True)
732
# topk will have .indices and .values
734
# Compute binary mask by initializing it to all 0s and then filling in
735
# 1s wherever topk.indices indicates, along self.dim.
736
# mask has the same shape as tensor t
737
def make_mask(t, dim, indices):
739
mask = torch.zeros_like(t)
740
# e.g.: slc = [None, None, None], if len(t.shape) = 3
741
slc = [slice(None)] * len(t.shape)
742
# replace a None at position=dim with indices
743
# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
745
# use slc to slice mask and replace all its entries with 1s
746
# e.g.: mask[:, :, [0, 2, 3]] = 1
750
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
753
mask = make_mask(t, self.dim, topk.indices)
754
mask *= default_mask.to(dtype=mask.dtype)
759
def apply(cls, module, name, amount, n, dim, importance_scores=None):
760
r"""Add pruning on the fly and reparametrization of a tensor.
762
Adds the forward pre-hook that enables pruning on the fly and
763
the reparametrization of a tensor in terms of the original tensor
764
and the pruning mask.
767
module (nn.Module): module containing the tensor to prune
768
name (str): parameter name within ``module`` on which pruning
770
amount (int or float): quantity of parameters to prune.
771
If ``float``, should be between 0.0 and 1.0 and represent the
772
fraction of parameters to prune. If ``int``, it represents the
773
absolute number of parameters to prune.
774
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
775
entries for argument ``p`` in :func:`torch.norm`.
776
dim (int): index of the dim along which we define channels to
778
importance_scores (torch.Tensor): tensor of importance scores (of same
779
shape as module parameter) used to compute mask for pruning.
780
The values in this tensor indicate the importance of the corresponding
781
elements in the parameter being pruned.
782
If unspecified or None, the module parameter will be used in its place.
784
return super().apply(
790
importance_scores=importance_scores,
794
class CustomFromMask(BasePruningMethod):
796
PRUNING_TYPE = "global"
798
def __init__(self, mask):
801
def compute_mask(self, t, default_mask):
802
assert default_mask.shape == self.mask.shape
803
mask = default_mask * self.mask.to(dtype=default_mask.dtype)
807
def apply(cls, module, name, mask):
808
r"""Add pruning on the fly and reparametrization of a tensor.
810
Adds the forward pre-hook that enables pruning on the fly and
811
the reparametrization of a tensor in terms of the original tensor
812
and the pruning mask.
815
module (nn.Module): module containing the tensor to prune
816
name (str): parameter name within ``module`` on which pruning
819
return super().apply(module, name, mask=mask)
822
def identity(module, name):
823
r"""Apply pruning reparametrization without pruning any units.
825
Applies pruning reparametrization to the tensor corresponding to the
826
parameter called ``name`` in ``module`` without actually pruning any
827
units. Modifies module in place (and also return the modified module)
830
1) adding a named buffer called ``name+'_mask'`` corresponding to the
831
binary mask applied to the parameter ``name`` by the pruning method.
832
2) replacing the parameter ``name`` by its pruned version, while the
833
original (unpruned) parameter is stored in a new parameter named
837
The mask is a tensor of ones.
840
module (nn.Module): module containing the tensor to prune.
841
name (str): parameter name within ``module`` on which pruning
845
module (nn.Module): modified (i.e. pruned) version of the input module
848
>>> # xdoctest: +SKIP
849
>>> m = prune.identity(nn.Linear(2, 3), 'bias')
850
>>> print(m.bias_mask)
853
Identity.apply(module, name)
857
def random_unstructured(module, name, amount):
858
r"""Prune tensor by removing random (currently unpruned) units.
860
Prunes tensor corresponding to parameter called ``name`` in ``module``
861
by removing the specified ``amount`` of (currently unpruned) units
863
Modifies module in place (and also return the modified module) by:
865
1) adding a named buffer called ``name+'_mask'`` corresponding to the
866
binary mask applied to the parameter ``name`` by the pruning method.
867
2) replacing the parameter ``name`` by its pruned version, while the
868
original (unpruned) parameter is stored in a new parameter named
872
module (nn.Module): module containing the tensor to prune
873
name (str): parameter name within ``module`` on which pruning
875
amount (int or float): quantity of parameters to prune.
876
If ``float``, should be between 0.0 and 1.0 and represent the
877
fraction of parameters to prune. If ``int``, it represents the
878
absolute number of parameters to prune.
881
module (nn.Module): modified (i.e. pruned) version of the input module
884
>>> # xdoctest: +SKIP
885
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
886
>>> torch.sum(m.weight_mask == 0)
890
RandomUnstructured.apply(module, name, amount)
894
def l1_unstructured(module, name, amount, importance_scores=None):
895
r"""Prune tensor by removing units with the lowest L1-norm.
897
Prunes tensor corresponding to parameter called ``name`` in ``module``
898
by removing the specified `amount` of (currently unpruned) units with the
900
Modifies module in place (and also return the modified module)
903
1) adding a named buffer called ``name+'_mask'`` corresponding to the
904
binary mask applied to the parameter ``name`` by the pruning method.
905
2) replacing the parameter ``name`` by its pruned version, while the
906
original (unpruned) parameter is stored in a new parameter named
910
module (nn.Module): module containing the tensor to prune
911
name (str): parameter name within ``module`` on which pruning
913
amount (int or float): quantity of parameters to prune.
914
If ``float``, should be between 0.0 and 1.0 and represent the
915
fraction of parameters to prune. If ``int``, it represents the
916
absolute number of parameters to prune.
917
importance_scores (torch.Tensor): tensor of importance scores (of same
918
shape as module parameter) used to compute mask for pruning.
919
The values in this tensor indicate the importance of the corresponding
920
elements in the parameter being pruned.
921
If unspecified or None, the module parameter will be used in its place.
924
module (nn.Module): modified (i.e. pruned) version of the input module
927
>>> # xdoctest: +SKIP
928
>>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
929
>>> m.state_dict().keys()
930
odict_keys(['bias', 'weight_orig', 'weight_mask'])
932
L1Unstructured.apply(
933
module, name, amount=amount, importance_scores=importance_scores
938
def random_structured(module, name, amount, dim):
939
r"""Prune tensor by removing random channels along the specified dimension.
941
Prunes tensor corresponding to parameter called ``name`` in ``module``
942
by removing the specified ``amount`` of (currently unpruned) channels
943
along the specified ``dim`` selected at random.
944
Modifies module in place (and also return the modified module)
947
1) adding a named buffer called ``name+'_mask'`` corresponding to the
948
binary mask applied to the parameter ``name`` by the pruning method.
949
2) replacing the parameter ``name`` by its pruned version, while the
950
original (unpruned) parameter is stored in a new parameter named
954
module (nn.Module): module containing the tensor to prune
955
name (str): parameter name within ``module`` on which pruning
957
amount (int or float): quantity of parameters to prune.
958
If ``float``, should be between 0.0 and 1.0 and represent the
959
fraction of parameters to prune. If ``int``, it represents the
960
absolute number of parameters to prune.
961
dim (int): index of the dim along which we define channels to prune.
964
module (nn.Module): modified (i.e. pruned) version of the input module
967
>>> # xdoctest: +SKIP
968
>>> m = prune.random_structured(
969
... nn.Linear(5, 3), 'weight', amount=3, dim=1
971
>>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
972
>>> print(columns_pruned)
975
RandomStructured.apply(module, name, amount, dim)
979
def ln_structured(module, name, amount, n, dim, importance_scores=None):
980
r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.
982
Prunes tensor corresponding to parameter called ``name`` in ``module``
983
by removing the specified ``amount`` of (currently unpruned) channels
984
along the specified ``dim`` with the lowest L\ ``n``-norm.
985
Modifies module in place (and also return the modified module)
988
1) adding a named buffer called ``name+'_mask'`` corresponding to the
989
binary mask applied to the parameter ``name`` by the pruning method.
990
2) replacing the parameter ``name`` by its pruned version, while the
991
original (unpruned) parameter is stored in a new parameter named
995
module (nn.Module): module containing the tensor to prune
996
name (str): parameter name within ``module`` on which pruning
998
amount (int or float): quantity of parameters to prune.
999
If ``float``, should be between 0.0 and 1.0 and represent the
1000
fraction of parameters to prune. If ``int``, it represents the
1001
absolute number of parameters to prune.
1002
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1003
entries for argument ``p`` in :func:`torch.norm`.
1004
dim (int): index of the dim along which we define channels to prune.
1005
importance_scores (torch.Tensor): tensor of importance scores (of same
1006
shape as module parameter) used to compute mask for pruning.
1007
The values in this tensor indicate the importance of the corresponding
1008
elements in the parameter being pruned.
1009
If unspecified or None, the module parameter will be used in its place.
1012
module (nn.Module): modified (i.e. pruned) version of the input module
1015
>>> from torch.nn.utils import prune
1016
>>> m = prune.ln_structured(
1017
... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
1021
module, name, amount, n, dim, importance_scores=importance_scores
1026
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
1028
Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``.
1030
Modifies modules in place by:
1032
1) adding a named buffer called ``name+'_mask'`` corresponding to the
1033
binary mask applied to the parameter ``name`` by the pruning method.
1034
2) replacing the parameter ``name`` by its pruned version, while the
1035
original (unpruned) parameter is stored in a new parameter named
1039
parameters (Iterable of (module, name) tuples): parameters of
1040
the model to prune in a global fashion, i.e. by aggregating all
1041
weights prior to deciding which ones to prune. module must be of
1042
type :class:`nn.Module`, and name must be a string.
1043
pruning_method (function): a valid pruning function from this module,
1044
or a custom one implemented by the user that satisfies the
1045
implementation guidelines and has ``PRUNING_TYPE='unstructured'``.
1046
importance_scores (dict): a dictionary mapping (module, name) tuples to
1047
the corresponding parameter's importance scores tensor. The tensor
1048
should be the same shape as the parameter, and is used for computing
1050
If unspecified or None, the parameter will be used in place of its
1052
kwargs: other keyword arguments such as:
1053
amount (int or float): quantity of parameters to prune across the
1054
specified parameters.
1055
If ``float``, should be between 0.0 and 1.0 and represent the
1056
fraction of parameters to prune. If ``int``, it represents the
1057
absolute number of parameters to prune.
1060
TypeError: if ``PRUNING_TYPE != 'unstructured'``
1063
Since global structured pruning doesn't make much sense unless the
1064
norm is normalized by the size of the parameter, we now limit the
1065
scope of global pruning to unstructured methods.
1068
>>> from torch.nn.utils import prune
1069
>>> from collections import OrderedDict
1070
>>> net = nn.Sequential(OrderedDict([
1071
... ('first', nn.Linear(10, 4)),
1072
... ('second', nn.Linear(4, 1)),
1074
>>> parameters_to_prune = (
1075
... (net.first, 'weight'),
1076
... (net.second, 'weight'),
1078
>>> prune.global_unstructured(
1079
... parameters_to_prune,
1080
... pruning_method=prune.L1Unstructured,
1083
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
1087
# ensure parameters is a list or generator of tuples
1088
if not isinstance(parameters, Iterable):
1089
raise TypeError("global_unstructured(): parameters is not an Iterable")
1091
importance_scores = importance_scores if importance_scores is not None else {}
1092
if not isinstance(importance_scores, dict):
1093
raise TypeError("global_unstructured(): importance_scores must be of type dict")
1095
# flatten importance scores to consider them all at once in global pruning
1096
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
1098
importance_scores.get((module, name), getattr(module, name))
1099
for (module, name) in parameters
1102
# similarly, flatten the masks (if they exist), or use a flattened vector
1103
# of 1s of the same dimensions as t
1104
default_mask = torch.nn.utils.parameters_to_vector(
1106
getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
1107
for (module, name) in parameters
1111
# use the canonical pruning methods to compute the new mask, even if the
1112
# parameter is now a flattened out version of `parameters`
1113
container = PruningContainer()
1114
container._tensor_name = "temp" # to make it match that of `method`
1115
method = pruning_method(**kwargs)
1116
method._tensor_name = "temp" # to make it match that of `container`
1117
if method.PRUNING_TYPE != "unstructured":
1119
'Only "unstructured" PRUNING_TYPE supported for '
1120
f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}"
1123
container.add_pruning_method(method)
1125
# use the `compute_mask` method from `PruningContainer` to combine the
1126
# mask computed by the new method with the pre-existing mask
1127
final_mask = container.compute_mask(relevant_importance_scores, default_mask)
1129
# Pointer for slicing the mask to match the shape of each parameter
1131
for module, name in parameters:
1133
param = getattr(module, name)
1134
# The length of the parameter
1135
num_param = param.numel()
1136
# Slice the mask, reshape it
1137
param_mask = final_mask[pointer : pointer + num_param].view_as(param)
1138
# Assign the correct pre-computed mask to each parameter and add it
1139
# to the forward_pre_hooks like any other pruning method
1140
custom_from_mask(module, name, mask=param_mask)
1142
# Increment the pointer to continue slicing the final_mask
1143
pointer += num_param
1146
def custom_from_mask(module, name, mask):
1147
r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``.
1149
Modifies module in place (and also return the modified module) by:
1151
1) adding a named buffer called ``name+'_mask'`` corresponding to the
1152
binary mask applied to the parameter ``name`` by the pruning method.
1153
2) replacing the parameter ``name`` by its pruned version, while the
1154
original (unpruned) parameter is stored in a new parameter named
1158
module (nn.Module): module containing the tensor to prune
1159
name (str): parameter name within ``module`` on which pruning
1161
mask (Tensor): binary mask to be applied to the parameter.
1164
module (nn.Module): modified (i.e. pruned) version of the input module
1167
>>> from torch.nn.utils import prune
1168
>>> m = prune.custom_from_mask(
1169
... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
1171
>>> print(m.bias_mask)
1172
tensor([0., 1., 0.])
1175
CustomFromMask.apply(module, name, mask)
1179
def remove(module, name):
1180
r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook.
1182
The pruned parameter named ``name`` remains permanently pruned, and the parameter
1183
named ``name+'_orig'`` is removed from the parameter list. Similarly,
1184
the buffer named ``name+'_mask'`` is removed from the buffers.
1187
Pruning itself is NOT undone or reversed!
1190
module (nn.Module): module containing the tensor to prune
1191
name (str): parameter name within ``module`` on which pruning
1195
>>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)
1196
>>> m = remove(m, name='weight')
1198
for k, hook in module._forward_pre_hooks.items():
1199
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
1201
del module._forward_pre_hooks[k]
1205
f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed"
1209
def is_pruned(module):
1210
r"""Check if a module is pruned by looking for pruning pre-hooks.
1212
Check whether ``module`` is pruned by looking for
1213
``forward_pre_hooks`` in its modules that inherit from the
1214
:class:`BasePruningMethod`.
1217
module (nn.Module): object that is either pruned or unpruned
1220
binary answer to whether ``module`` is pruned.
1223
>>> from torch.nn.utils import prune
1224
>>> m = nn.Linear(5, 7)
1225
>>> print(prune.is_pruned(m))
1227
>>> prune.random_unstructured(m, name='weight', amount=0.2)
1228
>>> print(prune.is_pruned(m))
1231
for _, submodule in module.named_modules():
1232
for hook in submodule._forward_pre_hooks.values():
1233
if isinstance(hook, BasePruningMethod):
1238
def _validate_pruning_amount_init(amount):
1239
r"""Validate helper to check the range of amount at init.
1242
amount (int or float): quantity of parameters to prune.
1243
If float, should be between 0.0 and 1.0 and represent the
1244
fraction of parameters to prune. If int, it represents the
1245
absolute number of parameters to prune.
1248
ValueError: if amount is a float not in [0, 1], or if it's a negative
1250
TypeError: if amount is neither a float nor an integer.
1253
This does not take into account the number of parameters in the
1254
tensor to be pruned, which is known only at prune.
1256
if not isinstance(amount, numbers.Real):
1258
f"Invalid type for amount: {amount}. Must be int or float."
1261
if (isinstance(amount, numbers.Integral) and amount < 0) or (
1262
not isinstance(amount, numbers.Integral) # so it's a float
1263
and (float(amount) > 1.0 or float(amount) < 0.0)
1266
f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer"
1270
def _validate_pruning_amount(amount, tensor_size):
1271
r"""Validate that the pruning amount is meaningful wrt to the size of the data.
1273
Validation helper to check that the amount of parameters to prune
1274
is meaningful wrt to the size of the data (`tensor_size`).
1277
amount (int or float): quantity of parameters to prune.
1278
If float, should be between 0.0 and 1.0 and represent the
1279
fraction of parameters to prune. If int, it represents the
1280
absolute number of parameters to prune.
1281
tensor_size (int): absolute number of parameters in the tensor
1284
# TODO: consider removing this check and allowing users to specify
1285
# a number of units to prune that is greater than the number of units
1286
# left to prune. In this case, the tensor will just be fully pruned.
1288
if isinstance(amount, numbers.Integral) and amount > tensor_size:
1290
f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}"
1294
def _validate_structured_pruning(t):
1295
r"""Validate that the tensor to be pruned is at least 2-Dimensional.
1297
Validation helper to check that the tensor to be pruned is multi-
1298
dimensional, such that the concept of "channels" is well-defined.
1301
t (torch.Tensor): tensor representing the parameter to prune
1304
ValueError: if the tensor `t` is not at least 2D.
1309
"Structured pruning can only be applied to "
1310
"multidimensional tensors. Found tensor of shape "
1311
f"{shape} with {len(shape)} dims"
1315
def _compute_nparams_toprune(amount, tensor_size):
1316
r"""Convert the pruning amount from a percentage to absolute value.
1318
Since amount can be expressed either in absolute value or as a
1319
percentage of the number of units/channels in a tensor, this utility
1320
function converts the percentage to absolute value to standardize
1321
the handling of pruning.
1324
amount (int or float): quantity of parameters to prune.
1325
If float, should be between 0.0 and 1.0 and represent the
1326
fraction of parameters to prune. If int, it represents the
1327
absolute number of parameters to prune.
1328
tensor_size (int): absolute number of parameters in the tensor
1332
int: the number of units to prune in the tensor
1334
# incorrect type already checked in _validate_pruning_amount_init
1335
if isinstance(amount, numbers.Integral):
1338
return round(amount * tensor_size)
1341
def _validate_pruning_dim(t, dim):
1342
r"""Validate that the pruning dimension is within the bounds of the tensor dimension.
1345
t (torch.Tensor): tensor representing the parameter to prune
1346
dim (int): index of the dim along which we define channels to prune
1349
raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}")
1352
def _compute_norm(t, n, dim):
1353
r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension.
1355
The L_n-norm will be computed across all entries in tensor `t` along all dimension
1356
except for the one identified by dim.
1357
Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),
1358
then norm will have Size [4], and each entry will represent the
1359
`L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.
1362
t (torch.Tensor): tensor representing the parameter to prune
1363
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1364
entries for argument p in torch.norm
1365
dim (int): dim identifying the channels to prune
1368
norm (torch.Tensor): L_n norm computed across all dimensions except
1369
for `dim`. By construction, `norm.shape = t.shape[-1]`.
1371
# dims = all axes, except for the one identified by `dim`
1372
dims = list(range(t.dim()))
1373
# convert negative indexing
1378
norm = torch.norm(t, p=n, dim=dims)