pytorch

Форк
0
/
regularizer.py 
549 строк · 20.3 Кб
1
# @package optimizer
2
# Module caffe2.python.regularizer
3

4

5
from caffe2.python import core, utils
6
import numpy as np
7

8

9
class RegularizationBy:
10
    AFTER_OPTIMIZER = "after_optimizer"
11
    ON_LOSS = "on_loss"
12

13

14
class Regularizer:
15
    def __init__(self):
16
        self.kEpsilon = 1e-9
17

18
    """
19
    Adds regularization to train_net for given parameter. Its factor ahead of
20
    regularization is given when initialization.
21
    The param should be a BlobReference.
22
    """
23

24
    def __call__(self, net, param_init_net, param, grad=None, by=None):
25
        assert isinstance(param, core.BlobReference)
26
        by_enum = utils.EnumClassKeyVals(RegularizationBy)
27
        assert by in by_enum.values(), (
28
            "Regularizer of type {} is called with invalid by={}, "
29
            "not in {}".format(self.__class__, by, by_enum.values())
30
        )
31
        run_func = "_run_" + by
32
        assert hasattr(
33
            self, run_func
34
        ), "Regularizer of type {} does not implement function {}".format(
35
            self.__class__, run_func
36
        )
37
        return getattr(self, run_func)(net, param_init_net, param, grad)
38

39
    def _run_on_loss(self, net, param_init_net, param, grad=None):
40
        return None
41

42
    def _run_after_optimizer(self, net, param_init_net, param, grad):
43
        return None
44

45
    def _feature_grouping(self, param, net):
46
        # Possible alternative grouping method via summing over absolute values
47
        # Compute l2norm over feature weights
48
        # pow( sum_i { pow(theda_i, 2) } ,  0.5)
49
        param_mul = net.Mul([param, param], [net.NextScopedBlob("param_mul")])
50
        param_reduced = net.ReduceFrontSum(
51
            [param_mul], [net.NextScopedBlob("param_reduced")]
52
        )
53
        grouped_feature_weight_vec = net.Pow(
54
            [param_reduced],
55
            [net.NextScopedBlob("grouped_feature_weight_vec")],
56
            exponent=0.5,
57
        )
58

59
        return grouped_feature_weight_vec
60

61
    def _ensure_clipped(
62
        self,
63
        net,
64
        param,
65
        grad=None,
66
        min=None,
67
        max=None,
68
        open_range=False,
69
        left_open=False,
70
        right_open=False,
71
    ):
72
        min = (
73
            min + self.kEpsilon
74
            if min is not None and (open_range or left_open)
75
            else min
76
        )
77
        max = (
78
            max - self.kEpsilon
79
            if max is not None and (open_range or right_open)
80
            else max
81
        )
82
        input_blobs = (
83
            [param, grad.indices, grad.values]
84
            if isinstance(grad, core.GradientSlice)
85
            else [param]
86
        )
87
        net.EnsureClipped(input_blobs, [param], min=min, max=max)
88

89

90
class L1Norm(Regularizer):
91
    def __init__(self, reg_lambda):
92
        super().__init__()
93
        assert reg_lambda >= 0, "factor ahead of regularization should be 0 or positive"
94

95
        self.reg_lambda = reg_lambda
96

97
    def _run_on_loss(self, net, param_init_net, param, grad=None):
98
        output_blob = net.NextScopedBlob(param + "_l1_regularization")
99
        net.LpNorm([param], [output_blob], p=1)
100
        net.Scale([output_blob], [output_blob], scale=self.reg_lambda)
101
        return output_blob
102

103
class LpNorm(Regularizer):
104
    def __init__(self, reg_lambda, p_value=0.5):
105
        """
106
        reg_lambda: parameter to scale regularization by
107

108
        p_value:    determines what type of Lp norm to calculate. If p > 0,
109
                    we will calculate Lp norm with the formula:
110
                    pow( sum_i { pow(theda_i, p) } ,  1/p)
111
        """
112
        super().__init__()
113
        assert reg_lambda > 0, "factor ahead of regularization should be greater than 0"
114
        assert p_value > 0, "p_value factor should be greater than 0"
115
        self.p_value = p_value
116
        self.reg_lambda = reg_lambda
117

118

119
    def _run_on_loss(self, net, param_init_net, param, grad=None):
120
        # TODO: the second dim (num of input nodes) of param is after feature preproc,
121
        # and does not correspond to the original num of dense features.
122
        # In the future, will want to create a util to reduce the input dim of param to
123
        # match the num of dense features.
124

