pytorch

Форк
0
/
optimizer.py 
2368 строк · 81.4 Кб
1
# @package optimizer
2
# Module caffe2.python.optimizer
3

4

5
import copy
6
import logging
7
from collections import defaultdict, namedtuple
8
from typing import Any, Dict
9

10
import numpy as np
11
from caffe2.proto import caffe2_pb2
12
from caffe2.python import core, scope, utils, workspace
13
from caffe2.python.modeling import parameter_info
14
from past.builtins import basestring
15

16

17
_LEARNING_RATE_INJECTION = "lr_injection"
18

19
AuxOptimizerParams = namedtuple("AuxOptimizerParams", ["local", "shared"])
20
_optimizer_instance_count = defaultdict(int)
21

22
FP16_ENGINES = ["SIMD_Q_FP16", "SIMD_Q_STOC_FP16", "SIMD_Q_STOC_MKL_FP16"]
23

24
logger = logging.getLogger(__name__)
25

26
def reset_optimizer_instance_count():
27
    """
28
    This function clears the _optimizer_instance_count. And keeps it
29
    empty. This functionality is needed in some situations where
30
    optimizer instance count might not reset even though the workplace is reset.
31
    """
32
    _optimizer_instance_count.clear()
33

34

35
class Optimizer:
36
    def __init__(self):
37
        self._aux_params = AuxOptimizerParams(local=[], shared=[])
38
        self._instance_num = _optimizer_instance_count[self.__class__.__name__]
39
        _optimizer_instance_count[self.__class__.__name__] += 1
40
        self._lr_multiplier = None
41
        self._local_lr_multiplier = None
42
        self._local_lr_multiplier_on_gpu = False
43
        self._use_dedicated_lr_iteration_counter = False
44

45
    """
46
    Adds optimization operators to the net for given parameter and its gradient
47
    Parameter is specified by either 'param' being a ParameterInfo object.
48
    In this case  param.grad has to be set
49

50
    Or by 'param' being a BlobReference and 'grad' being a BlobReference for its
51
    gradient.
52
    """
53

54
    def __call__(self, net, param_init_net, param, grad=None):
55
        if grad is None:
56
            assert isinstance(
57
                param, parameter_info.ParameterInfo
58
            ), "Expected parameter to be of type ParameterInfo, got {}".format(param)
59
            assert param.grad is not None
60
        else:
61
            if isinstance(param, basestring):
62
                param = core.BlobReference(param)
63
            param = parameter_info.ParameterInfo(param_id=None, param=param, grad=grad)
64

65
        self._run(net, param_init_net, param)
66

67
    def _run(self, net, param_init_net, param_info):
68
        raise Exception("Not Implemented")
69

70
    def get_cpu_blob_name(self, base_str, node_name=""):
71
        classname = self.__class__.__name__
72
        return "%s_%d_%s%s_cpu" % (classname, self._instance_num, base_str, node_name)
73

74
    def get_gpu_blob_name(self, base_str, gpu_id, node_name):
75
        classname = self.__class__.__name__
76
        return "%s_%d_%s%s_gpu%d" % (
77
            classname,
78
            self._instance_num,
79
            base_str,
80
            node_name,
81
            gpu_id,
82
        )
83

84
    @property
85
    def attributes(self):
86
        # return a dict that contains attributes related to init args only
87
        attr = copy.deepcopy(self.__dict__)
88
        del attr["_instance_num"]
89
        return attr
90

91
    @property
92
    def use_dedicated_lr_iteration_counter(self):
93
        return self._use_dedicated_lr_iteration_counter
94

95
    @use_dedicated_lr_iteration_counter.setter
96
    def use_dedicated_lr_iteration_counter(self, val):
97
        self._use_dedicated_lr_iteration_counter = val
98

99
    def make_unique_blob_name(self, base_str):
100
        """
101
        Returns a blob name that will be unique to the current device
102
        and optimizer instance.
103
        """
104
        current_scope = scope.CurrentDeviceScope()
105
        if current_scope is None:
106
            return self.get_cpu_blob_name(base_str)
107

108
        if core.IsGPUDeviceType(current_scope.device_type):
109
            return self.get_gpu_blob_name(
110
                base_str, current_scope.device_id, current_scope.node_name
111
            )
112
        else:
113
            return self.get_cpu_blob_name(base_str, current_scope.node_name)
114

115
    def build_lr(
116
        self,
117
        net,
118
        param_init_net,
119
        base_learning_rate,
120
        learning_rate_blob=None,
121
        policy="fixed",
122
        iter_val=0,
123
        **kwargs
124
    ):
125
        if learning_rate_blob is None:
126
            learning_rate_blob = self.make_unique_blob_name("lr")
127

128
        if self._use_dedicated_lr_iteration_counter:
129
            iteration = utils.BuildUniqueMutexIter(
130
                param_init_net,
131
                net,
132
                iter=utils.OPTIMIZER_ITERATION_LR_NAME,
133
                iter_mutex=utils.ITERATION_MUTEX_LR_NAME,
134
                iter_val=iter_val,
135
            )
136
            logger.info(f"Created dedicated learning rate iteration counter: {iteration}")
137
        else:
138
            iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
139

140
        if not net.BlobIsDefined(learning_rate_blob):
141
            # There is one interesting thing here: since we are minimizing, we are
142
            # doing "descent" so the learning rate is set to be negative.
143
            lr = net.LearningRate(
144
                [iteration],
145
                learning_rate_blob,
146
                base_lr=-base_learning_rate,
147
                policy=policy,
148
                **kwargs
149
            )
150
        else:
151
            lr = net.GetBlobRef(learning_rate_blob)
152

153
        if self._lr_multiplier is not None:
154
            lr_multiplier = net.CopyFromCPUInput(
155
                self._lr_multiplier, self.make_unique_blob_name("lr_multiplier")
156
            )
157

158
            lr = net.Mul(
159
                [lr, lr_multiplier],
160
                self.make_unique_blob_name("scaled_lr"),
161
                broadcast=1,
162
            )
163

164
        if self._local_lr_multiplier is not None:
165
            current_scope = scope.CurrentDeviceScope()
166
            if (
167
                current_scope is not None
168
                and core.IsGPUDeviceType(current_scope.device_type)
169
                and not self._local_lr_multiplier_on_gpu
170
            ):
171
                local_lr_multiplier = net.CopyFromCPUInput(
172
                    self._local_lr_multiplier,
173
                    self.make_unique_blob_name("local_lr_multiplier"),
174
                )
175
            else:
176
                local_lr_multiplier = self._local_lr_multiplier
177

178
            lr = net.Mul(
179
                [lr, local_lr_multiplier],
180
                self.make_unique_blob_name("local_scaled_lr"),
181
                broadcast=1,
182
            )
183

184
        return lr, iteration
185

186
    def build_non_lr_iter(
187
        self,
188
        net,
189
        param_init_net,
190
        iter_val=0,
191
    ):
192
        assert (
193
            self._use_dedicated_lr_iteration_counter
194
        ), "This method should be only called when dedicated learning rate iteration counter is used."
195

196
        iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
197
        logger.info(f"Created iteration counter for non learning rate purposes: {iteration}")
198

199
        # We need to create a dummy learning rate operator to enforce that
200
        # iteration counter blob being placed in the trainer nodes. Otherwise,
201
        # the Automatic Device Placement (ADP) algorithm for Hierachical
202
        # Training (HT) will encounter issues to distribute blobs across group
203
        # parameter servers. Note that this learning rate operator will not be
204
        # used for any other purpose.
205
        learning_rate_blob = self.make_unique_blob_name("iter_placement_hint")
206
        if not net.BlobIsDefined(learning_rate_blob):
207
            net.LearningRate(
208
                [iteration],
209
                learning_rate_blob,
210
                base_lr=1.0,
211
                policy="fixed",
212
            )
213

214
        return iteration
215

216
    def add_lr_multiplier(self, lr_multiplier):
217
        """
218
        Set the global learning rate multiplier. If a multiplier already
219
        existed, this will overwrite the existing multiplier. The multiplier is
220
        used for all future calls to _run(), unless it is overwritten.
221
        """
222
        self._lr_multiplier = lr_multiplier
223

224
    def _add_local_lr_multiplier(self, local_lr_multiplier, is_gpu_blob=False):
225
        """
226
        Set the local learning rate multiplier. This local multiplier is
227
        multiplied with the global learning rate multiplier if it exists. As
228
        with the global learning rate multiplier, this multiplier will be
229
        used for all future calls to _run(), so please call
230
        _clear_local_lr_multiplier() at the beginning of the optimizer's _run()
231
        before optionally calling this function.
232
        """
233
        self._local_lr_multiplier = local_lr_multiplier
234
        self._local_lr_multiplier_on_gpu = is_gpu_blob
235

236
    def _clear_local_lr_multiplier(self):
237
        self._local_lr_multiplier = None
238
        self._local_lr_multiplier_on_gpu = False
239

240
    @staticmethod
241
    def dedup(net, sparse_dedup_aggregator, grad):
242
        assert isinstance(
243
            grad, core.GradientSlice
244
        ), "Dedup only works for sparse gradient, got {}".format(grad)
245
        if sparse_dedup_aggregator:
246
            return net.DeduplicateGradientSlices(
247
                grad, aggregator=sparse_dedup_aggregator
248
            )
249
        else:
250
            return grad
251

252
    def get_auxiliary_parameters(self):
253
        """Returns a list of auxiliary parameters.
254

255
        Returns:
256
            aux_params: A namedtuple, AuxParams.
257

258
            aux_params.local stores a list of blobs. Each blob is a local
259
            auxiliary parameter. A local auxiliary parameter is a parameter in
260
            parallel to a learning rate parameter. Take adagrad as an example,
261
            the local auxiliary parameter is the squared sum parameter, because
262
            every learning rate has a squared sum associated with it.
263

264
            aux_params.shared also stores a list of blobs. Each blob is a shared
265
            auxiliary parameter. A shared auxiliary parameter is a parameter
266
            that is shared across all the learning rate parameters. Take adam as
267
            an example, the iteration parameter is a shared parameter, because
268
            all the learning rates share the same iteration parameter.
269
        """
270
        return self._aux_params
271

272
    # TODO(xlwang): In transfer learning, parameter initialized from pretrained
273
    # model might require a different learning rate than otherwise initialized.
274
    # To this end, here we implement a python solution where
275
    # `base_learning_rate` is scaled by `scale`, by calling
276
    # `scale_learning_rate`; Alternatively, we can achieve same effect by
277
    # rewriting the LearningRate operator in C++
278
    # Note that it is the responsibility of specific optimizer to decide what
279
    # logic should be used for `scale_learning_rate`
280
    def scale_learning_rate(self, *args, **kwargs):
281
        raise NotImplementedError(
282
            "Optimizer Need to Implement `scale_learning_rate` method."
283
        )
284

285
    def create_lars_inputs(self, param_init_net, weight_decay, trust, lr_max):