125
        output_blob = net.NextScopedBlob(param + "_dense_feature_regularization")
126
        grouped_feature_weight_vec = self._feature_grouping(param, net)
127

128
        # Compute Lpnorm:
129
        # pow( sum_i { pow(theda_i, p) } ,  1/p)
130
        lp_vec_raised = net.Pow(
131
            [grouped_feature_weight_vec],
132
            [net.NextScopedBlob("lp_vec_raised")],
133
            exponent=self.p_value,
134
        )
135
        lp_vec_summed = net.ReduceFrontSum(
136
            [lp_vec_raised], [net.NextScopedBlob("lp_vec_summed")]
137
        )
138
        lp_norm = net.Pow(
139
            [lp_vec_summed],
140
            [net.NextScopedBlob("lp_vec")],
141
            exponent=(1 / self.p_value),
142
        )
143
        net.Scale([lp_norm], [output_blob], scale=self.reg_lambda)
144
        return output_blob
145

146

147
class L0ApproxNorm(Regularizer):
148
    def __init__(self, reg_lambda, alpha=0.01, budget=0):
149
        """
150
        reg_lambda: parameter to scale regularization by
151

152
        alpha:      hyper parameter to tune that is only used in the calculation
153
                    of approximate L0 norm
154

155
        budget:     desired number of features. If the number of features is greater
156
                    than the budget amount, then the least important features will
157
                    be penalized. If there are fewer features than the desired
158
                    budget, no penalization will be applied. Optional parameter, if
159
                    0, then no budget is used
160
        """
161
        super().__init__()
162
        assert reg_lambda > 0, "factor ahead of regularization should be greater than 0"
163
        assert alpha > 0, "alpha factor must be a positive value greater than 0"
164
        assert budget >= 0, "budget factor must be greater than or equal to 0"
165
        self.reg_lambda = reg_lambda
166
        self.alpha = alpha
167
        self.budget = float(budget)  # budget must be float for future calculations
168

169
    def _run_on_loss(self, net, param_init_net, param, grad=None):
170
        # TODO: the second dim (num of input nodes) of param is after feature preproc,
171
        # and does not correspond to the original num of dense features.
172
        # In the future, will want to create a util to reduce the input dim of param to
173
        # match the num of dense features.
174

175
        output_blob = net.NextScopedBlob(param + "_dense_feature_regularization")
176
        grouped_feature_weight_vec = self._feature_grouping(param, net)
177

178
        # compute approximate L0 norm
179
        # sum_i ( min ( abs (theta_i), alpha))) / alpha
180
        l0_abs = net.Abs([grouped_feature_weight_vec], [net.NextScopedBlob("l0_abs")])
181
        l0_min = net.Clip([l0_abs], [net.NextScopedBlob("l0_min")], max=self.alpha)
182
        l0_summed = net.ReduceFrontSum([l0_min], [net.NextScopedBlob("l0_summed")])
183
        l0_norm = net.Scale(
184
            [l0_summed], [net.NextScopedBlob("l0_norm")], scale=(1 / self.alpha)
185
        )
186

187
        # incorporate budget factor
188
        # regularization = reg_lambda * max(0, l0_norm - budget)
189
        if self.budget:
190
            budget_blob = net.ConstantFill([], "budget", shape=[1], value=self.budget)
191
            l0_sub_budget = net.Sub(
192
                [l0_norm, budget_blob], [net.NextScopedBlob("l0_budget")]
193
            )
194
            relu_l0_sub_budget = net.Relu(
195
                [l0_sub_budget], [net.NextScopedBlob("relu_l0_sub_budget")]
196
            )
197
            net.Scale([relu_l0_sub_budget], [output_blob], scale=self.reg_lambda)
198
        else:
199
            net.Scale([l0_norm], [output_blob], scale=self.reg_lambda)
200
        return output_blob
201

202
class L1NormTrimmed(Regularizer):
203
    """
204
    The Trimmed Lasso: Sparsity and Robustness. https://arxiv.org/abs/1708.04527
205
    """
206
    def __init__(self, reg_lambda, k):
207
        super().__init__()
208
        assert reg_lambda >= 0, "factor ahead of regularization should be 0 or positive"
209
        assert isinstance(k, int), "k should be an interger as expected #. after selection"
210
        assert k >= 1, "k should be larger than 1"
211

212
        self.reg_lambda = reg_lambda
213
        self.k = k
214

215
    def _run_on_loss(self, net, param_init_net, param, grad=None):
216
        output_blob = net.NextScopedBlob(param + "_l1_trimmed_regularization")
217
        abs = net.Abs([param], [net.NextScopedBlob("abs")])
218
        sum_abs = net.SumElements([abs], [net.NextScopedBlob("sum_abs")], average=False)
219
        topk, _, _ = net.TopK([abs], [net.NextScopedBlob("topk"), net.NextScopedBlob("id"), net.NextScopedBlob("flat_id")], k=self.k)
220
        topk_sum = net.SumElements([topk], [net.NextScopedBlob("topk_sum")], average=False)
221
        net.Sub([sum_abs, topk_sum], [output_blob])
222
        net.Scale([output_blob], [output_blob], scale=self.reg_lambda)
223
        return output_blob
224

225

226
class L2Norm(Regularizer):
227
    def __init__(self, reg_lambda):
228
        super().__init__()
229
        assert reg_lambda >= 0, "factor ahead of regularization should be 0 or positive"
230

231
        self.reg_lambda = reg_lambda
232

233
    def _run_on_loss(self, net, param_init_net, param, grad=None):
234
        output_blob = net.NextScopedBlob(param + "_l2_regularization")
235
        net.LpNorm([param], [output_blob], p=2)
236
        net.Scale([output_blob], [output_blob], scale=self.reg_lambda)
237
        return output_blob
238

239

240
class ElasticNet(Regularizer):
241
    def __init__(self, l1, l2):
242
        super().__init__()
243
        self.l1 = l1
244
        self.l2 = l2
245

246
    def _run_on_loss(self, net, param_init_net, param, grad=None):
247
        output_blob = net.NextScopedBlob(param + "_elastic_net_regularization")
248
        l2_blob = net.NextScopedBlob(param + "_l2_blob")
249
        l1_blob = net.NextScopedBlob(param + "_l1_blob")
250
        net.LpNorm([param], [l2_blob], p=2)
251
        net.LpNorm([param], [l1_blob], p=1)
252
        net.Scale([l2_blob], [l2_blob], scale=self.l2)
253
        net.Scale([l1_blob], [l1_blob], scale=self.l1)
254
        net.Add([l1_blob, l2_blob], [output_blob])
255
        return output_blob
256

257

258
class ElasticNetL1NormTrimmed(Regularizer):
259
    def __init__(self, l1, l2, k):
260
        super().__init__()
261
        self.l1 = l1
262
        self.l2 = l2
263
        self.k = k
264

265
    def _run_on_loss(self, net, param_init_net, param, grad=None):
266
        output_blob = net.NextScopedBlob(param + "_elastic_net_l1_trimmed_regularization")
267
        l2_blob = net.NextScopedBlob(param + "_l2_blob")
268
        net.LpNorm([param], [l2_blob], p=2)
269
        net.Scale([l2_blob], [l2_blob], scale=self.l2)
270

271
        l1_blob = net.NextScopedBlob(param + "_l1_blob")
272
        abs = net.Abs([param], [net.NextScopedBlob("abs")])
273
        sum_abs = net.SumElements([abs], [net.NextScopedBlob("sum_abs")], average=False)
274
        topk, _, _ = net.TopK([abs], [net.NextScopedBlob("topk"), net.NextScopedBlob("id"), net.NextScopedBlob("flat_id")], k=self.k)
275
        topk_sum = net.SumElements([topk], [net.NextScopedBlob("topk_sum")], average=False)
276
        net.Sub([sum_abs, topk_sum], [l1_blob])
277
        net.Scale([l1_blob], [l1_blob], scale=self.l1)
278

279
        net.Add([l1_blob, l2_blob], [output_blob])
280
        return output_blob
281

282

283
class MaxNorm(Regularizer):
284
    def __init__(self, norm=1.0, dtype=None):
285
        super().__init__()
286
        self.norm = norm
287
        self.dtype = dtype
288

289
    def _run_after_optimizer(self, net, param_init_net, param, grad):
290
        assert self.norm > 0, "norm should be bigger than 0."
291
        if isinstance(grad, core.GradientSlice):
292
            if self.dtype and self.dtype == 'fp16':
293
                net.Float16SparseNormalize(
294
                    [param, grad.indices],
295
                    [param],
296
                    use_max_norm=True,
297
                    norm=self.norm,
298
                )
299
            else:
300
                net.SparseNormalize(
301
                    [param, grad.indices],
302
                    [param],
303
                    use_max_norm=True,
304
                    norm=self.norm,
305
                )
306
        else:
307
            raise NotImplementedError("MaxNorm is not supported for dense parameters")
308

309

310
class ConstantNorm(Regularizer):
311
    def __init__(self, norm=1.0):
312
        super().__init__()
313
        self.norm = norm