286
        wd = param_init_net.ConstantFill(
287
            [], "weight_decay", shape=[1], value=weight_decay
288
        )
289
        trust = param_init_net.ConstantFill([], "trust", shape=[1], value=trust)
290
        lr_max = param_init_net.ConstantFill([], "lr_max", shape=[1], value=lr_max)
291
        return wd, trust, lr_max
292

293

294
class SgdOptimizer(Optimizer):
295
    def __init__(
296
        self,
297
        base_learning_rate=0.01,
298
        policy="fixed",
299
        momentum=0.0,
300
        nesterov=True,
301
        sparse_dedup_aggregator=None,
302
        lars=None,
303
        **kwargs
304
    ):
305
        super().__init__()
306
        self.base_learning_rate = base_learning_rate
307
        self.policy = policy
308
        self.momentum = momentum
309
        self.nesterov = nesterov
310
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
311
        self.lars = lars
312
        self.init_kwargs = kwargs
313

314
    def _run(self, net, param_init_net, param_info):
315
        param = param_info.blob
316
        grad = param_info.grad
317
        if self.base_learning_rate == 0:
318
            return
319
        assert (
320
            self.base_learning_rate > 0
321
        ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
322

323
        self._clear_local_lr_multiplier()
324

325
        # TODO(zqq): support LARS for sparse parameters
326
        if self.lars is not None and not isinstance(grad, core.GradientSlice):
327
            assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
328
                self.lars
329
            )
330
            wd, trust, lr_max = self.create_lars_inputs(
331
                param_init_net, 0.0, 1.0, np.finfo(np.float32).max
332
            )
333
            lr_lars_multiplier = net.Lars(
334
                [param, grad, wd, trust, lr_max],
335
                self.make_unique_blob_name(str(param) + "_lars"),
336
                offset=self.lars,
337
                lr_min=0.0,
338
            )
339
            current_scope = scope.CurrentDeviceScope()
340
            self._add_local_lr_multiplier(
341
                lr_lars_multiplier,
342
                is_gpu_blob=(
343
                    current_scope is not None
344
                    and core.IsGPUDeviceType(current_scope.device_type)
345
                ),
346
            )
347

348
        # We need negative sign for LR when used directly with WeightedSum
349
        # below.
350
        lr_sign = -1 if self.momentum else 1
351
        lr, _ = self.build_lr(
352
            net,
353
            param_init_net,
354
            base_learning_rate=self.base_learning_rate * lr_sign,
355
            policy=self.policy,
356
            **(self.init_kwargs)
357
        )
358

359
        dev = scope.CurrentDeviceScope()
360
        if dev is None:
361
            dev = core.DeviceOption(caffe2_pb2.CPU)
362

363
        # Each GPU/CPU must have its own ONE blob, thus modify the name
364
        # to include device information.
365
        ONE = param_init_net.ConstantFill(
366
            [],
367
            "ONE_{}_{}{}".format(dev.device_type, dev.device_id, dev.node_name),
368
            shape=[1],
369
            value=1.0,
370
        )
371

372
        self._aux_params.shared.append(ONE)
373

374
        if self.momentum > 0:
375
            momentum_data = param_init_net.ConstantFill(
376
                param, str(param) + "_momentum", value=0.0
377
            )
378
            self._aux_params.local.append(momentum_data)
379

380
        if isinstance(grad, core.GradientSlice):
381
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
382
            if self.momentum > 0.0:
383
                net.SparseMomentumSGDUpdate(
384
                    [grad.values, momentum_data, lr, param, grad.indices],
385
                    [grad.values, momentum_data, param],
386
                    momentum=self.momentum,
387
                    nesterov=self.nesterov,
388
                )
389
            else:
390
                net.ScatterWeightedSum(
391
                    [param, ONE, grad.indices, grad.values, lr], param
392
                )
393
        else:
394
            if self.momentum > 0.0:
395
                net.MomentumSGDUpdate(
396
                    [grad, momentum_data, lr, param],
397
                    [grad, momentum_data, param],
398
                    momentum=self.momentum,
399
                    nesterov=self.nesterov,
400
                )
401
            else:
402
                coeff = lr
403

404
                net.WeightedSum([param, ONE, grad, coeff], param)
405

406
    def scale_learning_rate(self, scale):
407
        self.base_learning_rate *= scale
408
        return
409

410

411
class MultiPrecisionSgdOptimizer(SgdOptimizer):
412
    def __init__(
413
        self,
414
        base_learning_rate=0.1,
415
        momentum=0.0,
416
        policy="fixed",
417
        nesterov=True,
418
        sparse_dedup_aggregator=None,
419
        **kwargs
420
    ):
421
        super().__init__(
422
            base_learning_rate=base_learning_rate,
423
            policy=policy,
424
            momentum=momentum,
425
            nesterov=nesterov,
426
            sparse_dedup_aggregator=sparse_dedup_aggregator,
427
            **kwargs
428
        )
429

430
    def _run(self, net, param_init_net, param_info):
431
        param = param_info.blob
432
        param_fp32 = (
433
            param_info.blob_copy[core.DataType.FLOAT]
434
            if param_info.blob_copy is not None
435
            else None
436
        )
437

438
        # If we have a straight fp32 parameter, run the base class
439
        if param_fp32 is None:
440
            return SgdOptimizer._run(self, net, param_init_net, param_info)
441

442
        grad = param_info.grad
443
        if self.base_learning_rate == 0:
444
            return
445
        assert (
446
            self.base_learning_rate > 0
447
        ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
448

449
        lr, _ = self.build_lr(
450
            net,
451
            param_init_net,
452
            base_learning_rate=-self.base_learning_rate,
453
            policy=self.policy,
454
            **(self.init_kwargs)
455
        )
456

457
        momentum_data = param_init_net.ConstantFill(
458
            param_fp32, str(param) + "_momentum", value=0.0
459
        )
460
        self._aux_params.local.append(momentum_data)
461

462
        assert not isinstance(
463
            grad, core.GradientSlice
464
        ), "MultiPrecisionSgd does not support sparse gradients"
465

466
        # Copy gradient to fp32
467
        grad_fp32 = net.HalfToFloat(grad, grad + "_fp32")
468

469
        # update (fused) in fp32
470
        net.MomentumSGDUpdate(
471
            [grad_fp32, momentum_data, lr, param_fp32],
472
            [grad_fp32, momentum_data, param_fp32],
473
            momentum=self.momentum,
474
            nesterov=self.nesterov,
475
        )
476

477
        # Copy updated param back to fp16
478
        net.FloatToHalf(param_fp32, param)
479

480

481
class FP16SgdOptimizer(SgdOptimizer):
482
    def __init__(
483
        self,
484
        base_learning_rate=0.1,
485
        momentum=0.0,
486
        policy="fixed",
487
        nesterov=True,
488
        weight_decay=0.0001,
489
        sparse_dedup_aggregator=None,
490
        **kwargs
491
    ):
492
        super().__init__(
493
            base_learning_rate=base_learning_rate,
494
            policy=policy,
495
            momentum=momentum,
496
            nesterov=nesterov,
497
            sparse_dedup_aggregator=sparse_dedup_aggregator,
498
            **kwargs
499
        )
500
        self.weight_decay = weight_decay
501

502
    def _run(self, net, param_init_net, param_info, fp32_update=False):
503

504
        fp32_update_flag = 0
505
        param_name = str(param_info.blob)
506

507
        # should only be triggered in FP16 training by SpatialBN, which
508
        # requires FP32 params in CuDNN.
509
        if param_name.find("spatbn") != -1:
510
            fp32_update = True
511

512
        if fp32_update:
513
            # doing a 32bit update
514
            # Have to assume param_info.blob is FP32 as there is no way
515
            # (that i currently know of) to query a blob's type in python
516
            fp32_update_flag = 1
517
            param = param_info.blob
518
            param_fp32 = param_info.blob
519
        else:
520
            if param_info.blob_copy is None:
521
                # doing a 32bit update
522
                # Have to assume param_info.blob is FP32 as there is no way
523
                # (that i currently know of) to query a blob's type in python
524
                fp32_update_flag = 1
525
                param = param_info.blob
526
                param_fp32 = param_info.blob
527
            else:
528
                if core.DataType.FLOAT in param_info.blob_copy:
529
                    param = param_info.blob
530
                    param_fp32 = param_info.blob_copy[core.DataType.FLOAT]
531
                elif core.DataType.FLOAT16 in param_info.blob_copy:
532
                    param = param_info.blob_copy[core.DataType.FLOAT16]
533
                    param_fp32 = param_info.blob
534
                else:
535
                    AssertionError(
536
                        "Unrecognized parameter format to be updated "
537
                        "by FP16 Optimizer. Parameter: {}".format(param_info.name)
538
                    )
539

540
        grad = param_info.grad
541

542
        if self.base_learning_rate == 0:
543
            return
544
        assert (
545
            self.base_learning_rate > 0
546
        ), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
547

548
        lr, _ = self.build_lr(
549
            net,
550
            param_init_net,
551
            base_learning_rate=-self.base_learning_rate,
552
            policy=self.policy,
553
            **(self.init_kwargs)
554
        )
555

556
        momentum_data_fp32 = param_init_net.ConstantFill(
557
            param_fp32, str(param) + "_momentum_fp32", value=0.0
558
        )
559

560
        momentum_data = param_init_net.FloatToHalf(
561
            momentum_data_fp32, str(param) + "_momentum"
562
        )
563

564
        self._aux_params.local.append(momentum_data)
565

566
        assert not isinstance(
567
            grad, core.GradientSlice
568
        ), "FP16Sgd does not support sparse gradients"
569

570
        if fp32_update_flag == 0:
571
            net.FP16MomentumSGDUpdate(
572
                [grad, momentum_data, lr, param],
573
                [grad, momentum_data, param],
574
                momentum=self.momentum,
575
                nesterov=self.nesterov,
576
                weight_decay=self.weight_decay,
577
            )
578
        else:
579
            # flag set to 1, therefore doing FP32 update
580
            net.FP32MomentumSGDUpdate(
581
                [grad, momentum_data_fp32, lr, param],
582
                [grad, momentum_data_fp32, param],
583
                momentum=self.momentum,
584
                nesterov=self.nesterov,
585
                weight_decay=self.weight_decay,
586
            )
587

588

589
class WeightDecayBuilder(Optimizer):
590
    def __init__(self, weight_decay):
591
        self.weight_decay = weight_decay
592

593
    def _run(self, net, param_init_net, param_info):
594
        dev = scope.CurrentDeviceScope()
595
        if dev is None:
596
            dev = core.DeviceOption(caffe2_pb2.CPU)
597

598
        ONE = param_init_net.ConstantFill(
599
            [], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
600
        )
601
        WD = param_init_net.ConstantFill(
602
            [],
603
            "wd_{}_{}".format(dev.device_type, dev.device_id),
604
            shape=[1],
605
            value=self.weight_decay,
606
        )
607

608
        if isinstance(param_info.grad, core.GradientSlice):
609
            raise ValueError("Weight decay does not yet support sparse gradients")
610
        else:
611
            net.WeightedSum(
612
                [param_info.grad, ONE, param_info.blob, WD], param_info.grad
613
            )
614

615

616
class AdagradOptimizer(Optimizer):
617
    def __init__(
618
        self,
619
        alpha=0.01,
620
        epsilon=1e-4,
621
        decay=1,
622
        weight_decay=0.0,
623
        policy="fixed",
624
        sparse_dedup_aggregator=None,
625
        rowWise=False,
626
        engine="",
627
        lars=None,
628
        output_effective_lr=False,
629
        output_effective_lr_and_update=False,
630
        pruning_options=None,
631
        swa_options=None,
632
        ema_options=None,
633
        weight_scale=None,
634
        counter_halflife=-1,
635
        use_dedicated_lr_iteration_counter=False,
636
        **kwargs
637
    ):
638
        super().__init__()
639
        self.alpha = alpha
640
        self.epsilon = epsilon
641
        self.decay = decay
642
        self.weight_decay = float(weight_decay)
643
        self.policy = policy
644
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
645
        self.rowWise = rowWise
646
        self.engine = engine
647
        self.lars = lars
648
        self.output_effective_lr = output_effective_lr
649
        self.output_effective_lr_and_update = output_effective_lr_and_update
650
        self.counter_halflife = counter_halflife
651
        self.init_kwargs = kwargs
652
        self.weight_scale = weight_scale
653
        self.use_dedicated_lr_iteration_counter = use_dedicated_lr_iteration_counter
654

655
        self._process_pruning_options(pruning_options)
656
        self._process_swa_options(swa_options)
657
        self._process_ema_options(ema_options)
658

659
    def set_mapping_for_param2ema_teacher_param(self, param_mapping: Dict[str, Any]) -> None:
660
        self.param2ema_teacher_param = param_mapping
661

662
    def _process_swa_options(self, swa_options):
663
        self.swa_enabled = True if swa_options else False
664
        if self.swa_enabled:
665
            self.swa_avg_start_it = swa_options.get("swa_avg_start_it", None)
666
            self.swa_avg_end_it = swa_options.get("swa_avg_end_it", None)
667
            self.swa_feedback_start_it = swa_options.get("swa_feedback_start_it", None)
668
            self.swa_feedback_step = swa_options.get("swa_feedback_step", None)
669
            self.swa_feedback_end_it = swa_options.get("swa_feedback_end_it", None)
670

671
    def _process_ema_options(self, ema_options):
672
        logger.info(f"ema_options: {str(ema_options)}")
673
        self.ema_enabled = ema_options and ema_options.get("ema_alpha", None) is not None
674
        self.ema_teacher_enabled = ema_options and ema_options.get("ema_teacher_alpha", None) is not None
675
        self.param2ema_teacher_param = {}
676
        if self.ema_enabled or self.ema_teacher_enabled:
677
            self.ema_start = ema_options.get("ema_start", None)
678
            self.ema_end = ema_options.get("ema_end", None)
679
            self.ema_step = ema_options.get("ema_step", None)
680
            self.ema_alpha = ema_options.get("ema_alpha", None)
681
            self.ema_teacher_alpha = ema_options.get("ema_teacher_alpha", None)
682
            self.ema_teacher_module_name = ema_options.get(
683
                "ema_teacher_module_name", "ema_teacher_arch"
684
            )
685

686
    def _process_pruning_options(self, pruning_options):
687
        self.use_mask = False
688

689
        if pruning_options is None:
690
            pruning_options = {}
691
        else:
692
            assert isinstance(pruning_options, dict), (
693
                "pruning_options can only "
694
                "be provided as a dictionary, currently: {}".format(pruning_options)
695
            )
696

697
        self.mask_tensor = pruning_options.get("mask_tensor", None)
698
        self.mask_db_path = pruning_options.get("mask_db_path", None)
699
        self.mask_db_type = pruning_options.get("mask_db_type", None)
700
        self.mask_blob_name = pruning_options.get("mask_blob_name", None)
701
        self.prune_delays = pruning_options.get("prune_delays", [])
702
        self.prune_ratios = pruning_options.get("prune_ratios", [])
703
        self.prune_block_size = pruning_options.get("prune_block_size", 1)
704

705
        if self.mask_tensor is not None:
706
            assert (
707
                type(self.mask_tensor) is np.ndarray
708
            ), "mask_tensor must be a numpy array!"
709
            assert self.mask_db_path is None, (
710
                "mask can be provided through either a numpy array "
711
                "or a db path, not both"
712
            )
713
            assert self.mask_db_type is None, (
714
                "mask can be provided through either a numpy array "
715
                "or a db path, not both"
716
            )
717
            assert self.mask_blob_name is None, (
718
                "mask can be provided through either a numpy array "
719
                "or a db path, not both"
720
            )
721
            self.use_mask = True
722

723
        if self.mask_db_path is not None or self.mask_db_type is not None:
724
            assert self.mask_db_path is not None, (
725
                "when mask is provided through db, "
726
                "db path, db type, and blob name are all needed"
727
            )
728
            assert self.mask_db_type is not None, (
729
                "when mask is provided through db, "
730
                "db path, db type, and blob name are all needed"
731
            )
732
            assert self.mask_tensor is None, (
733
                "mask can be provided through either a numpy array "
734
                "or a db path, not both"
735
            )
736
            self.use_mask = True
737

738
        if self.prune_delays:
739
            assert self.prune_ratios is not None and len(self.prune_delays) == len(
740
                self.prune_ratios
741
            ), "Prune Delays and prune ratios should be of the same length"
742
            assert (
743
                self.mask_tensor is None
744
            ), "Mask Tensor should be None with prune ratios"
745
            assert (
746
                self.mask_db_path is None
747
            ), "Mask DB Path should be None with prune ratios"
748
            self.use_mask = True
749

750
    def _run(self, net, param_init_net, param_info):
751
        param = param_info.blob
752
        grad = param_info.grad
753

754
        if self.alpha <= 0:
755
            return
756

757
        self._clear_local_lr_multiplier()
758

759
        if self.lars is not None and not isinstance(grad, core.GradientSlice):
760
            assert (
761
                self.weight_decay == 0
762
            ), "weight decay is not implemented for LARS yet"
763
            assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
764
                self.lars
765
            )
766
            wd, trust, lr_max = self.create_lars_inputs(
767
                param_init_net, 0.0, 1.0, np.finfo(np.float32).max
768
            )
769
            lr_lars_multiplier = net.Lars(
770
                [param, grad, wd, trust, lr_max],
771
                self.make_unique_blob_name(str(param) + "_lars"),
772
                offset=self.lars,
773
                lr_min=0.0,
774
            )
775

776
            current_scope = scope.CurrentDeviceScope()
777
            self._add_local_lr_multiplier(
778
                lr_lars_multiplier,
779
                is_gpu_blob=(
780
                    current_scope is not None
781
                    and core.IsGPUDeviceType(current_scope.device_type)
782
                ),
783
            )
784

785
        lr, lr_iteration = self.build_lr(
786
            net,
787
            param_init_net,
788
            base_learning_rate=self.alpha,
789
            policy=self.policy,
790
            **(self.init_kwargs)
791
        )
792
        iteration = (
793
            self.build_non_lr_iter(net, param_init_net, iter_val=0)
794
            if self._use_dedicated_lr_iteration_counter
795
            else lr_iteration
796
        )
797

798
        if self.counter_halflife > 0:
799
            self._aux_params.shared.append(iteration)
800

801
        if self.rowWise:
802
            logger.debug(
803
                "Using engine {} for rowWise Adagrad to train param {}".format(
804
                    self.engine, param
805
                )
806
            )
807

808
            shapes, types = workspace.InferShapesAndTypes([param_init_net])
809
            if str(param) not in shapes:
810
                # Type/shape inference is not available for this param, fallback
811
                # on Shape/Slice logic
812
                shape = param_init_net.Shape(param, str(param) + "_shape")
813
                num_rows = param_init_net.Slice(
814
                    [shape], str(shape) + "_numrows", starts=[0], ends=[1]
815
                )
816
                param_squared_sum = param_init_net.ConstantFill(
817
                    num_rows,
818
                    str(param) + "_avg_squared_sum",
819
                    input_as_shape=1,
820
                    value=0.0,
821
                )
822
            else:
823
                param_squared_sum = param_init_net.ConstantFill(
824
                    [],
825
                    str(param) + "_avg_squared_sum",
826
                    shape=[shapes[str(param)][0]],
827
                    value=0.0,
828
                )
829
        else:
830
            logger.debug(
831
                "Using engine {} for regular Adagrad to train param {}".format(
832
                    self.engine, param
833
                )
834
            )
835

836
            if self.engine in FP16_ENGINES:
837
                assert (
838
                    self.weight_decay == 0
839
                ), "weight decay is not tested for engine: {}".format(self.engine)
840

841
                shapes, types = workspace.InferShapesAndTypes([param_init_net])
842
                assert str(param) in shapes, shapes
843
                shape = shapes[str(param)]
844

845
                param_squared_sum = param_init_net.Float16ConstantFill(
846
                    [], str(param) + "_squared_sum", value=0.0, shape=shape
847
                )
848
            else:
849
                param_squared_sum = param_init_net.ConstantFill(
850
                    [param], str(param) + "_squared_sum", value=0.0
851
                )
852

853
        if self.use_mask is True:
854
            assert (
855
                self.weight_decay == 0
856
            ), "weight decay is not implemented for use_mask yet"
857

858
            if self.mask_tensor is not None:
859
                if not isinstance(grad, core.GradientSlice):
860
                    mask_blob = param_init_net.GivenTensorFill(
861
                        [],
862
                        [str(param) + "_mask"],
863
                        values=self.mask_tensor,
864
                        shape=self.mask_tensor.shape,
865
                    )
866
                else:
867
                    self.mask_tensor = self.mask_tensor.astype(np.uint8)
868
                    mask_blob = param_init_net.GivenTensorBoolFill(
869
                        [],
870
                        [str(param) + "_mask"],
871
                        values=self.mask_tensor,
872
                        shape=self.mask_tensor.shape,
873
                    )
874
                    mask_blob = param_init_net.Cast(mask_blob, to=core.DataType.UINT8)
875
                    mask_changed_blob = param_init_net.ConstantFill(
876
                        [],
877
                        [str(param) + "_mask_changed_blob"],
878
                        value=False,
879
                        dtype=core.DataType.BOOL,
880
                        shape=[1],
881
                    )
882
            elif (
883
                self.mask_db_path is not None or self.mask_db_type is not None
884
            ):  # mask is provided through a db file
885
                # if mask_blob_name is not given use the param name to derive mask name
886
                self.mask_blob_name = self.mask_blob_name or str(param) + "_mask"
887

888
                mask_blob = param_init_net.Load(
889
                    [],
890
                    self.mask_blob_name,
891
                    db=self.mask_db_path,
892
                    db_type=self.mask_db_type,
893
                    absolute_path=True,
894
                )