314

315
    def _run_after_optimizer(self, net, param_init_net, param, grad):
316
        assert self.norm > 0, "norm should be bigger than 0."
317
        if isinstance(grad, core.GradientSlice):
318
            net.SparseNormalize(
319
                [param, grad.indices],
320
                [param],
321
                use_max_norm=False,
322
                norm=self.norm,
323
            )
324
        else:
325
            raise NotImplementedError(
326
                "ConstantNorm is not supported for dense parameters"
327
            )
328

329

330
class SparseLpNorm(Regularizer):
331
    def __init__(self, p, reg_lambda):
332
        super().__init__()
333
        assert p in (1.0, 2.0), "Sparse Lp regularization only implemented for p = 1.0 and p = 2.0."
334
        assert reg_lambda > 0, "factor ahead of regularization should be greater than 0."
335
        self.p = p
336
        self.reg_lambda = reg_lambda
337

338
    def _run_after_optimizer(self, net, param_init_net, param, grad):
339
        if isinstance(grad, core.GradientSlice):
340
            net.SparseLpRegularizer(
341
                [param, grad.indices],
342
                [param],
343
                p=self.p,
344
                reg_lambda=self.reg_lambda,
345
            )
346
        else:
347
            raise NotImplementedError("SparseLpNorm is not supported for dense parameters")
348

349

350
class SparseL1Norm(SparseLpNorm):
351
    def __init__(self, reg_lambda):
352
        super().__init__(p=1.0, reg_lambda=reg_lambda)
353

354

355
class SparseL2Norm(SparseLpNorm):
356
    def __init__(self, reg_lambda):
357
        super().__init__(p=2.0, reg_lambda=reg_lambda)
358

359

360
class LogBarrier(Regularizer):
361
    """
362
    Wright, S., & Nocedal, J. (1999). Numerical optimization. Springer Science,
363
    35(67-68), 7. Chapter 19
364
    """
365

366
    def __init__(self, reg_lambda, discount_policy="inv", discount_options=None):
367
        """
368
        discount is a positive weight that is decreasing, and here it is implemented
369
        similar to the learning rate. It is specified by a learning rate policy and
370
        corresponding options
371
        """
372
        super().__init__()
373
        assert reg_lambda > 0, "factor ahead of regularization should be 0 or positive"
374
        self.reg_lambda = reg_lambda
375
        self.discount_policy = discount_policy
376
        self.discount_options = discount_options or {"gamma": 1.0, "power": 1.0}
377

378
    def _run_on_loss(self, net, param_init_net, param, grad=None):
379
        iteration = utils.BuildUniqueMutexIter(param_init_net, net)
380
        # Since we are most likely to do a minimization
381
        discount = net.NextScopedBlob(param + "_log_barrier_discount")
382
        net.LearningRate(
383
            [iteration],
384
            [discount],
385
            base_lr=-self.reg_lambda,
386
            policy=self.discount_policy,
387
            **self.discount_options
388
        )
389
        # TODO(xlwang): param might still be negative at the initialization time or
390
        # slightly negative due to the distributed training. Enforce it's non-negativity
391
        # for now (at least above machine epsilon)
392
        param_non_neg = net.NextScopedBlob(param + "_non_neg")
393
        net.Clip([param], [param_non_neg], min=self.kEpsilon)
394
        param_log = net.NextScopedBlob(param + "_log")
395
        net.Log([param_non_neg], [param_log])
396
        param_log_sum = net.NextScopedBlob(param + "_log_sum")
397
        net.SumElements([param_log], [param_log_sum])
398
        output_blob = net.NextScopedBlob(param + "_log_barrier")
399
        net.Mul([param_log_sum, discount], [output_blob], broadcast=1)
400
        return output_blob
401

402
    def _run_after_optimizer(self, net, param_init_net, param, grad):
403
        self._ensure_clipped(net, param, grad, min=0, open_range=True)
404

405

406
class BoundedGradientProjection(Regularizer):
407
    """
408
    Wright, S., & Nocedal, J. (1999). Numerical optimization. Springer Science,
409
    35(67-68), 7. Chapter 16
410
    """
411

412
    def __init__(
413
        self, lb=None, ub=None, left_open=False, right_open=False, epsilon=None
414
    ):
415
        super().__init__()
416
        lb = float(lb) if lb is not None else None
417
        ub = float(ub) if ub is not None else None
418
        epsilon = float(epsilon) if epsilon is not None else self.kEpsilon
419
        assert epsilon > 0, "Bounded Gradient Projection with invalid eps={eps}".format(
420
            eps=epsilon
421
        )