895

896
                if isinstance(grad, core.GradientSlice):
897
                    mask_changed_blob = param_init_net.ConstantFill(
898
                        [],
899
                        [str(param) + "_mask_changed_blob"],
900
                        value=False,
901
                        dtype=core.DataType.BOOL,
902
                        shape=[1],
903
                    )
904
            elif self.prune_delays:
905
                last_mask_updated_iter = param_init_net.ConstantFill(
906
                    [],
907
                    [str(param) + "_last_mask_updated_iter"],
908
                    value=-1,
909
                    dtype=core.DataType.INT64,
910
                    shape=[1],
911
                )
912

913
                if isinstance(grad, core.GradientSlice):
914
                    AssertionError(
915
                        "Prune Delays and Prune Ratios are currently not supported"
916
                        "for sparse operators"
917
                    )
918
                else:
919
                    mask_blob = param_init_net.GivenTensorFill(
920
                        [],
921
                        [str(param) + "_empty_mask"],
922
                        values=[],
923
                        dtype=core.DataType.FLOAT,
924
                        shape=[0],
925
                    )
926
            else:
927
                raise NotImplementedError(
928
                    "If mask is used, it needs a numpy array or a db file or"
929
                    "a delay iter needs to be provided"
930
                )
931

932
        self._aux_params.local.append(param_squared_sum)
933
        if self.counter_halflife > 0:
934
            shapes, types = workspace.InferShapesAndTypes([param_init_net])
935
            if str(param) not in shapes:
936
                shape = param_init_net.Shape(param, str(param) + "_shape")
937
                num_rows = param_init_net.Slice(
938
                    [shape], str(shape) + "_numrows", starts=[0], ends=[1]
939
                )
940
                update_counter = param_init_net.ConstantFill(
941
                    num_rows,
942
                    str(param) + "_update_counter",
943
                    input_as_shape=1,
944
                    value=0.0,
945
                    dtype=core.DataType.DOUBLE,
946
                )
947
                prev_update_iter = param_init_net.ConstantFill(
948
                    num_rows,
949
                    str(param) + "_prev_update_iter",
950
                    input_as_shape=1,
951
                    value=0,
952
                    dtype=core.DataType.INT64,
953
                )
954
            else:
955
                update_counter = param_init_net.ConstantFill(
956
                    [],
957
                    str(param) + "_update_counter",
958
                    shape=[shapes[str(param)][0]],
959
                    value=0.0,
960
                    dtype=core.DataType.DOUBLE,
961
                )
962
                prev_update_iter = param_init_net.ConstantFill(
963
                    [],
964
                    str(param) + "_prev_update_iter",
965
                    shape=[shapes[str(param)][0]],
966
                    value=0,
967
                    dtype=core.DataType.INT64,
968
                )
969
            self._aux_params.local.append(update_counter)
970
            self._aux_params.local.append(prev_update_iter)
971

972
        if self.rowWise:
973
            assert isinstance(grad, core.GradientSlice), (
974
                "If SparseAdagrad with rowWise=True, gradient must be "
975
                "a gradientslice. PLease ensure that rowWise is not enabled "
976
                "for the dense Adagrad optimizer, as it is not supported."
977
            )
978

979
        shapes, _ = workspace.InferShapesAndTypes([param_init_net])
980
        param_shape = shapes[str(param)]
981
        weight_decay = 0.0
982
        if isinstance(grad, core.GradientSlice):
983
            if len(param_shape) == 1:
984
                weight_decay = 0.0
985
                logger.warn(
986
                    "SKIPPING weight decay on 1d sparse param: {}.shape is {}".format(
987
                        str(param), param_shape
988
                    )
989
                )
990
            else:
991
                weight_decay = self.weight_decay
992
        else:
993
            # Skip weight decay for 1d parameters
994
            if len(param_shape) == 1:
995
                weight_decay = 0.0
996
                logger.warning(
997
                    "SKIPPING weight decay on 1d dense param: {}.shape is {}".format(
998
                        str(param), param_shape
999
                    )
1000
                )
1001
            else:
1002
                weight_decay = self.weight_decay
1003
        logger.debug(
1004
            "weight_decay for {} (shape:{}): {}".format(
1005
                str(param), param_shape, weight_decay
1006
            )
1007
        )
1008

1009
        if isinstance(grad, core.GradientSlice):
1010
            assert (
1011
                self.decay == 1.0
1012
            ), "Decay is not implemented for SparseAdagrad and must be set to 1"
1013
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1014

1015
            input_args = [param, param_squared_sum, grad.indices, grad.values, lr]
1016
            output_args = [param, param_squared_sum]
1017
            if self.rowWise:
1018
                if self.use_mask is True:
1019
                    op = "MaskedRowWiseSparseAdagrad"
1020
                    assert (
1021
                        weight_decay == 0
1022
                    ), "weight decay is not implemented for {} yet".format(op)
1023
                    input_args += [mask_blob, mask_changed_blob]
1024
                else:
1025
                    if self.counter_halflife > 0:
1026
                        input_args += [update_counter]
1027
                    op = "RowWiseSparseAdagrad"
1028
            else:
1029
                if self.use_mask is True:
1030
                    op = "MaskedSparseAdagrad"
1031
                    assert (
1032
                        weight_decay == 0
1033
                    ), "weight decay is not implemented for {} yet".format(op)
1034
                    input_args += [mask_blob, mask_changed_blob]
1035
                else:
1036
                    op = "SparseAdagrad"
1037
            logger.debug("using {} for {}".format(op, str(param)))
1038

1039
            if self.prune_delays:
1040
                input_args += [iteration, last_mask_updated_iter]
1041
                output_args += [mask_blob, last_mask_updated_iter]
1042

1043
            if weight_decay > 0 and self.counter_halflife == -1:
1044
                net.__getattr__(op)(
1045
                    input_args,
1046
                    output_args,
1047
                    epsilon=self.epsilon,
1048
                    weight_decay=weight_decay,
1049
                    engine=self.engine,
1050
                )
1051
            elif weight_decay > 0 and self.counter_halflife != -1:
1052
                net.__getattr__(op)(
1053
                    input_args,
1054
                    output_args,
1055
                    epsilon=self.epsilon,
1056
                    weight_decay=weight_decay,
1057
                    engine=self.engine,
1058
                    counter_halflife=self.counter_halflife,
1059
                )
1060
            else:
1061
                net.__getattr__(op)(
1062
                    input_args, output_args, epsilon=self.epsilon, engine=self.engine
1063
                )
1064
            if self.counter_halflife > 0:
1065
                net.RowWiseCounter(
1066
                    [prev_update_iter, update_counter, grad.indices, iteration],
1067
                    [prev_update_iter, update_counter],
1068
                    counter_halflife=self.counter_halflife,
1069
                )
1070
        else:
1071
            input_args = [param, param_squared_sum, grad, lr]
1072
            output_args = [param, param_squared_sum]
1073

1074
            if self.output_effective_lr_and_update:
1075
                assert (
1076
                    self.use_mask is False
1077
                ), "MaskedAdagrad doesn't support outputting effective_lr_and_update"
1078
                output_args.append(str(param) + "_effective_lr")
1079
                output_args.append(str(param) + "_update")
1080
            elif self.output_effective_lr:
1081
                assert (
1082
                    self.use_mask is False
1083
                ), "MaskedAdagrad doesn't support outputting effective_lr"
1084
                output_args.append(str(param) + "_effective_lr")
1085

1086
            if self.use_mask is True:
1087
                input_args += [mask_blob]
1088

1089
            if self.prune_delays:
1090
                input_args += [iteration, last_mask_updated_iter]
1091
                output_args += [mask_blob, last_mask_updated_iter]
1092

1093
            if self.use_mask:
1094
                assert (
1095
                    weight_decay == 0
1096
                ), "weight decay is not implemented for use_mask yet"
1097
                net.MaskedAdagrad(
1098
                    input_args,
1099
                    output_args,
1100
                    epsilon=self.epsilon,
1101
                    decay=float(self.decay),
1102
                    block_size=self.prune_block_size,
1103
                    delays=self.prune_delays,
1104
                    prune_ratios=self.prune_ratios,
1105
                    engine=self.engine,
1106
                )
1107
            else:
1108
                if weight_decay > 0:
1109
                    net.Adagrad(
1110
                        input_args,
1111
                        output_args,
1112
                        epsilon=self.epsilon,
1113
                        decay=float(self.decay),
1114
                        weight_decay=weight_decay,
1115
                        engine=self.engine,
1116
                    )
1117
                else:
1118
                    net.Adagrad(
1119
                        input_args,
1120
                        output_args,
1121
                        epsilon=self.epsilon,
1122
                        decay=float(self.decay),
1123
                        engine=self.engine,
1124
                    )
1125

1126
                if self.swa_enabled:
1127
                    param_swa = str(param) + "_swa"
1128
                    if not param_init_net.BlobIsDefined(param_swa):
1129
                        param_init_net.ConstantFill([param], param_swa, value=0.0)
1130
                        self._aux_params.local.append(param_swa)
1131

1132
                    net.SWA(
1133
                        [param, param_swa, iteration],
1134
                        [param, param_swa],
1135
                        avg_start=self.swa_avg_start_it,
1136
                        avg_end=self.swa_avg_end_it,
1137
                        feedback_start=self.swa_feedback_start_it,
1138
                        feedback_step=self.swa_feedback_step,
1139
                        feedback_end=self.swa_feedback_end_it,
1140
                    )
1141

1142
        if self.ema_enabled:
1143
            param_ema = str(param) + "_ema"
1144
            if not param_init_net.BlobIsDefined(param_ema):
1145
                param_init_net.ConstantFill([param], param_ema, value=0.0)
1146
                self._aux_params.local.append(param_ema)
1147

1148
            net.EMA(
1149
                [param, param_ema, iteration],
1150
                [param, param_ema],
1151
                ema_start=self.ema_start,
1152
                ema_end=self.ema_end,
1153
                ema_step=self.ema_step,
1154
                ema_alpha=self.ema_alpha,
1155
            )
1156

1157

1158
        if self.ema_teacher_enabled:
1159
            if param in self.param2ema_teacher_param:
1160
                param_ema_teacher = self.param2ema_teacher_param[param]
1161
                if not param_init_net.BlobIsDefined(param_ema_teacher):
1162
                    param_init_net.ConstantFill([param], param_ema_teacher, value=0.0)
1163
                    self._aux_params.local.append(param_ema_teacher)
1164

1165
                net.EMA(
1166
                    [param, param_ema_teacher, iteration],
1167
                    [param, param_ema_teacher],
1168
                    ema_start=self.ema_start,
1169
                    ema_end=self.ema_end,
1170
                    ema_step=self.ema_step,
1171
                    ema_alpha=self.ema_teacher_alpha,
1172
                )
1173

1174
        if self.weight_scale:
1175
            net.WeightScale(
1176
                [param, iteration],
1177
                [param],
1178
                stepsize=self.weight_scale.stepsize,
1179
                upper_bound_iter=self.weight_scale.upper_bound_iter,
1180
                scale=float(self.weight_scale.scale),
1181
            )
1182
            if self.weight_scale.to_aux:
1183
                net.WeightScale(
1184
                    [param_squared_sum, iteration],
1185
                    [param_squared_sum],
1186
                    stepsize=self.weight_scale.stepsize,
1187
                    upper_bound_iter=self.weight_scale.upper_bound_iter,
1188
                    scale=float(self.weight_scale.scale),
1189
                )
1190

1191
    def scale_learning_rate(self, scale):
1192
        self.alpha *= scale
1193
        return
1194

1195

1196
class WngradOptimizer(Optimizer):
1197
    def __init__(
1198
        self,
1199
        alpha=1.0,
1200
        epsilon=1e-9,
1201
        policy="fixed",
1202
        sparse_dedup_aggregator=None,
1203
        engine="",
1204
        moment_init=100.0,
1205
        lars=None,
1206
        output_effective_lr=False,
1207
        output_effective_lr_and_update=False,
1208
        **kwargs
1209
    ):
1210
        super().__init__()
1211
        self.alpha = alpha
1212
        self.epsilon = epsilon
1213
        self.policy = policy
1214
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1215
        self.engine = engine
1216
        self.moment_init = moment_init
1217
        self.lars = lars
1218
        self.output_effective_lr = output_effective_lr
1219
        self.output_effective_lr_and_update = output_effective_lr_and_update
1220
        self.init_kwargs = kwargs
1221

1222
    def _run(self, net, param_init_net, param_info):
1223
        param = param_info.blob
1224
        grad = param_info.grad
1225

1226
        if self.alpha <= 0:
1227
            return
1228

1229
        self._clear_local_lr_multiplier()
1230

1231
        if self.lars is not None and not isinstance(grad, core.GradientSlice):
1232
            assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
1233
                self.lars
1234
            )
1235
            wd, trust, lr_max = self.create_lars_inputs(
1236
                param_init_net, 0.0, 1.0, np.finfo(np.float32).max
1237
            )
1238
            lr_lars_multiplier = net.Lars(
1239
                [param, grad, wd, trust, lr_max],
1240
                self.make_unique_blob_name(str(param) + "_lars"),
1241
                offset=self.lars,
1242
                lr_min=0.0,
1243
            )
1244
            current_scope = scope.CurrentDeviceScope()
1245
            self._add_local_lr_multiplier(
1246
                lr_lars_multiplier,
1247
                is_gpu_blob=(
1248
                    current_scope is not None
1249
                    and core.IsGPUDeviceType(current_scope.device_type)
1250
                ),
1251
            )
1252

1253
        lr, _ = self.build_lr(
1254
            net,
1255
            param_init_net,
1256
            base_learning_rate=self.alpha,
1257
            policy=self.policy,
1258
            **(self.init_kwargs)
1259
        )
1260

1261
        moment = param_init_net.ConstantFill(
1262
            [], str(param) + "_moment", shape=[1], value=self.moment_init
1263
        )
1264

1265
        self._aux_params.local.append(moment)
1266

1267
        if isinstance(grad, core.GradientSlice):
1268
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1269
            net.SparseWngrad(
1270
                [param, moment, grad.indices, grad.values, lr],
1271
                [param, moment],
1272
                epsilon=self.epsilon,
1273
                engine=self.engine,
1274
            )
1275
        else:
1276
            output_args = [param, moment]
1277
            if self.output_effective_lr_and_update:
1278
                output_args.append(str(param) + "_effective_lr")
1279
                output_args.append(str(param) + "_update")
1280
            elif self.output_effective_lr:
1281
                output_args.append(str(param) + "_effective_lr")
1282

1283
            net.Wngrad(
1284
                [param, moment, grad, lr],
1285
                output_args,
1286
                epsilon=self.epsilon,
1287
                engine=self.engine,
1288
            )
1289

1290
    def scale_learning_rate(self, scale):
1291
        self.alpha *= scale
1292
        return
1293

1294

1295
class StormOptimizer(Optimizer):
1296
    def __init__(
1297
        self,
1298
        lr=0.1,
1299
        momentum=10.0,
1300
        beta=0.1,
1301
        grad_sq_init=0.01,
1302
        policy="fixed",
1303
        sparse_dedup_aggregator=None,
1304
        lars=None,
1305
        **kwargs
1306
    ):
1307
        """Constructor function to add STORM Optimizer
1308

1309
        Args:
1310
            lr: learning rate scaling (called k in the original paper)
1311
            momentum: momentum scaling (called c in the original paper)
1312
            beta: initial value of denominator in adaptive learning rate (
1313
              called c in the original paper)
1314
            grad_sq_init: initial value of gradient squared accumulator.
1315
            policy: specifies how learning rate should be applied, options are
1316
              'fixed', 'step', 'exp', etc.
1317
            sparse_dedup_aggregator: specifies deduplication strategy for
1318
              gradient slices. Works while using sparse gradients. Options
1319
              include 'mean' and 'sum'.
1320
            lars: lars offset.
1321
        """
1322
        super().__init__()
1323
        self.lr = lr
1324
        self.momentum = momentum
1325
        self.beta = beta
1326
        self.grad_sq_init = grad_sq_init
1327
        self.policy = policy
1328
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1329
        self.lars = lars
1330
        self.init_kwargs = kwargs
1331

1332
    def _run(self, net, param_init_net, param_info):
1333
        param = param_info.blob
1334
        grad = param_info.grad
1335

1336
        if self.lr <= 0:
1337
            return
1338

1339
        self._clear_local_lr_multiplier()
1340

1341
        if self.lars is not None and not isinstance(grad, core.GradientSlice):
1342
            assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
1343
                self.lars
1344
            )
1345
            wd, trust, lr_max = self.create_lars_inputs(
1346
                param_init_net, 0.0, 1.0, np.finfo(np.float32).max
1347
            )
1348
            lr_lars_multiplier = net.Lars(
1349
                [param, grad, wd, trust, lr_max],
1350
                self.make_unique_blob_name(str(param) + "_lars"),
1351
                offset=self.lars,
1352
                lr_min=0.0,
1353
            )
1354
            current_scope = scope.CurrentDeviceScope()
1355
            self._add_local_lr_multiplier(
1356
                lr_lars_multiplier,
1357
                is_gpu_blob=(
1358
                    current_scope is not None
1359
                    and core.IsGPUDeviceType(current_scope.device_type)
1360
                ),
1361
            )
1362

1363
        lr, _ = self.build_lr(
1364
            net,
1365
            param_init_net,
1366
            base_learning_rate=self.lr,
1367
            policy=self.policy,
1368
            **(self.init_kwargs)
1369
        )
1370

1371
        moment = param_init_net.ConstantFill(param, str(param) + "_moment", value=0.0)
1372
        self._aux_params.local.append(moment)
1373

1374
        grad_sq_sum = param_init_net.ConstantFill(
1375
            [], str(param) + "_grad_sq_sum", shape=[1], value=self.grad_sq_init
1376
        )
1377
        self._aux_params.local.append(grad_sq_sum)
1378

1379
        if isinstance(grad, core.GradientSlice):
1380
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1381
            net.SparseStorm(
1382
                [param, moment, grad_sq_sum, grad.values, grad.indices, lr],
1383
                [param, moment, grad_sq_sum],
1384
                momentum=self.momentum,
1385
                beta=self.beta,
1386
            )
1387
        else:
1388
            net.Storm(
1389
                [param, moment, grad_sq_sum, grad, lr],
1390
                [param, moment, grad_sq_sum],
1391
                momentum=self.momentum,
1392
                beta=self.beta,
1393
            )
1394

1395
    def scale_learning_rate(self, scale):
1396
        self.lr *= scale
1397

1398

1399
class AdadeltaOptimizer(Optimizer):
1400
    def __init__(
1401
        self,
1402
        alpha=0.01,
1403
        epsilon=1e-4,
1404
        decay=0.95,
1405
        policy="fixed",
1406
        sparse_dedup_aggregator=None,
1407
        engine="",
1408
        **kwargs
1409
    ):
1410
        """Constructor function to add Adadelta Optimizer
1411

1412
        Args:
1413
            alpha: learning rate
1414
            epsilon: attribute of Adadelta to avoid numerical issues
1415
            decay: attribute of Adadelta to decay the squared gradient sum
1416
            policy: specifies how learning rate should be applied, options are
1417
              "fixed", "step", "exp", etc.
1418
            sparse_dedup_aggregator: specifies deduplication strategy for
1419
              gradient slices. Works while using sparse gradients. Options
1420
              include "mean" and "sum".
1421
            engine: the engine used, options include "", "CUDNN", etc.
1422
        """
1423
        super().__init__()
1424
        self.alpha = alpha
1425
        self.epsilon = epsilon
1426
        self.decay = decay
1427
        self.policy = policy
1428
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1429
        self.engine = engine
1430
        self.init_kwargs = kwargs
1431

1432
    def _run(self, net, param_init_net, param_info):
1433
        param = param_info.blob
1434
        grad = param_info.grad
1435

1436
        if self.alpha <= 0:
1437
            return
1438

1439
        lr, _ = self.build_lr(
1440
            net,
1441
            param_init_net,
1442
            base_learning_rate=self.alpha,
1443
            policy=self.policy,
1444
            **(self.init_kwargs)
1445
        )
1446

1447
        moment = param_init_net.ConstantFill(
1448
            [param], str(param) + "_squared_moment", value=0.0
1449
        )
1450

1451
        moment_update = param_init_net.ConstantFill(
1452
            [param], str(param) + "_squared_moment_update", value=0.0
1453
        )
1454

1455
        self._aux_params.local.append(moment)
1456
        self._aux_params.local.append(moment_update)
1457

1458
        if isinstance(grad, core.GradientSlice):
1459
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1460
            net.SparseAdadelta(
1461
                [param, moment, moment_update, grad.indices, grad.values, lr],
1462
                [param, moment, moment_update],
1463
                epsilon=self.epsilon,
1464
                decay=self.decay,
1465
                engine=self.engine,
1466
            )
1467
        else:
1468
            net.Adadelta(
1469
                [param, moment, moment_update, grad, lr],
1470
                [param, moment, moment_update],
1471
                epsilon=self.epsilon,
1472
                decay=self.decay,
1473
                engine=self.engine,
1474
            )
1475

1476
    def scale_learning_rate(self, scale):
1477
        self.alpha *= scale
1478
        return
1479

1480

1481
class FtrlOptimizer(Optimizer):
1482
    def __init__(
1483
        self,
1484
        alpha=0.01,
1485
        beta=1e-4,
1486
        lambda1=0,
1487
        lambda2=0,
1488
        sparse_dedup_aggregator=None,
1489
        engine="",
1490
    ):
1491
        super().__init__()