422
        assert (
423
            (lb is None)
424
            or (ub is None)
425
            or (
426
                lb + (epsilon if left_open else 0.)
427
                <= ub - (epsilon if right_open else 0.)
428
            )
429
        ), (
430
            "Bounded Gradient Projection with invalid "
431
            "{lp}ub={ub}, lb={lb}{rp}, eps={eps}".format(
432
                lb=lb,
433
                ub=ub,
434
                lp="(" if left_open else "[",
435
                rp=")" if right_open else "]",
436
                eps=epsilon,
437
            )
438
        )
439
        self.left_open = left_open
440
        self.right_open = right_open
441
        self.kEpsilon = epsilon
442
        self.lb = lb
443
        self.ub = ub
444

445
    def _run_after_optimizer(self, net, param_init_net, param, grad):
446
        self._ensure_clipped(
447
            net,
448
            param,
449
            grad,
450
            min=self.lb,
451
            max=self.ub,
452
            left_open=self.left_open,
453
            right_open=self.right_open,
454
        )
455

456

457
class GroupL1Norm(Regularizer):
458
    """
459
    Scardapane, Simone, et al. "Group sparse regularization for deep neural networks."
460
    Neurocomputing 241 (2017): 81-89.
461

462
    This regularizer computes l1 norm of a weight matrix based on groups.
463
    There are essentially three stages in the computation:
464
    1. Compute the l2 norm on all the members of each group
465
    2. Scale each l2 norm by the size of each group
466
    3. Compute the l1 norm of the scaled l2 norms
467
    """
468
    def __init__(self, reg_lambda, groups, stabilizing_val=0):
469
        """
470
        Args:
471
            reg_lambda: The weight of the regularization term.
472
            groups: A list of integers describing the size of each group.
473
                The length of the list is the number of groups.
474

475
        Optional Args:
476
            stabilizing_val: The computation of GroupL1Norm involves the Sqrt
477
                operator. When values are small, its gradient can be numerically
478
                unstable and causing gradient explosion. Adding this term to
479
                stabilize gradient calculation. Recommended value of this term is
480
                1e-8, but it depends on the specific scenarios. If the implementation
481
                of the gradient operator of Sqrt has taken into stability into
482
                consideration, this term won't be necessary.
483
        """
484
        super().__init__()
485
        assert (
486
            (reg_lambda) >= 0
487
        ), "regularization weight should be 0 or positive"
488
        assert isinstance(groups, list), "groups needs to be a list"
489

490
        self.reg_lambda = (reg_lambda)
491
        self.groups = groups
492
        self.stabilizing_val = stabilizing_val
493

494
    def _run_on_loss(self, net, param_init_net, param, grad=None):
495
        """
496
        Args:
497
            param: The input blob to regularize. It should be a weight matrix
498
                blob with shape (output_dim, input_dim). input_dim should be
499
                equal to the sum of self.groups.
500

501
        Returns:
502
            group_l1_norm: The output blob after applying regularization.
503

504
        These are the steps of computation:
505
            1. square all elements
506
            2. sum by row
507
            3. lengthssum by group
508
            4. square_root all elements
509
            5. normalize each group based on group size
510
            6. compute l1 norm of each group
511
            7. scale the result with the regularization lambda
512
        """
513
        squared = net.Sqr(param)
514
        reduced_sum = net.ReduceSum(squared, axes=[0], keepdims=0)
515
        lengths_sum = net.LengthsSum(
516
            [
517
                reduced_sum,
518
                net.GivenTensorIntFill(
519
                    [], 1, shape=[len(self.groups)], values=self.groups
520
                ),
521
            ]
522
        )
523

524
        if self.stabilizing_val:
525
            net.Add(
526
                [lengths_sum, net.ConstantFill([], 1, value=self.stabilizing_val)],
527
                [lengths_sum],
528
                broadcast=1,
529
            )
530

531
        sqrt = net.Sqrt(lengths_sum)
532

533
        # Here we combine step 5 and step 7 into one operator call to
534
        # improve efficiency: values = np.sqrt(self.groups) * self.reg_lambda
535
        l2_scaled = net.Mul(
536
            [
537
                sqrt,
538
                net.GivenTensorFill(
539
                    [],
540
                    shape=[len(self.groups)],
541
                    values=np.sqrt(self.groups) * self.reg_lambda
542
                )
543
            ],
544
            ['normalized_l2_norm_scaled']
545
        )
546

547
        group_l1_norm = net.LpNorm(l2_scaled, ['group_l1_nrom'], p=1)
548

549
        return group_l1_norm
550

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.