1492
        self.alpha = alpha
1493
        self.beta = beta
1494
        self.lambda1 = lambda1
1495
        self.lambda2 = lambda2
1496
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1497
        self.engine = engine
1498

1499
    def _run(self, net, param_init_net, param_info):
1500
        param = param_info.blob
1501
        grad = param_info.grad
1502

1503
        if self.alpha <= 0:
1504
            return
1505

1506
        nz = param_init_net.ConstantFill(
1507
            [param], str(param) + "_ftrl_nz", extra_shape=[2], value=0.0
1508
        )
1509
        self._aux_params.local.append(nz)
1510
        if isinstance(grad, core.GradientSlice):
1511
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1512
            net.SparseFtrl(
1513
                [param, nz, grad.indices, grad.values],
1514
                [param, nz],
1515
                engine=self.engine,
1516
                alpha=self.alpha,
1517
                beta=self.beta,
1518
                lambda1=self.lambda1,
1519
                lambda2=self.lambda2,
1520
            )
1521
        else:
1522
            net.Ftrl(
1523
                [param, nz, grad],
1524
                [param, nz],
1525
                engine=self.engine,
1526
                alpha=self.alpha,
1527
                beta=self.beta,
1528
                lambda1=self.lambda1,
1529
                lambda2=self.lambda2,
1530
            )
1531

1532
    def scale_learning_rate(self, scale):
1533
        self.alpha *= scale
1534
        return
1535

1536

1537
class GFtrlOptimizer(Optimizer):
1538
    """Group Lasso FTRL Optimizer."""
1539

1540
    def __init__(
1541
        self,
1542
        alpha=0.01,
1543
        beta=1e-4,
1544
        lambda1=0,
1545
        lambda2=0,
1546
        sparse_dedup_aggregator=None,
1547
        engine="",
1548
    ):
1549
        super().__init__()
1550
        self.alpha = alpha
1551
        self.beta = beta
1552
        self.lambda1 = lambda1
1553
        self.lambda2 = lambda2
1554
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1555
        self.engine = engine
1556

1557
    def _run(self, net, param_init_net, param_info):
1558
        param = param_info.blob
1559
        grad = param_info.grad
1560

1561
        if self.alpha <= 0:
1562
            return
1563

1564
        nz = param_init_net.ConstantFill(
1565
            [param], str(param) + "_gftrl_nz", extra_shape=[2], value=0.0
1566
        )
1567
        self._aux_params.local.append(nz)
1568
        net.GFtrl(
1569
            [param, nz, grad],
1570
            [param, nz],
1571
            engine=self.engine,
1572
            alpha=self.alpha,
1573
            beta=self.beta,
1574
            lambda1=self.lambda1,
1575
            lambda2=self.lambda2,
1576
        )
1577

1578
    def scale_learning_rate(self, scale):
1579
        self.alpha *= scale
1580
        return
1581

1582

1583
class AdamOptimizer(Optimizer):
1584
    def __init__(
1585
        self,
1586
        alpha=0.001,
1587
        beta1=0.9,
1588
        beta2=0.999,
1589
        epsilon=1e-8,
1590
        policy="fixed",
1591
        use_lr_adaption=False,
1592
        lr_alpha=0.01,
1593
        normalized_lr_adaption=True,
1594
        sparse_dedup_aggregator=None,
1595
        rowWise=False,
1596
        engine="",
1597
        enableRAdam=False,
1598
        use_smart_decay=False,  # See https://fburl.com/2jdiwrhy for context.
1599
        **kwargs
1600
    ):
1601
        super().__init__()
1602
        self.alpha = alpha
1603
        self.beta1 = beta1
1604
        self.beta2 = beta2
1605
        self.epsilon = epsilon
1606
        self.policy = policy
1607
        self.use_lr_adaption = use_lr_adaption
1608
        self.lr_alpha = lr_alpha
1609
        self.normalized_lr_adaption = normalized_lr_adaption
1610
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1611
        self.rowWise = rowWise
1612
        self.engine = engine
1613
        self.enableRAdam = enableRAdam
1614
        if use_smart_decay:
1615
            if rowWise:
1616
                raise NotImplementedError(('Smart decay is not implemented for rowWise Adam.  '
1617
                                           'Set rowWise or use_smart_decay to False.'))
1618
            if enableRAdam:
1619
                raise NotImplementedError(('Smart decay is not implemented for RAdam.  '
1620
                                           'Set enableRAdam or use_smart_decay to False.'))
1621
            if use_lr_adaption:
1622
                raise NotImplementedError(('Smart decay is not implemented with lr_adaption.  '
1623
                                           'Set use_lr_adaption or use_smart_decay to False.'))
1624

1625
        self.use_smart_decay = use_smart_decay
1626
        self.init_kwargs = kwargs
1627

1628
    def _run(self, net, param_init_net, param_info):
1629
        param = param_info.blob
1630
        grad = param_info.grad
1631

1632
        if self.alpha <= 0:
1633
            return
1634

1635
        lr, iteration = self.build_lr(
1636
            net,
1637
            param_init_net,
1638
            base_learning_rate=self.alpha,
1639
            policy=self.policy,
1640
            **(self.init_kwargs)
1641
        )
1642

1643
        m1 = param_init_net.ConstantFill([param], param + "_first_moment", value=0.0)
1644

1645
        if self.rowWise:
1646
            shapes, types = workspace.InferShapesAndTypes([param_init_net])
1647
            m2 = param_init_net.ConstantFill(
1648
                [], param + "_avg_second_moment", shape=[shapes[param][0]], value=0.0
1649
            )
1650
        else:
1651
            m2 = param_init_net.ConstantFill(
1652
                [param], param + "_second_moment", value=0.0
1653
            )
1654

1655
        # Initialize "minibatch in which this parameter was last seen" for smart decay.
1656
        if self.use_smart_decay:
1657
            shapes, _ = workspace.InferShapesAndTypes([param_init_net])
1658
            last_seen = param_init_net.ConstantFill(
1659
                [], param + "_last_seen", shape=[shapes[param][0]], value=0, dtype=core.DataType.INT64
1660
            )
1661
            self._aux_params.local.append(last_seen)
1662

1663
        self._aux_params.shared.append(iteration)
1664
        self._aux_params.local.append(m1)
1665
        self._aux_params.local.append(m2)
1666

1667
        if self.rowWise:
1668
            assert isinstance(grad, core.GradientSlice), (
1669
                "If SparseAdam with rowWise=True, gradient must be "
1670
                "a gradientslice. PLease ensure that rowWise is not enabled "
1671
                "for the dense Adam optimizer, as it is not supported."
1672
            )
1673

1674
        output_blobs = [param, m1, m2]
1675

1676
        if self.use_smart_decay:
1677
            output_blobs.append(last_seen)
1678

1679
        if self.use_lr_adaption:
1680
            effective_grad = str(param) + "_effective_grad"
1681
            output_blobs.append(effective_grad)
1682

1683
        if isinstance(grad, core.GradientSlice):
1684
            grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
1685
            if self.rowWise:
1686
                op = "RowWiseSparseAdam"
1687
            elif self.use_smart_decay:
1688
                op = "SmartDecaySparseAdam"
1689
            else:
1690
                op = "SparseAdam"
1691

1692
            # Currently, only SparseAdam support RAdam, other Adam Ops will support later
1693
            if op == "SparseAdam":
1694
                net.__getattr__(op)(
1695
                    [param, m1, m2, grad.indices, grad.values, lr, iteration],
1696
                    output_blobs,
1697
                    beta1=self.beta1,
1698
                    beta2=self.beta2,
1699
                    epsilon=self.epsilon,
1700
                    enableRAdam=self.enableRAdam,
1701
                )
1702
            elif op == "SmartDecaySparseAdam":
1703
                net.__getattr__(op)(
1704
                    [param, m1, m2, last_seen, grad.indices, grad.values, lr, iteration],
1705
                    output_blobs,
1706
                    beta1=self.beta1,
1707
                    beta2=self.beta2,
1708
                    epsilon=self.epsilon,
1709
                )
1710
            else:
1711
                assert (
1712
                    not self.enableRAdam
1713
                ), "Currently, RowWiseSparseAdam is not supported by RAdam!"
1714
                net.__getattr__(op)(
1715
                    [param, m1, m2, grad.indices, grad.values, lr, iteration],
1716
                    output_blobs,
1717
                    beta1=self.beta1,
1718
                    beta2=self.beta2,
1719
                    epsilon=self.epsilon,
1720
                )
1721

1722
            if self.use_lr_adaption:
1723
                net.LearningRateAdaption(
1724
                    [lr, grad.values, effective_grad],
1725
                    [lr],
1726
                    lr_alpha=self.lr_alpha,
1727
                    normalized_lr_adaption=self.normalized_lr_adaption,
1728
                )
1729

1730
        else:
1731
            net.Adam(
1732
                [param, m1, m2, grad, lr, iteration],
1733
                output_blobs,
1734
                beta1=self.beta1,
1735
                beta2=self.beta2,
1736
                epsilon=self.epsilon,
1737
            )
1738
            if self.use_lr_adaption:
1739
                net.LearningRateAdaption(
1740
                    [lr, grad, effective_grad],
1741
                    [lr],
1742
                    lr_alpha=self.lr_alpha,
1743
                    normalized_lr_adaption=self.normalized_lr_adaption,
1744
                )
1745

1746
    def scale_learning_rate(self, scale):
1747
        self.alpha *= scale
1748
        return
1749

1750
class DecayAdagradOptimizer(Optimizer):
1751
    def __init__(
1752
        self,
1753
        alpha=0.01,
1754
        beta1=0.0,
1755
        beta2=0.999,
1756
        epsilon=0.1,
1757
        weight_decay=0.0,
1758
        ema_options=None,
1759
        bias_correction_first=True,
1760
        policy="fixed",
1761
        engine="",
1762
        **kwargs
1763
    ):
1764
        super().__init__()
1765
        self.alpha = alpha
1766
        self.beta1 = beta1
1767
        self.beta2 = beta2
1768
        self.epsilon = epsilon
1769
        self.weight_decay = weight_decay
1770
        self.bias_correction_first = bias_correction_first
1771
        self.policy = policy
1772
        self.engine = engine
1773
        self.init_kwargs = kwargs
1774
        self._process_ema_options(ema_options)
1775

1776
    def set_mapping_for_param2ema_teacher_param(self, param_mapping: Dict[str, Any]) -> None:
1777
        self.param2ema_teacher_param = param_mapping
1778

1779
    def _process_ema_options(self, ema_options):
1780
        self.ema_enabled = True if ema_options and "ema_alpha" in ema_options else False
1781
        self.ema_teacher_enabled = True if ema_options and "ema_teacher_alpha" in ema_options else False
1782
        self.param2ema_teacher_param = {}
1783
        if self.ema_enabled or self.ema_teacher_enabled:
1784
            self.ema_start = ema_options.get("ema_start", None)
1785
            self.ema_end = ema_options.get("ema_end", None)
1786
            self.ema_step = ema_options.get("ema_step", None)
1787
            self.ema_alpha = ema_options.get("ema_alpha", None)
1788
            self.ema_teacher_alpha = ema_options.get("ema_alpha", None)
1789
            self.ema_teacher_module_name = ema_options.get(
1790
                "ema_teacher_module_name", "ema_teacher_arch"
1791
            )
1792

1793
    def _run(self, net, param_init_net, param_info):
1794
        param = param_info.blob
1795
        grad = param_info.grad
1796

1797
        if self.alpha <= 0:
1798
            return
1799

1800
        lr, iteration = self.build_lr(
1801
            net,
1802
            param_init_net,
1803
            base_learning_rate=self.alpha,
1804
            policy=self.policy,
1805
            **(self.init_kwargs)
1806
        )
1807

1808
        if isinstance(grad, core.GradientSlice):
1809
            # hack for position weighted.
1810
            param_squared_sum = param_init_net.ConstantFill([param], param + "_squared_sum", value=0.0)
1811
            self._aux_params.local.append(param_squared_sum)
1812
            output_blobs = [param, param_squared_sum]
1813
            net.SparseAdagrad(
1814
                [param, param_squared_sum, grad.indices, grad.values, lr],
1815
                output_blobs,
1816
                epsilon=self.epsilon,
1817
            )
1818
        else:
1819
            m1 = param_init_net.ConstantFill([param], param + "_first_mo1ment", value=0.0)
1820
            m2 = param_init_net.ConstantFill([param], param + "_second_moment", value=0.0)
1821
            self._aux_params.shared.append(iteration)
1822
            self._aux_params.local.append(m1)
1823
            self._aux_params.local.append(m2)
1824
            output_blobs = [param, m1, m2]
1825
            net.DecayAdagrad(
1826
                [param, m1, m2, grad, lr, iteration],
1827
                output_blobs,
1828
                beta1=self.beta1,
1829
                beta2=self.beta2,
1830
                epsilon=self.epsilon,
1831
                weight_decay=self.weight_decay,
1832
                bias_correction_first=self.bias_correction_first,
1833
            )
1834

1835
            if self.ema_enabled:
1836
                param_ema = str(param) + "_ema"
1837
                if not param_init_net.BlobIsDefined(param_ema):
1838
                    param_init_net.ConstantFill([param], param_ema, value=0.0)
1839
                    self._aux_params.local.append(param_ema)
1840

1841
                net.EMA(
1842
                    [param, param_ema, iteration],
1843
                    [param, param_ema],
1844
                    ema_start=self.ema_start,
1845
                    ema_end=self.ema_end,
1846
                    ema_step=self.ema_step,
1847
                    ema_alpha=self.ema_alpha,
1848
                )
1849

1850
            if self.ema_teacher_enabled:
1851
                if param in self.param2ema_teacher_param:
1852
                    param_ema_teacher = self.param2ema_teacher_param[param]
1853
                    if not param_init_net.BlobIsDefined(param_ema_teacher):
1854
                        param_init_net.ConstantFill([param], param_ema_teacher, value=0.0)
1855
                        self._aux_params.local.append(param_ema_teacher)
1856

1857
                    net.EMA(
1858
                        [param, param_ema_teacher, iteration],
1859
                        [param, param_ema_teacher],
1860
                        ema_start=self.ema_start,
1861
                        ema_end=self.ema_end,
1862
                        ema_step=self.ema_step,
1863
                        ema_alpha=self.ema_teacher_alpha,
1864
                    )
1865

1866
    def scale_learning_rate(self, scale):
1867
        self.alpha *= scale
1868
        return
1869

1870
class YellowFinOptimizer(Optimizer):
1871
    """YellowFin: An automatic tuner for momentum SGD
1872

1873
    See https://arxiv.org/abs/1706.03471 for more details. This implementation
1874
    has separate learning rate and momentum per each parameter."""
1875

1876
    def __init__(
1877
        self,
1878
        alpha=0.1,
1879
        mu=0.0,
1880
        beta=0.999,
1881
        curv_win_width=20,
1882
        zero_debias=True,
1883
        epsilon=0.1 ** 6,
1884
        policy="fixed",
1885
        sparse_dedup_aggregator=None,
1886
        **kwargs
1887
    ):
1888
        super().__init__()
1889
        self.alpha = alpha
1890
        self.mu = mu
1891
        self.beta = beta
1892
        self.curv_win_width = curv_win_width
1893
        self.zero_debias = zero_debias
1894
        self.epsilon = epsilon
1895
        self.policy = policy
1896
        self.sparse_dedup_aggregator = sparse_dedup_aggregator
1897
        self.init_kwargs = kwargs
1898

1899
    def _run(self, net, param_init_net, param_info):
1900

1901
        # Note: This is number of persistent scalars in YellowFin optimizer.
1902
        #       It should always be the number of scalars being used. The same
1903
        #       number should be used in class for the operation.
1904
        SCALARS_MEMORY_SIZE = 5
1905

1906
        param = param_info.blob
1907
        grad = param_info.grad
1908
        moment = param_init_net.ConstantFill([param], param + "_moment", value=0.0)
1909
        curv_win = param_init_net.ConstantFill(
1910
            [], param + "_curv_win", shape=[self.curv_win_width], value=0.0
1911
        )
1912
        g_avg = param_init_net.ConstantFill([param], param + "_g_avg", value=0.0)
1913
        g2_avg = param_init_net.ConstantFill([param], param + "_g2_avg", value=0.0)
1914
        lr_avg = param_init_net.ConstantFill(
1915
            [], param + "_lr_avg", shape=[1], value=self.alpha
1916
        )
1917
        mu_avg = param_init_net.ConstantFill(
1918
            [], param + "_mu_avg", shape=[1], value=self.mu
1919
        )
1920
        scalars_memory = param_init_net.ConstantFill(
1921
            [], param + "_scalars_memory", shape=[SCALARS_MEMORY_SIZE], value=0.0
1922
        )
1923

1924
        assert self.alpha > 0
1925
        assert not isinstance(
1926
            grad, core.GradientSlice
1927
        ), "YellowFin does not support sparse gradients"
1928

1929
        iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=0)
1930

1931
        self._aux_params.shared.append(iteration)
1932
        self._aux_params.local.append(moment)
1933
        self._aux_params.local.append(lr_avg)
1934
        self._aux_params.local.append(mu_avg)
1935
        self._aux_params.local.append(curv_win)
1936
        self._aux_params.local.append(g_avg)
1937
        self._aux_params.local.append(g2_avg)
1938
        self._aux_params.local.append(scalars_memory)
1939

1940
        yf_in_out_args = [
1941
            param,
1942
            moment,
1943
            lr_avg,
1944
            mu_avg,
1945
            curv_win,
1946
            g_avg,
1947
            g2_avg,
1948
            scalars_memory,
1949
        ]
1950

1951
        net.YellowFin(
1952
            yf_in_out_args + [grad, iteration],
1953
            yf_in_out_args,
1954
            beta=self.beta,
1955
            epsilon=self.epsilon,
1956
            curv_win_width=self.curv_win_width,
1957
            zero_debias=self.zero_debias,
1958
        )
1959

1960
    def scale_learning_rate(self, scale):
1961
        self.alpha *= scale
1962
        return
1963

1964

1965
class RmsPropOptimizer(Optimizer):
1966
    def __init__(
1967
        self,
1968
        alpha=0.01,
1969
        decay=0.9,
1970
        momentum=0.0,
1971
        epsilon=1e-5,
1972
        policy="fixed",
1973
        engine="",
1974
        **kwargs
1975
    ):
1976
        super().__init__()
1977
        self.alpha = alpha
1978
        self.decay = decay
1979
        self.momentum = momentum
1980
        self.epsilon = epsilon
1981
        self.policy = policy
1982
        self.engine = engine
1983
        self.init_kwargs = kwargs
1984

1985
    def _run(self, net, param_init_net, param_info):
1986
        param = param_info.blob
1987
        grad = param_info.grad
1988

1989
        assert self.alpha > 0
1990
        assert not isinstance(
1991
            grad, core.GradientSlice
1992
        ), "RmsPropOptimizer doesn't support sparse gradients"
1993

1994
        dev = scope.CurrentDeviceScope()
1995
        if dev is None:
1996
            dev = core.DeviceOption(caffe2_pb2.CPU)
1997

1998
        ONE = param_init_net.ConstantFill(
1999
            [], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
2000
        )
2001

2002
        lr, _ = self.build_lr(
2003
            net,
2004
            param_init_net,
2005
            base_learning_rate=-self.alpha,
2006
            policy=self.policy,
2007
            **(self.init_kwargs)
2008
        )
2009

2010
        grad_o = param_init_net.ConstantFill(
2011
            [param], str(param) + "_grad_o", values=0.0
2012
        )
2013

2014
        ms = param_init_net.ConstantFill(
2015
            [param], str(param) + "_mean_squares", values=0.0
2016
        )
2017

2018
        mom = param_init_net.ConstantFill([param], str(param) + "_momentum", values=0.0)
2019

2020
        self._aux_params.local.append(ms)
2021
        self._aux_params.local.append(mom)
2022

2023
        net.RmsProp(
2024
            [grad, ms, mom, ONE],
2025
            [grad_o, ms, mom],
2026
            decay=self.decay,
2027
            momentum=self.momentum,
2028
            epsilon=self.epsilon,
2029
            engine=self.engine,
2030
        )
2031

2032
        net.MomentumSGDUpdate([grad_o, mom, lr, param], [grad_o, mom, param])
2033

2034
    def scale_learning_rate(self, scale):
2035
        self.alpha *= scale
2036
        return
2037

2038

2039
def _get_param_to_device(model):
2040
    # Infer blob devices by going through the net and param_init_net
2041
    # ops and observing the device used to create or use the blob.
2042
    param_to_device = core.InferBlobDevices(model.net)
2043
    param_to_device.update(core.InferBlobDevices(model.param_init_net))
2044
    return param_to_device
2045

2046

2047
def get_param_device(param_name, grad, param_to_device=None, default_device=None):
2048
    device = default_device
2049
    param_to_device = param_to_device or {}
2050
    # We first check if parameter's device has been inferred. If not,
2051
    # we check the gradient. This can happen if parameter is not output
2052
    # by any blob but created by a FetchBlob.
2053
    if param_name in param_to_device:
2054
        device = param_to_device[param_name]
2055
    else:
2056
        if isinstance(grad, core.GradientSlice):
2057
            grad = grad
2058
            if str(grad.values) in param_to_device:
2059
                device = param_to_device[str(grad.values)]
2060
            elif str(grad.indices) in param_to_device:
2061
                device = param_to_device[str(grad.indices)]
2062
        else:
2063
            grad_name = str(grad)
2064
            if grad_name in param_to_device:
2065
                device = param_to_device[grad_name]
2066

2067
    assert device is not None, "Cannot infer device for {}: no op creates it".format(
2068
        param_name
2069
    )
2070
    return device
2071

2072

2073
def get_lr_injection():
2074
    """
2075
    Gets current value for lr_injection, a multiplier for all base
2076
    learning rates.
2077
    Must set allow_lr_injection=True when building optimizer, as it
2078
    relies on synchronization over CPU.
2079
    """
2080
    return workspace.FetchBlob(_LEARNING_RATE_INJECTION)
2081

2082

2083
def set_lr_injection(lr_injection_value):
2084
    """
2085
    Sets lr_injection, a multiplier for all base learning rates.
2086
    Must set allow_lr_injection=True when building optimizer, as it
2087
    relies on synchronization over CPU.
2088
    """
2089
    workspace.FeedBlob(
2090
        _LEARNING_RATE_INJECTION,
2091
        np.array([float(lr_injection_value)], dtype=np.float32),
2092
    )
2093

2094

2095
def _calc_norm_ratio(model, params, name_scope, param_to_device, max_gradient_norm):
2096
    with core.NameScope(name_scope):
2097
        grad_squared_sums = []
2098
        for i, param in enumerate(params):
2099
            device = get_param_device(str(param.blob), param.grad, param_to_device)
2100

2101
            with core.DeviceScope(device):
2102
                grad = (
2103
                    param.grad
2104
                    if not isinstance(param.grad, core.GradientSlice)
2105
                    else param.grad.values
2106
                )
2107

2108
                grad_squared_sum_name = "grad_{}_squared_sum".format(i)
2109
                grad_squared_sum = model.net.SumSqrElements(grad, grad_squared_sum_name)
2110
                grad_squared_sum_cpu = model.net.EnsureCPUOutput(grad_squared_sum)
2111
                grad_squared_sums.append(grad_squared_sum_cpu)
2112

2113
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
2114
            grad_squared_full_sum = model.net.Sum(
2115
                grad_squared_sums, "grad_squared_full_sum"
2116
            )
2117
            global_norm = model.net.Pow(
2118
                grad_squared_full_sum, "global_norm", exponent=0.5
2119
            )
2120
            clip_norm = model.param_init_net.ConstantFill(
2121
                [], "clip_norm", shape=[], value=float(max_gradient_norm)
2122
            )
2123
            max_norm = model.net.Max([global_norm, clip_norm], "max_norm")
2124
            norm_ratio = model.net.Div([clip_norm, max_norm], "norm_ratio")
2125
            return norm_ratio
2126

2127

2128
def _build(
2129
    model,
2130
    optimizer,
2131
    weights_only=False,
2132
    use_param_info_optim=True,
2133
    max_gradient_norm=None,
2134
    allow_lr_injection=False,
2135
):
2136
    param_to_device = _get_param_to_device(model)
2137

2138
    # Validate there are no duplicate params
2139
    model.Validate()
2140

2141
    params = []
2142
    for param_info in model.GetOptimizationParamInfo():
2143
        if weights_only and param_info.blob not in model.weights:
2144
            continue
2145
        params.append(param_info)
2146

2147
    lr_multiplier = None
2148
    if max_gradient_norm is not None:
2149
        lr_multiplier = _calc_norm_ratio(
2150
            model,
2151
            params,
2152
            "norm_clipped_grad_update",
2153
            param_to_device,
2154
            max_gradient_norm,
2155
        )
2156

2157
    if allow_lr_injection:
2158
        if not model.net.BlobIsDefined(_LEARNING_RATE_INJECTION):
2159
            lr_injection = model.param_init_net.ConstantFill(
2160
                [], _LEARNING_RATE_INJECTION, shape=[1], value=1.0
2161
            )
2162
        else:
2163
            lr_injection = _LEARNING_RATE_INJECTION
2164

2165
        if lr_multiplier is None:
2166
            lr_multiplier = lr_injection
2167
        else:
2168
            lr_multiplier = model.net.Mul(
2169
                [lr_multiplier, lr_injection], "lr_multiplier", broadcast=1
2170
            )
2171
    optimizer.add_lr_multiplier(lr_multiplier)
2172

2173
    for param_info in params:
2174
        param_name = str(param_info.blob)
2175
        device = get_param_device(param_name, param_info.grad, param_to_device)
2176
        with core.DeviceScope(device):
2177
            if param_info.optimizer and use_param_info_optim:
2178
                param_info.optimizer(model.net, model.param_init_net, param_info)
2179
            else:
2180
                optimizer(model.net, model.param_init_net, param_info)
2181
    return optimizer
2182

2183

2184
def add_weight_decay(model, weight_decay):
2185
    """Adds a decay to weights in the model.
2186

2187
    This is a form of L2 regularization.
2188

2189
    Args:
2190
        weight_decay: strength of the regularization
2191
    """
2192
    _build(
2193
        model,
2194
        WeightDecayBuilder(weight_decay=weight_decay),
2195
        weights_only=True,
2196
        use_param_info_optim=False,
2197
    )
2198

2199

2200
def build_sgd(
2201
    model,
2202
    base_learning_rate,
2203
    max_gradient_norm=None,
2204
    allow_lr_injection=False,
2205
    **kwargs
2206
):
2207
    sgd_optimizer = SgdOptimizer(base_learning_rate, **kwargs)
2208
    return _build(
2209
        model,
2210
        sgd_optimizer,
2211
        max_gradient_norm=max_gradient_norm,
2212
        allow_lr_injection=allow_lr_injection,
2213
    )
2214

2215

2216
def build_multi_precision_sgd(
2217
    model,
2218
    base_learning_rate,
2219
    max_gradient_norm=None,
2220
    allow_lr_injection=False,
2221
    **kwargs
2222
):
2223
    multi_prec_sgd_optimizer = MultiPrecisionSgdOptimizer(base_learning_rate, **kwargs)
2224
    return _build(
2225
        model,
2226
        multi_prec_sgd_optimizer,
2227
        max_gradient_norm=max_gradient_norm,
2228
        allow_lr_injection=allow_lr_injection,
2229
    )
2230

2231

2232
def build_fp16_sgd(model, base_learning_rate, **kwargs):
2233
    fp16_sgd_optimizer = FP16SgdOptimizer(base_learning_rate, **kwargs)
2234
    return _build(model, fp16_sgd_optimizer)
2235

2236

2237
def build_ftrl(model, engine="SIMD", **kwargs):
2238
    if engine == "SIMD":
2239
        assert core.IsOperator("Ftrl_ENGINE_SIMD")
2240
        assert core.IsOperator("SparseFtrl_ENGINE_SIMD")
2241
    ftrl_optimizer = FtrlOptimizer(engine=engine, **kwargs)
2242
    return _build(model, ftrl_optimizer)
2243

2244

2245
def build_gftrl(model, engine="", **kwargs):
2246
    if engine == "SIMD":
2247
        assert core.IsOperator("GFtrl_ENGINE_SIMD")
2248
    gftrl_optimizer = GFtrlOptimizer(engine=engine, **kwargs)
2249
    return _build(model, gftrl_optimizer)
2250

2251

2252
def build_adagrad(
2253
    model,
2254
    base_learning_rate,
2255
    parameters=None,
2256
    max_gradient_norm=None,
2257
    allow_lr_injection=False,
2258
    **kwargs
2259
):
2260
    adagrad_optimizer = AdagradOptimizer(alpha=base_learning_rate, **kwargs)
2261
    return _build(
2262
        model,
2263
        adagrad_optimizer,
2264
        max_gradient_norm=max_gradient_norm,
2265
        allow_lr_injection=allow_lr_injection,
2266
    )
2267

2268

2269
def build_wngrad(
2270
    model,
2271
    base_learning_rate,
2272
    parameters=None,
2273
    max_gradient_norm=None,
2274
    allow_lr_injection=False,
2275
    **kwargs
2276
):
2277
    wngrad_optimizer = WngradOptimizer(alpha=base_learning_rate, **kwargs)
2278
    return _build(
2279
        model,
2280
        wngrad_optimizer,
2281
        max_gradient_norm=max_gradient_norm,
2282
        allow_lr_injection=allow_lr_injection,
2283
    )
2284

2285

2286
def build_storm(
2287
    model,
2288
    base_learning_rate,
2289
    parameters=None,
2290
    max_gradient_norm=None,
2291
    allow_lr_injection=False,
2292
    **kwargs
2293
):
2294
    storm_optimizer = StormOptimizer(lr=base_learning_rate, **kwargs)
2295
    return _build(
2296
        model,
2297
        storm_optimizer,
2298
        max_gradient_norm=max_gradient_norm,
2299
        allow_lr_injection=allow_lr_injection,
2300
    )
2301

2302

2303
def build_adadelta(
2304
    model,
2305
    base_learning_rate,
2306
    parameters=None,
2307
    max_gradient_norm=None,
2308
    allow_lr_injection=False,
2309
    **kwargs
2310
):
2311
    adadelta_optimizer = AdadeltaOptimizer(alpha=base_learning_rate, **kwargs)
2312
    return _build(
2313
        model,
2314
        adadelta_optimizer,
2315
        max_gradient_norm=max_gradient_norm,
2316
        allow_lr_injection=allow_lr_injection,
2317
    )
2318

2319

2320
def build_adam(
2321
    model,
2322
    base_learning_rate,
2323
    max_gradient_norm=None,
2324
    allow_lr_injection=False,
2325
    **kwargs
2326
):
2327
    adam_optimizer = AdamOptimizer(alpha=base_learning_rate, **kwargs)
2328
    return _build(
2329
        model,
2330
        adam_optimizer,
2331
        max_gradient_norm=max_gradient_norm,
2332
        allow_lr_injection=allow_lr_injection,
2333
    )
2334

2335
def build_decay_adagrad(
2336
    model,
2337
    base_learning_rate,
2338
    max_gradient_norm=None,
2339
    allow_lr_injection=False,
2340
    **kwargs
2341
):
2342
    decay_adagrad_optimizer = DecayAdagradOptimizer(alpha=base_learning_rate, **kwargs)
2343
    return _build(
2344
        model,
2345
        decay_adagrad_optimizer,
2346
        max_gradient_norm=max_gradient_norm,
2347
        allow_lr_injection=allow_lr_injection,
2348
    )
2349

2350
def build_yellowfin(model, base_learning_rate=0.1, **kwargs):
2351
    yellowfin_optimizer = YellowFinOptimizer(alpha=base_learning_rate, **kwargs)
2352
    return _build(model, yellowfin_optimizer)
2353

2354

2355
def build_rms_prop(
2356
    model,
2357
    base_learning_rate,
2358
    max_gradient_norm=None,
2359
    allow_lr_injection=False,
2360
    **kwargs
2361
):
2362
    rms_prop_optimizer = RmsPropOptimizer(alpha=base_learning_rate, **kwargs)
2363
    return _build(
2364
        model,
2365
        rms_prop_optimizer,
2366
        max_gradient_norm=max_gradient_norm,
2367
        allow_lr_injection=allow_lr_injection,
2368
    )
2369

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.