pytorch

Форк
0
/
lr_scheduler.py 
1805 строк · 76.5 Кб
1
import types
2
import math
3
from torch import inf
4
from functools import wraps, partial
5
import warnings
6
import weakref
7
from collections import Counter
8
from bisect import bisect_right
9

10
from .optimizer import Optimizer
11

12
__all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
13
           'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
14
           'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
15

16
EPOCH_DEPRECATION_WARNING = (
17
    "The epoch parameter in `scheduler.step()` was not necessary and is being "
18
    "deprecated where possible. Please use `scheduler.step()` to step the "
19
    "scheduler. During the deprecation, if epoch is different from None, the "
20
    "closed form is used instead of the new chainable form, where available. "
21
    "Please open an issue if you are unable to replicate your use case: "
22
    "https://github.com/pytorch/pytorch/issues/new/choose."
23
)
24

25
def _check_verbose_deprecated_warning(verbose):
26
    """Raises a warning when verbose is not the default value."""
27
    if verbose != "deprecated":
28
        warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
29
                      "to access the learning rate.", UserWarning)
30
        return verbose
31
    return False
32

33
class LRScheduler:
34

35
    def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
36

37
        # Attach optimizer
38
        if not isinstance(optimizer, Optimizer):
39
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
40
        self.optimizer = optimizer
41

42
        # Initialize epoch and base learning rates
43
        if last_epoch == -1:
44
            for group in optimizer.param_groups:
45
                group.setdefault('initial_lr', group['lr'])
46
        else:
47
            for i, group in enumerate(optimizer.param_groups):
48
                if 'initial_lr' not in group:
49
                    raise KeyError("param 'initial_lr' is not specified "
50
                                   f"in param_groups[{i}] when resuming an optimizer")
51
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
52
        self.last_epoch = last_epoch
53

54
        # Following https://github.com/pytorch/pytorch/issues/20124
55
        # We would like to ensure that `lr_scheduler.step()` is called after
56
        # `optimizer.step()`
57
        def with_counter(method):
58
            if getattr(method, '_with_counter', False):
59
                # `optimizer.step()` has already been replaced, return.
60
                return method
61

62
            # Keep a weak reference to the optimizer instance to prevent
63
            # cyclic references.
64
            instance_ref = weakref.ref(method.__self__)
65
            # Get the unbound method for the same purpose.
66
            func = method.__func__
67
            cls = instance_ref().__class__
68
            del method
69

70
            @wraps(func)
71
            def wrapper(*args, **kwargs):
72
                instance = instance_ref()
73
                instance._step_count += 1
74
                wrapped = func.__get__(instance, cls)
75
                return wrapped(*args, **kwargs)
76

77
            # Note that the returned function here is no longer a bound method,
78
            # so attributes like `__func__` and `__self__` no longer exist.
79
            wrapper._with_counter = True
80
            return wrapper
81

82
        self.optimizer.step = with_counter(self.optimizer.step)
83
        self.verbose = _check_verbose_deprecated_warning(verbose)
84

85
        self._initial_step()
86

87
    def _initial_step(self):
88
        """Initialize step counts and performs a step"""
89
        self.optimizer._step_count = 0
90
        self._step_count = 0
91
        self.step()
92

93
    def state_dict(self):
94
        """Returns the state of the scheduler as a :class:`dict`.
95

96
        It contains an entry for every variable in self.__dict__ which
97
        is not the optimizer.
98
        """
99
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
100

101
    def load_state_dict(self, state_dict):
102
        """Loads the schedulers state.
103

104
        Args:
105
            state_dict (dict): scheduler state. Should be an object returned
106
                from a call to :meth:`state_dict`.
107
        """
108
        self.__dict__.update(state_dict)
109

110
    def get_last_lr(self):
111
        """ Return last computed learning rate by current scheduler.
112
        """
113
        return self._last_lr
114

115
    def get_lr(self):
116
        # Compute learning rate using chainable form of the scheduler
117
        raise NotImplementedError
118

119
    def print_lr(self, is_verbose, group, lr, epoch=None):
120
        """Display the current learning rate.
121
        """
122
        if is_verbose:
123
            if epoch is None:
124
                print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
125
            else:
126
                epoch_str = ("%.2f" if isinstance(epoch, float) else
127
                             "%.5d") % epoch
128
                print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
129

130

131
    def step(self, epoch=None):
132
        # Raise a warning if old pattern is detected
133
        # https://github.com/pytorch/pytorch/issues/20124
134
        if self._step_count == 1:
135
            if not hasattr(self.optimizer.step, "_with_counter"):
136
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
137
                              "initialization. Please, make sure to call `optimizer.step()` before "
138
                              "`lr_scheduler.step()`. See more details at "
139
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
140

141
            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
142
            elif self.optimizer._step_count < 1:
143
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
144
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
145
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
146
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
147
                              "See more details at "
148
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
149
        self._step_count += 1
150

151
        with _enable_get_lr_call(self):
152
            if epoch is None:
153
                self.last_epoch += 1
154
                values = self.get_lr()
155
            else:
156
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
157
                self.last_epoch = epoch
158
                if hasattr(self, "_get_closed_form_lr"):
159
                    values = self._get_closed_form_lr()
160
                else:
161
                    values = self.get_lr()
162

163
        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
164
            param_group, lr = data
165
            param_group['lr'] = lr
166

167
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
168

169

170
# Including _LRScheduler for backwards compatibility
171
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
172
class _LRScheduler(LRScheduler):
173
    pass
174

175

176
class _enable_get_lr_call:
177

178
    def __init__(self, o):
179
        self.o = o
180

181
    def __enter__(self):
182
        self.o._get_lr_called_within_step = True
183
        return self
184

185
    def __exit__(self, type, value, traceback):
186
        self.o._get_lr_called_within_step = False
187

188

189
class LambdaLR(LRScheduler):
190
    """Sets the learning rate of each parameter group to the initial lr
191
    times a given function. When last_epoch=-1, sets initial lr as lr.
192

193
    Args:
194
        optimizer (Optimizer): Wrapped optimizer.
195
        lr_lambda (function or list): A function which computes a multiplicative
196
            factor given an integer parameter epoch, or a list of such
197
            functions, one for each group in optimizer.param_groups.
198
        last_epoch (int): The index of last epoch. Default: -1.
199
        verbose (bool): If ``True``, prints a message to stdout for
200
            each update. Default: ``False``.
201

202
            .. deprecated:: 2.2
203
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
204
                learning rate.
205

206
    Example:
207
        >>> # xdoctest: +SKIP
208
        >>> # Assuming optimizer has two groups.
209
        >>> lambda1 = lambda epoch: epoch // 30
210
        >>> lambda2 = lambda epoch: 0.95 ** epoch
211
        >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
212
        >>> for epoch in range(100):
213
        >>>     train(...)
214
        >>>     validate(...)
215
        >>>     scheduler.step()
216
    """
217

218
    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
219
        self.optimizer = optimizer
220

221
        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
222
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
223
        else:
224
            if len(lr_lambda) != len(optimizer.param_groups):
225
                raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
226
            self.lr_lambdas = list(lr_lambda)
227
        super().__init__(optimizer, last_epoch, verbose)
228

229
    def state_dict(self):
230
        """Returns the state of the scheduler as a :class:`dict`.
231

232
        It contains an entry for every variable in self.__dict__ which
233
        is not the optimizer.
234
        The learning rate lambda functions will only be saved if they are callable objects
235
        and not if they are functions or lambdas.
236

237
        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
238
        """
239

240
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
241
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
242

243
        for idx, fn in enumerate(self.lr_lambdas):
244
            if not isinstance(fn, types.FunctionType):
245
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
246

247
        return state_dict
248

249
    def load_state_dict(self, state_dict):
250
        """Loads the schedulers state.
251

252
        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
253

254
        Args:
255
            state_dict (dict): scheduler state. Should be an object returned
256
                from a call to :meth:`state_dict`.
257
        """
258

259
        lr_lambdas = state_dict.pop('lr_lambdas')
260
        self.__dict__.update(state_dict)
261
        # Restore state_dict keys in order to prevent side effects
262
        # https://github.com/pytorch/pytorch/issues/32756
263
        state_dict['lr_lambdas'] = lr_lambdas
264

265
        for idx, fn in enumerate(lr_lambdas):
266
            if fn is not None:
267
                self.lr_lambdas[idx].__dict__.update(fn)
268

269
    def get_lr(self):
270
        if not self._get_lr_called_within_step:
271
            warnings.warn("To get the last learning rate computed by the scheduler, "
272
                          "please use `get_last_lr()`.")
273

274
        return [base_lr * lmbda(self.last_epoch)
275
                for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
276

277

278
class MultiplicativeLR(LRScheduler):
279
    """Multiply the learning rate of each parameter group by the factor given
280
    in the specified function. When last_epoch=-1, sets initial lr as lr.
281

282
    Args:
283
        optimizer (Optimizer): Wrapped optimizer.
284
        lr_lambda (function or list): A function which computes a multiplicative
285
            factor given an integer parameter epoch, or a list of such
286
            functions, one for each group in optimizer.param_groups.
287
        last_epoch (int): The index of last epoch. Default: -1.
288
        verbose (bool): If ``True``, prints a message to stdout for
289
            each update. Default: ``False``.
290

291
            .. deprecated:: 2.2
292
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
293
                learning rate.
294

295
    Example:
296
        >>> # xdoctest: +SKIP
297
        >>> lmbda = lambda epoch: 0.95
298
        >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
299
        >>> for epoch in range(100):
300
        >>>     train(...)
301
        >>>     validate(...)
302
        >>>     scheduler.step()
303
    """
304

305
    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
306
        self.optimizer = optimizer
307

308
        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
309
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
310
        else:
311
            if len(lr_lambda) != len(optimizer.param_groups):
312
                raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
313
            self.lr_lambdas = list(lr_lambda)
314
        super().__init__(optimizer, last_epoch, verbose)
315

316
    def state_dict(self):
317
        """Returns the state of the scheduler as a :class:`dict`.
318

319
        It contains an entry for every variable in self.__dict__ which
320
        is not the optimizer.
321
        The learning rate lambda functions will only be saved if they are callable objects
322
        and not if they are functions or lambdas.
323
        """
324
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
325
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
326

327
        for idx, fn in enumerate(self.lr_lambdas):
328
            if not isinstance(fn, types.FunctionType):
329
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
330

331
        return state_dict
332

333
    def load_state_dict(self, state_dict):
334
        """Loads the schedulers state.
335

336
        Args:
337
            state_dict (dict): scheduler state. Should be an object returned
338
                from a call to :meth:`state_dict`.
339
        """
340
        lr_lambdas = state_dict.pop('lr_lambdas')
341
        self.__dict__.update(state_dict)
342
        # Restore state_dict keys in order to prevent side effects
343
        # https://github.com/pytorch/pytorch/issues/32756
344
        state_dict['lr_lambdas'] = lr_lambdas
345

346
        for idx, fn in enumerate(lr_lambdas):
347
            if fn is not None:
348
                self.lr_lambdas[idx].__dict__.update(fn)
349

350
    def get_lr(self):
351
        if not self._get_lr_called_within_step:
352
            warnings.warn("To get the last learning rate computed by the scheduler, "
353
                          "please use `get_last_lr()`.", UserWarning)
354

355
        if self.last_epoch > 0:
356
            return [group['lr'] * lmbda(self.last_epoch)
357
                    for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
358
        else:
359
            return [group['lr'] for group in self.optimizer.param_groups]
360

361

362
class StepLR(LRScheduler):
363
    """Decays the learning rate of each parameter group by gamma every
364
    step_size epochs. Notice that such decay can happen simultaneously with
365
    other changes to the learning rate from outside this scheduler. When
366
    last_epoch=-1, sets initial lr as lr.
367

368
    Args:
369
        optimizer (Optimizer): Wrapped optimizer.
370
        step_size (int): Period of learning rate decay.
371
        gamma (float): Multiplicative factor of learning rate decay.
372
            Default: 0.1.
373
        last_epoch (int): The index of last epoch. Default: -1.
374
        verbose (bool): If ``True``, prints a message to stdout for
375
            each update. Default: ``False``.
376

377
            .. deprecated:: 2.2
378
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
379
                learning rate.
380

381
    Example:
382
        >>> # xdoctest: +SKIP
383
        >>> # Assuming optimizer uses lr = 0.05 for all groups
384
        >>> # lr = 0.05     if epoch < 30
385
        >>> # lr = 0.005    if 30 <= epoch < 60
386
        >>> # lr = 0.0005   if 60 <= epoch < 90
387
        >>> # ...
388
        >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
389
        >>> for epoch in range(100):
390
        >>>     train(...)
391
        >>>     validate(...)
392
        >>>     scheduler.step()
393
    """
394

395
    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose="deprecated"):
396
        self.step_size = step_size
397
        self.gamma = gamma
398
        super().__init__(optimizer, last_epoch, verbose)
399

400
    def get_lr(self):
401
        if not self._get_lr_called_within_step:
402
            warnings.warn("To get the last learning rate computed by the scheduler, "
403
                          "please use `get_last_lr()`.", UserWarning)
404

405
        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
406
            return [group['lr'] for group in self.optimizer.param_groups]
407
        return [group['lr'] * self.gamma
408
                for group in self.optimizer.param_groups]
409

410
    def _get_closed_form_lr(self):
411
        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
412
                for base_lr in self.base_lrs]
413

414

415
class MultiStepLR(LRScheduler):
416
    """Decays the learning rate of each parameter group by gamma once the
417
    number of epoch reaches one of the milestones. Notice that such decay can
418
    happen simultaneously with other changes to the learning rate from outside
419
    this scheduler. When last_epoch=-1, sets initial lr as lr.
420

421
    Args:
422
        optimizer (Optimizer): Wrapped optimizer.
423
        milestones (list): List of epoch indices. Must be increasing.
424
        gamma (float): Multiplicative factor of learning rate decay.
425
            Default: 0.1.
426
        last_epoch (int): The index of last epoch. Default: -1.
427
        verbose (bool): If ``True``, prints a message to stdout for
428
            each update. Default: ``False``.
429

430
            .. deprecated:: 2.2
431
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
432
                learning rate.
433

434
    Example:
435
        >>> # xdoctest: +SKIP
436
        >>> # Assuming optimizer uses lr = 0.05 for all groups
437
        >>> # lr = 0.05     if epoch < 30
438
        >>> # lr = 0.005    if 30 <= epoch < 80
439
        >>> # lr = 0.0005   if epoch >= 80
440
        >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
441
        >>> for epoch in range(100):
442
        >>>     train(...)
443
        >>>     validate(...)
444
        >>>     scheduler.step()
445
    """
446

447
    def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose="deprecated"):
448
        self.milestones = Counter(milestones)
449
        self.gamma = gamma
450
        super().__init__(optimizer, last_epoch, verbose)
451

452
    def get_lr(self):
453
        if not self._get_lr_called_within_step:
454
            warnings.warn("To get the last learning rate computed by the scheduler, "
455
                          "please use `get_last_lr()`.", UserWarning)
456

457
        if self.last_epoch not in self.milestones:
458
            return [group['lr'] for group in self.optimizer.param_groups]
459
        return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
460
                for group in self.optimizer.param_groups]
461

462
    def _get_closed_form_lr(self):
463
        milestones = sorted(self.milestones.elements())
464
        return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
465
                for base_lr in self.base_lrs]
466

467

468
class ConstantLR(LRScheduler):
469
    """Decays the learning rate of each parameter group by a small constant factor until the
470
    number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
471
    happen simultaneously with other changes to the learning rate from outside this scheduler.
472
    When last_epoch=-1, sets initial lr as lr.
473

474
    Args:
475
        optimizer (Optimizer): Wrapped optimizer.
476
        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
477
        total_iters (int): The number of steps that the scheduler decays the learning rate.
478
            Default: 5.
479
        last_epoch (int): The index of the last epoch. Default: -1.
480
        verbose (bool): If ``True``, prints a message to stdout for
481
            each update. Default: ``False``.
482

483
            .. deprecated:: 2.2
484
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
485
                learning rate.
486

487
    Example:
488
        >>> # xdoctest: +SKIP
489
        >>> # Assuming optimizer uses lr = 0.05 for all groups
490
        >>> # lr = 0.025   if epoch == 0
491
        >>> # lr = 0.025   if epoch == 1
492
        >>> # lr = 0.025   if epoch == 2
493
        >>> # lr = 0.025   if epoch == 3
494
        >>> # lr = 0.05    if epoch >= 4
495
        >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
496
        >>> for epoch in range(100):
497
        >>>     train(...)
498
        >>>     validate(...)
499
        >>>     scheduler.step()
500
    """
501

502
    def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose="deprecated"):
503
        if factor > 1.0 or factor < 0:
504
            raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
505

506
        self.factor = factor
507
        self.total_iters = total_iters
508
        super().__init__(optimizer, last_epoch, verbose)
509

510
    def get_lr(self):
511
        if not self._get_lr_called_within_step:
512
            warnings.warn("To get the last learning rate computed by the scheduler, "
513
                          "please use `get_last_lr()`.", UserWarning)
514

515
        if self.last_epoch == 0:
516
            return [group['lr'] * self.factor for group in self.optimizer.param_groups]
517

518
        if self.last_epoch != self.total_iters:
519
            return [group['lr'] for group in self.optimizer.param_groups]
520

521
        return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
522

523
    def _get_closed_form_lr(self):
524
        return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
525
                for base_lr in self.base_lrs]
526

527

528
class LinearLR(LRScheduler):
529
    """Decays the learning rate of each parameter group by linearly changing small
530
    multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
531
    Notice that such decay can happen simultaneously with other changes to the learning rate
532
    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
533

534
    Args:
535
        optimizer (Optimizer): Wrapped optimizer.
536
        start_factor (float): The number we multiply learning rate in the first epoch.
537
            The multiplication factor changes towards end_factor in the following epochs.
538
            Default: 1./3.
539
        end_factor (float): The number we multiply learning rate at the end of linear changing
540
            process. Default: 1.0.
541
        total_iters (int): The number of iterations that multiplicative factor reaches to 1.
542
            Default: 5.
543
        last_epoch (int): The index of the last epoch. Default: -1.
544
        verbose (bool): If ``True``, prints a message to stdout for
545
            each update. Default: ``False``.
546

547
            .. deprecated:: 2.2
548
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
549
                learning rate.
550

551
    Example:
552
        >>> # xdoctest: +SKIP
553
        >>> # Assuming optimizer uses lr = 0.05 for all groups
554
        >>> # lr = 0.025    if epoch == 0
555
        >>> # lr = 0.03125  if epoch == 1
556
        >>> # lr = 0.0375   if epoch == 2
557
        >>> # lr = 0.04375  if epoch == 3
558
        >>> # lr = 0.05    if epoch >= 4
559
        >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
560
        >>> for epoch in range(100):
561
        >>>     train(...)
562
        >>>     validate(...)
563
        >>>     scheduler.step()
564
    """
565

566
    def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
567
                 verbose="deprecated"):
568
        if start_factor > 1.0 or start_factor <= 0:
569
            raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
570

571
        if end_factor > 1.0 or end_factor < 0:
572
            raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
573

574
        self.start_factor = start_factor
575
        self.end_factor = end_factor
576
        self.total_iters = total_iters
577
        super().__init__(optimizer, last_epoch, verbose)
578

579
    def get_lr(self):
580
        if not self._get_lr_called_within_step:
581
            warnings.warn("To get the last learning rate computed by the scheduler, "
582
                          "please use `get_last_lr()`.", UserWarning)
583

584
        if self.last_epoch == 0:
585
            return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
586

587
        if self.last_epoch > self.total_iters:
588
            return [group['lr'] for group in self.optimizer.param_groups]
589

590
        return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
591
                (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
592
                for group in self.optimizer.param_groups]
593

594
    def _get_closed_form_lr(self):
595
        return [base_lr * (self.start_factor +
596
                (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
597
                for base_lr in self.base_lrs]
598

599

600
class ExponentialLR(LRScheduler):
601
    """Decays the learning rate of each parameter group by gamma every epoch.
602
    When last_epoch=-1, sets initial lr as lr.
603

604
    Args:
605
        optimizer (Optimizer): Wrapped optimizer.
606
        gamma (float): Multiplicative factor of learning rate decay.
607
        last_epoch (int): The index of last epoch. Default: -1.
608
        verbose (bool): If ``True``, prints a message to stdout for
609
            each update. Default: ``False``.
610

611
            .. deprecated:: 2.2
612
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
613
                learning rate.
614
    """
615

616
    def __init__(self, optimizer, gamma, last_epoch=-1, verbose="deprecated"):
617
        self.gamma = gamma
618
        super().__init__(optimizer, last_epoch, verbose)
619

620
    def get_lr(self):
621
        if not self._get_lr_called_within_step:
622
            warnings.warn("To get the last learning rate computed by the scheduler, "
623
                          "please use `get_last_lr()`.", UserWarning)
624

625
        if self.last_epoch == 0:
626
            return [group['lr'] for group in self.optimizer.param_groups]
627
        return [group['lr'] * self.gamma
628
                for group in self.optimizer.param_groups]
629

630
    def _get_closed_form_lr(self):
631
        return [base_lr * self.gamma ** self.last_epoch
632
                for base_lr in self.base_lrs]
633

634

635
class SequentialLR(LRScheduler):
636
    """Receives the list of schedulers that is expected to be called sequentially during
637
    optimization process and milestone points that provides exact intervals to reflect
638
    which scheduler is supposed to be called at a given epoch.
639

640
    Args:
641
        optimizer (Optimizer): Wrapped optimizer.
642
        schedulers (list): List of chained schedulers.
643
        milestones (list): List of integers that reflects milestone points.
644
        last_epoch (int): The index of last epoch. Default: -1.
645
        verbose (bool): Does nothing.
646

647
            .. deprecated:: 2.2
648
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
649
                learning rate.
650

651
    Example:
652
        >>> # xdoctest: +SKIP
653
        >>> # Assuming optimizer uses lr = 1. for all groups
654
        >>> # lr = 0.1     if epoch == 0
655
        >>> # lr = 0.1     if epoch == 1
656
        >>> # lr = 0.9     if epoch == 2
657
        >>> # lr = 0.81    if epoch == 3
658
        >>> # lr = 0.729   if epoch == 4
659
        >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
660
        >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
661
        >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
662
        >>> for epoch in range(100):
663
        >>>     train(...)
664
        >>>     validate(...)
665
        >>>     scheduler.step()
666
    """
667

668
    def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose="deprecated"):
669
        for scheduler_idx in range(len(schedulers)):
670
            if schedulers[scheduler_idx].optimizer != optimizer:
671
                raise ValueError(
672
                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
673
                    f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
674
                )
675

676
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
677
                raise ValueError(
678
                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
679
                    f"got schedulers at index {0} and {scheduler_idx} to be different."
680
                )
681
        if (len(milestones) != len(schedulers) - 1):
682
            raise ValueError(
683
                "Sequential Schedulers expects number of schedulers provided to be one more "
684
                f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
685
                f"number of milestones to be equal to {len(milestones)}"
686
            )
687
        _check_verbose_deprecated_warning(verbose)
688
        self._schedulers = schedulers
689
        self._milestones = milestones
690
        self.last_epoch = last_epoch + 1
691
        self.optimizer = optimizer
692

693
        # Reset learning rates back to initial values
694
        for group in self.optimizer.param_groups:
695
            group["lr"] = group["initial_lr"]
696

697
        # "Undo" the step performed by other schedulers
698
        for scheduler in self._schedulers:
699
            scheduler.last_epoch -= 1
700

701
        # Perform the initial step for only the first scheduler
702
        self._schedulers[0]._initial_step()
703

704
        self._last_lr = schedulers[0].get_last_lr()
705

706
    def step(self):
707
        self.last_epoch += 1
708
        idx = bisect_right(self._milestones, self.last_epoch)
709
        scheduler = self._schedulers[idx]
710
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
711
            scheduler.step(0)
712
        else:
713
            scheduler.step()
714

715
        self._last_lr = scheduler.get_last_lr()
716

717
    def state_dict(self):
718
        """Returns the state of the scheduler as a :class:`dict`.
719

720
        It contains an entry for every variable in self.__dict__ which
721
        is not the optimizer.
722
        The wrapped scheduler states will also be saved.
723
        """
724
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
725
        state_dict['_schedulers'] = [None] * len(self._schedulers)
726

727
        for idx, s in enumerate(self._schedulers):
728
            state_dict['_schedulers'][idx] = s.state_dict()
729

730
        return state_dict
731

732
    def load_state_dict(self, state_dict):
733
        """Loads the schedulers state.
734

735
        Args:
736
            state_dict (dict): scheduler state. Should be an object returned
737
                from a call to :meth:`state_dict`.
738
        """
739
        _schedulers = state_dict.pop('_schedulers')
740
        self.__dict__.update(state_dict)
741
        # Restore state_dict keys in order to prevent side effects
742
        # https://github.com/pytorch/pytorch/issues/32756
743
        state_dict['_schedulers'] = _schedulers
744

745
        for idx, s in enumerate(_schedulers):
746
            self._schedulers[idx].load_state_dict(s)
747

748

749
class PolynomialLR(LRScheduler):
750
    """Decays the learning rate of each parameter group using a polynomial function
751
    in the given total_iters. When last_epoch=-1, sets initial lr as lr.
752

753
    Args:
754
        optimizer (Optimizer): Wrapped optimizer.
755
        total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
756
        power (float): The power of the polynomial. Default: 1.0.
757
        verbose (bool): If ``True``, prints a message to stdout for
758
            each update. Default: ``False``.
759

760
            .. deprecated:: 2.2
761
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
762
                learning rate.
763

764
    Example:
765
        >>> # xdoctest: +SKIP("undefined vars")
766
        >>> # Assuming optimizer uses lr = 0.001 for all groups
767
        >>> # lr = 0.001     if epoch == 0
768
        >>> # lr = 0.00075   if epoch == 1
769
        >>> # lr = 0.00050   if epoch == 2
770
        >>> # lr = 0.00025   if epoch == 3
771
        >>> # lr = 0.0       if epoch >= 4
772
        >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
773
        >>> for epoch in range(100):
774
        >>>     train(...)
775
        >>>     validate(...)
776
        >>>     scheduler.step()
777
    """
778
    def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose="deprecated"):
779
        self.total_iters = total_iters
780
        self.power = power
781
        super().__init__(optimizer, last_epoch, verbose)
782

783
    def get_lr(self):
784
        if not self._get_lr_called_within_step:
785
            warnings.warn("To get the last learning rate computed by the scheduler, "
786
                          "please use `get_last_lr()`.", UserWarning)
787

788
        if self.last_epoch == 0 or self.last_epoch > self.total_iters:
789
            return [group["lr"] for group in self.optimizer.param_groups]
790

791
        decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
792
        return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
793

794
    def _get_closed_form_lr(self):
795
        return [
796
            (
797
                base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
798
            )
799
            for base_lr in self.base_lrs
800
        ]
801

802

803
class CosineAnnealingLR(LRScheduler):
804
    r"""Set the learning rate of each parameter group using a cosine annealing
805
    schedule, where :math:`\eta_{max}` is set to the initial lr and
806
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
807

808
    .. math::
809
        \begin{aligned}
810
            \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
811
            + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
812
            & T_{cur} \neq (2k+1)T_{max}; \\
813
            \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
814
            \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
815
            & T_{cur} = (2k+1)T_{max}.
816
        \end{aligned}
817

818
    When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
819
    is defined recursively, the learning rate can be simultaneously modified
820
    outside this scheduler by other operators. If the learning rate is set
821
    solely by this scheduler, the learning rate at each step becomes:
822

823
    .. math::
824
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
825
        \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
826

827
    It has been proposed in
828
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
829
    implements the cosine annealing part of SGDR, and not the restarts.
830

831
    Args:
832
        optimizer (Optimizer): Wrapped optimizer.
833
        T_max (int): Maximum number of iterations.
834
        eta_min (float): Minimum learning rate. Default: 0.
835
        last_epoch (int): The index of last epoch. Default: -1.
836
        verbose (bool): If ``True``, prints a message to stdout for
837
            each update. Default: ``False``.
838

839
            .. deprecated:: 2.2
840
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
841
                learning rate.
842

843
    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
844
        https://arxiv.org/abs/1608.03983
845
    """
846

847
    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose="deprecated"):
848
        self.T_max = T_max
849
        self.eta_min = eta_min
850
        super().__init__(optimizer, last_epoch, verbose)
851

852
    def get_lr(self):
853
        if not self._get_lr_called_within_step:
854
            warnings.warn("To get the last learning rate computed by the scheduler, "
855
                          "please use `get_last_lr()`.", UserWarning)
856

857
        if self.last_epoch == 0:
858
            return [group['lr'] for group in self.optimizer.param_groups]
859
        elif self._step_count == 1 and self.last_epoch > 0:
860
            return [self.eta_min + (base_lr - self.eta_min) *
861
                    (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
862
                    for base_lr, group in
863
                    zip(self.base_lrs, self.optimizer.param_groups)]
864
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
865
            return [group['lr'] + (base_lr - self.eta_min) *
866
                    (1 - math.cos(math.pi / self.T_max)) / 2
867
                    for base_lr, group in
868
                    zip(self.base_lrs, self.optimizer.param_groups)]
869
        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
870
                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
871
                (group['lr'] - self.eta_min) + self.eta_min
872
                for group in self.optimizer.param_groups]
873

874
    def _get_closed_form_lr(self):
875
        return [self.eta_min + (base_lr - self.eta_min) *
876
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
877
                for base_lr in self.base_lrs]
878

879

880
class ChainedScheduler(LRScheduler):
881
    """Chains list of learning rate schedulers. It takes a list of chainable learning
882
    rate schedulers and performs consecutive step() functions belonging to them by just
883
    one call.
884

885
    Args:
886
        schedulers (list): List of chained schedulers.
887

888
    Example:
889
        >>> # xdoctest: +SKIP
890
        >>> # Assuming optimizer uses lr = 1. for all groups
891
        >>> # lr = 0.09     if epoch == 0
892
        >>> # lr = 0.081    if epoch == 1
893
        >>> # lr = 0.729    if epoch == 2
894
        >>> # lr = 0.6561   if epoch == 3
895
        >>> # lr = 0.59049  if epoch >= 4
896
        >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
897
        >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
898
        >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
899
        >>> for epoch in range(100):
900
        >>>     train(...)
901
        >>>     validate(...)
902
        >>>     scheduler.step()
903
    """
904

905
    def __init__(self, schedulers):
906
        for scheduler_idx in range(1, len(schedulers)):
907
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
908
                raise ValueError(
909
                    "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
910
                    f"got schedulers at index {0} and {scheduler_idx} to be different"
911
                )
912
        self._schedulers = list(schedulers)
913
        self.optimizer = schedulers[0].optimizer
914
        self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
915

916
    def step(self):
917
        for scheduler in self._schedulers:
918
            scheduler.step()
919
        self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
920

921
    def state_dict(self):
922
        """Returns the state of the scheduler as a :class:`dict`.
923

924
        It contains an entry for every variable in self.__dict__ which
925
        is not the optimizer.
926
        The wrapped scheduler states will also be saved.
927
        """
928
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
929
        state_dict['_schedulers'] = [None] * len(self._schedulers)
930

931
        for idx, s in enumerate(self._schedulers):
932
            state_dict['_schedulers'][idx] = s.state_dict()
933

934
        return state_dict
935

936
    def load_state_dict(self, state_dict):
937
        """Loads the schedulers state.
938

939
        Args:
940
            state_dict (dict): scheduler state. Should be an object returned
941
                from a call to :meth:`state_dict`.
942
        """
943
        _schedulers = state_dict.pop('_schedulers')
944
        self.__dict__.update(state_dict)
945
        # Restore state_dict keys in order to prevent side effects
946
        # https://github.com/pytorch/pytorch/issues/32756
947
        state_dict['_schedulers'] = _schedulers
948

949
        for idx, s in enumerate(_schedulers):
950
            self._schedulers[idx].load_state_dict(s)
951

952

953
class ReduceLROnPlateau(LRScheduler):
954
    """Reduce learning rate when a metric has stopped improving.
955
    Models often benefit from reducing the learning rate by a factor
956
    of 2-10 once learning stagnates. This scheduler reads a metrics
957
    quantity and if no improvement is seen for a 'patience' number
958
    of epochs, the learning rate is reduced.
959

960
    Args:
961
        optimizer (Optimizer): Wrapped optimizer.
962
        mode (str): One of `min`, `max`. In `min` mode, lr will
963
            be reduced when the quantity monitored has stopped
964
            decreasing; in `max` mode it will be reduced when the
965
            quantity monitored has stopped increasing. Default: 'min'.
966
        factor (float): Factor by which the learning rate will be
967
            reduced. new_lr = lr * factor. Default: 0.1.
968
        patience (int): The number of allowed epochs with no improvement after
969
            which the learning rate will be reduced.
970
            For example, consider the case of having no patience (`patience = 0`).
971
            In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
972
            In the second epoch, if the performance is worse than the baseline,
973
            we have what is considered an intolerable epoch.
974
            Since the count of intolerable epochs (1) is greater than the patience level (0),
975
            the learning rate is reduced at the end of this epoch.
976
            From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
977
            if the performance is worse than the baseline. If the performance improves or remains the same,
978
            the learning rate is not adjusted.
979
            Default: 10.
980
        threshold (float): Threshold for measuring the new optimum,
981
            to only focus on significant changes. Default: 1e-4.
982
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
983
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
984
            mode or best * ( 1 - threshold ) in `min` mode.
985
            In `abs` mode, dynamic_threshold = best + threshold in
986
            `max` mode or best - threshold in `min` mode. Default: 'rel'.
987
        cooldown (int): Number of epochs to wait before resuming
988
            normal operation after lr has been reduced. Default: 0.
989
        min_lr (float or list): A scalar or a list of scalars. A
990
            lower bound on the learning rate of all param groups
991
            or each group respectively. Default: 0.
992
        eps (float): Minimal decay applied to lr. If the difference
993
            between new and old lr is smaller than eps, the update is
994
            ignored. Default: 1e-8.
995
        verbose (bool): If ``True``, prints a message to stdout for
996
            each update. Default: ``False``.
997

998
            .. deprecated:: 2.2
999
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1000
                learning rate.
1001

1002
    Example:
1003
        >>> # xdoctest: +SKIP
1004
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1005
        >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
1006
        >>> for epoch in range(10):
1007
        >>>     train(...)
1008
        >>>     val_loss = validate(...)
1009
        >>>     # Note that step should be called after validate()
1010
        >>>     scheduler.step(val_loss)
1011
    """
1012

1013
    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
1014
                 threshold=1e-4, threshold_mode='rel', cooldown=0,
1015
                 min_lr=0, eps=1e-8, verbose="deprecated"):
1016

1017
        if factor >= 1.0:
1018
            raise ValueError('Factor should be < 1.0.')
1019
        self.factor = factor
1020

1021
        # Attach optimizer
1022
        if not isinstance(optimizer, Optimizer):
1023
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1024
        self.optimizer = optimizer
1025

1026
        if isinstance(min_lr, (list, tuple)):
1027
            if len(min_lr) != len(optimizer.param_groups):
1028
                raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
1029
            self.min_lrs = list(min_lr)
1030
        else:
1031
            self.min_lrs = [min_lr] * len(optimizer.param_groups)
1032

1033
        self.patience = patience
1034

1035
        self.verbose = _check_verbose_deprecated_warning(verbose)
1036
        self.cooldown = cooldown
1037
        self.cooldown_counter = 0
1038
        self.mode = mode
1039
        self.threshold = threshold
1040
        self.threshold_mode = threshold_mode
1041
        self.best = None
1042
        self.num_bad_epochs = None
1043
        self.mode_worse = None  # the worse value for the chosen mode
1044
        self.eps = eps
1045
        self.last_epoch = 0
1046
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1047
        self._init_is_better(mode=mode, threshold=threshold,
1048
                             threshold_mode=threshold_mode)
1049
        self._reset()
1050

1051
    def _reset(self):
1052
        """Resets num_bad_epochs counter and cooldown counter."""
1053
        self.best = self.mode_worse
1054
        self.cooldown_counter = 0
1055
        self.num_bad_epochs = 0
1056

1057
    def step(self, metrics, epoch=None):
1058
        # convert `metrics` to float, in case it's a zero-dim Tensor
1059
        current = float(metrics)
1060
        if epoch is None:
1061
            epoch = self.last_epoch + 1
1062
        else:
1063
            warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
1064
        self.last_epoch = epoch
1065

1066
        if self.is_better(current, self.best):
1067
            self.best = current
1068
            self.num_bad_epochs = 0
1069
        else:
1070
            self.num_bad_epochs += 1
1071

1072
        if self.in_cooldown:
1073
            self.cooldown_counter -= 1
1074
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown
1075

1076
        if self.num_bad_epochs > self.patience:
1077
            self._reduce_lr(epoch)
1078
            self.cooldown_counter = self.cooldown
1079
            self.num_bad_epochs = 0
1080

1081
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1082

1083
    def _reduce_lr(self, epoch):
1084
        for i, param_group in enumerate(self.optimizer.param_groups):
1085
            old_lr = float(param_group['lr'])
1086
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
1087
            if old_lr - new_lr > self.eps:
1088
                param_group['lr'] = new_lr
1089

1090
    @property
1091
    def in_cooldown(self):
1092
        return self.cooldown_counter > 0
1093

1094
    def is_better(self, a, best):
1095
        if self.mode == 'min' and self.threshold_mode == 'rel':
1096
            rel_epsilon = 1. - self.threshold
1097
            return a < best * rel_epsilon
1098

1099
        elif self.mode == 'min' and self.threshold_mode == 'abs':
1100
            return a < best - self.threshold
1101

1102
        elif self.mode == 'max' and self.threshold_mode == 'rel':
1103
            rel_epsilon = self.threshold + 1.
1104
            return a > best * rel_epsilon
1105

1106
        else:  # mode == 'max' and epsilon_mode == 'abs':
1107
            return a > best + self.threshold
1108

1109
    def _init_is_better(self, mode, threshold, threshold_mode):
1110
        if mode not in {'min', 'max'}:
1111
            raise ValueError('mode ' + mode + ' is unknown!')
1112
        if threshold_mode not in {'rel', 'abs'}:
1113
            raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
1114

1115
        if mode == 'min':
1116
            self.mode_worse = inf
1117
        else:  # mode == 'max':
1118
            self.mode_worse = -inf
1119

1120
        self.mode = mode
1121
        self.threshold = threshold
1122
        self.threshold_mode = threshold_mode
1123

1124
    def state_dict(self):
1125
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
1126

1127
    def load_state_dict(self, state_dict):
1128
        self.__dict__.update(state_dict)
1129
        self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
1130

1131

1132
class CyclicLR(LRScheduler):
1133
    r"""Sets the learning rate of each parameter group according to
1134
    cyclical learning rate policy (CLR). The policy cycles the learning
1135
    rate between two boundaries with a constant frequency, as detailed in
1136
    the paper `Cyclical Learning Rates for Training Neural Networks`_.
1137
    The distance between the two boundaries can be scaled on a per-iteration
1138
    or per-cycle basis.
1139

1140
    Cyclical learning rate policy changes the learning rate after every batch.
1141
    `step` should be called after a batch has been used for training.
1142

1143
    This class has three built-in policies, as put forth in the paper:
1144

1145
    * "triangular": A basic triangular cycle without amplitude scaling.
1146
    * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
1147
    * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
1148
      at each cycle iteration.
1149

1150
    This implementation was adapted from the github repo: `bckenstler/CLR`_
1151

1152
    Args:
1153
        optimizer (Optimizer): Wrapped optimizer.
1154
        base_lr (float or list): Initial learning rate which is the
1155
            lower boundary in the cycle for each parameter group.
1156
        max_lr (float or list): Upper learning rate boundaries in the cycle
1157
            for each parameter group. Functionally,
1158
            it defines the cycle amplitude (max_lr - base_lr).
1159
            The lr at any cycle is the sum of base_lr
1160
            and some scaling of the amplitude; therefore
1161
            max_lr may not actually be reached depending on
1162
            scaling function.
1163
        step_size_up (int): Number of training iterations in the
1164
            increasing half of a cycle. Default: 2000
1165
        step_size_down (int): Number of training iterations in the
1166
            decreasing half of a cycle. If step_size_down is None,
1167
            it is set to step_size_up. Default: None
1168
        mode (str): One of {triangular, triangular2, exp_range}.
1169
            Values correspond to policies detailed above.
1170
            If scale_fn is not None, this argument is ignored.
1171
            Default: 'triangular'
1172
        gamma (float): Constant in 'exp_range' scaling function:
1173
            gamma**(cycle iterations)
1174
            Default: 1.0
1175
        scale_fn (function): Custom scaling policy defined by a single
1176
            argument lambda function, where
1177
            0 <= scale_fn(x) <= 1 for all x >= 0.
1178
            If specified, then 'mode' is ignored.
1179
            Default: None
1180
        scale_mode (str): {'cycle', 'iterations'}.
1181
            Defines whether scale_fn is evaluated on
1182
            cycle number or cycle iterations (training
1183
            iterations since start of cycle).
1184
            Default: 'cycle'
1185
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
1186
            to learning rate between 'base_momentum' and 'max_momentum'.
1187
            Default: True
1188
        base_momentum (float or list): Lower momentum boundaries in the cycle
1189
            for each parameter group. Note that momentum is cycled inversely
1190
            to learning rate; at the peak of a cycle, momentum is
1191
            'base_momentum' and learning rate is 'max_lr'.
1192
            Default: 0.8
1193
        max_momentum (float or list): Upper momentum boundaries in the cycle
1194
            for each parameter group. Functionally,
1195
            it defines the cycle amplitude (max_momentum - base_momentum).
1196
            The momentum at any cycle is the difference of max_momentum
1197
            and some scaling of the amplitude; therefore
1198
            base_momentum may not actually be reached depending on
1199
            scaling function. Note that momentum is cycled inversely
1200
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
1201
            and learning rate is 'base_lr'
1202
            Default: 0.9
1203
        last_epoch (int): The index of the last batch. This parameter is used when
1204
            resuming a training job. Since `step()` should be invoked after each
1205
            batch instead of after each epoch, this number represents the total
1206
            number of *batches* computed, not the total number of epochs computed.
1207
            When last_epoch=-1, the schedule is started from the beginning.
1208
            Default: -1
1209
        verbose (bool): If ``True``, prints a message to stdout for
1210
            each update. Default: ``False``.
1211

1212
            .. deprecated:: 2.2
1213
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1214
                learning rate.
1215

1216
    Example:
1217
        >>> # xdoctest: +SKIP
1218
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1219
        >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
1220
        >>> data_loader = torch.utils.data.DataLoader(...)
1221
        >>> for epoch in range(10):
1222
        >>>     for batch in data_loader:
1223
        >>>         train_batch(...)
1224
        >>>         scheduler.step()
1225

1226

1227
    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
1228
    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
1229
    """
1230

1231
    def __init__(self,
1232
                 optimizer,
1233
                 base_lr,
1234
                 max_lr,
1235
                 step_size_up=2000,
1236
                 step_size_down=None,
1237
                 mode='triangular',
1238
                 gamma=1.,
1239
                 scale_fn=None,
1240
                 scale_mode='cycle',
1241
                 cycle_momentum=True,
1242
                 base_momentum=0.8,
1243
                 max_momentum=0.9,
1244
                 last_epoch=-1,
1245
                 verbose="deprecated"):
1246

1247
        # Attach optimizer
1248
        if not isinstance(optimizer, Optimizer):
1249
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1250
        self.optimizer = optimizer
1251

1252
        base_lrs = self._format_param('base_lr', optimizer, base_lr)
1253
        if last_epoch == -1:
1254
            for lr, group in zip(base_lrs, optimizer.param_groups):
1255
                group['lr'] = lr
1256

1257
        self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
1258

1259
        step_size_up = float(step_size_up)
1260
        step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
1261
        self.total_size = step_size_up + step_size_down
1262
        self.step_ratio = step_size_up / self.total_size
1263

1264
        if mode not in ['triangular', 'triangular2', 'exp_range'] \
1265
                and scale_fn is None:
1266
            raise ValueError('mode is invalid and scale_fn is None')
1267

1268
        self.mode = mode
1269
        self.gamma = gamma
1270

1271
        self._scale_fn_ref = None
1272
        self._scale_fn_custom = scale_fn
1273
        self.scale_mode = scale_mode
1274
        self._init_scale_fn()
1275

1276
        self.cycle_momentum = cycle_momentum
1277
        if cycle_momentum:
1278
            if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
1279
                raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
1280

1281
            self.use_beta1 = 'betas' in self.optimizer.defaults
1282
            self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
1283
            self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
1284
            if last_epoch == -1:
1285
                for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
1286
                    if self.use_beta1:
1287
                        group['betas'] = (m_momentum, *group['betas'][1:])
1288
                    else:
1289
                        group['momentum'] = m_momentum
1290
                    group['max_momentum'] = m_momentum
1291
                    group['base_momentum'] = b_momentum
1292

1293
        super().__init__(optimizer, last_epoch, verbose)
1294
        self.base_lrs = base_lrs
1295

1296
    def _init_scale_fn(self):
1297
        if self._scale_fn_custom is not None:
1298
            return
1299
        if self.mode == 'triangular':
1300
            self._scale_fn_ref = self._triangular_scale_fn
1301
            self.scale_mode = 'cycle'
1302
        elif self.mode == 'triangular2':
1303
            self._scale_fn_ref = self._triangular2_scale_fn
1304
            self.scale_mode = 'cycle'
1305
        elif self.mode == 'exp_range':
1306
            self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
1307
            self.scale_mode = 'iterations'
1308

1309
    def _format_param(self, name, optimizer, param):
1310
        """Return correctly formatted lr/momentum for each param group."""
1311
        if isinstance(param, (list, tuple)):
1312
            if len(param) != len(optimizer.param_groups):
1313
                raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
1314
            return param
1315
        else:
1316
            return [param] * len(optimizer.param_groups)
1317

1318
    def scale_fn(self, x):
1319
        if self._scale_fn_custom is not None:
1320
            return self._scale_fn_custom(x)
1321
        else:
1322
            return self._scale_fn_ref(x)  # static method
1323

1324
    @staticmethod
1325
    def _triangular_scale_fn(x):
1326
        return 1.
1327

1328
    @staticmethod
1329
    def _triangular2_scale_fn(x):
1330
        return 1 / (2. ** (x - 1))
1331

1332
    @staticmethod
1333
    def _exp_range_scale_fn(gamma, x):
1334
        return gamma ** x
1335

1336
    def get_lr(self):
1337
        """Calculates the learning rate at batch index. This function treats
1338
        `self.last_epoch` as the last batch index.
1339

1340
        If `self.cycle_momentum` is ``True``, this function has a side effect of
1341
        updating the optimizer's momentum.
1342
        """
1343

1344
        if not self._get_lr_called_within_step:
1345
            warnings.warn("To get the last learning rate computed by the scheduler, "
1346
                          "please use `get_last_lr()`.", UserWarning)
1347

1348
        cycle = math.floor(1 + self.last_epoch / self.total_size)
1349
        x = 1. + self.last_epoch / self.total_size - cycle
1350
        if x <= self.step_ratio:
1351
            scale_factor = x / self.step_ratio
1352
        else:
1353
            scale_factor = (x - 1) / (self.step_ratio - 1)
1354

1355
        lrs = []
1356
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
1357
            base_height = (max_lr - base_lr) * scale_factor
1358
            if self.scale_mode == 'cycle':
1359
                lr = base_lr + base_height * self.scale_fn(cycle)
1360
            else:
1361
                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
1362
            lrs.append(lr)
1363

1364
        if self.cycle_momentum:
1365
            momentums = []
1366
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
1367
                base_height = (max_momentum - base_momentum) * scale_factor
1368
                if self.scale_mode == 'cycle':
1369
                    momentum = max_momentum - base_height * self.scale_fn(cycle)
1370
                else:
1371
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
1372
                momentums.append(momentum)
1373
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
1374
                if self.use_beta1:
1375
                    param_group['betas'] = (momentum, *param_group['betas'][1:])
1376
                else:
1377
                    param_group['momentum'] = momentum
1378

1379
        return lrs
1380

1381
    def state_dict(self):
1382
        state = super().state_dict()
1383
        # We are dropping the `_scale_fn_ref` attribute because it is a
1384
        # `weakref.WeakMethod` and can't be pickled.
1385
        state.pop('_scale_fn_ref')
1386
        fn = state.pop('_scale_fn_custom')
1387
        state['_scale_fn_custom'] = None
1388
        if fn is not None and not isinstance(fn, types.FunctionType):
1389
            # The _scale_fn_custom will only be saved if it is a callable object
1390
            # and not if it is a function or lambda.
1391
            state['_scale_fn_custom'] = fn.__dict__.copy()
1392

1393
        return state
1394

1395
    def load_state_dict(self, state_dict):
1396
        fn = state_dict.pop('_scale_fn_custom')
1397
        super().load_state_dict(state_dict)
1398
        if fn is not None:
1399
            self._scale_fn_custom.__dict__.update(fn)
1400
        self._init_scale_fn()
1401

1402

1403
class CosineAnnealingWarmRestarts(LRScheduler):
1404
    r"""Set the learning rate of each parameter group using a cosine annealing
1405
    schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
1406
    is the number of epochs since the last restart and :math:`T_{i}` is the number
1407
    of epochs between two warm restarts in SGDR:
1408

1409
    .. math::
1410
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1411
        \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
1412

1413
    When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
1414
    When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
1415

1416
    It has been proposed in
1417
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
1418

1419
    Args:
1420
        optimizer (Optimizer): Wrapped optimizer.
1421
        T_0 (int): Number of iterations for the first restart.
1422
        T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
1423
        eta_min (float, optional): Minimum learning rate. Default: 0.
1424
        last_epoch (int, optional): The index of last epoch. Default: -1.
1425
        verbose (bool): If ``True``, prints a message to stdout for
1426
            each update. Default: ``False``.
1427

1428
            .. deprecated:: 2.2
1429
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1430
                learning rate.
1431

1432
    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
1433
        https://arxiv.org/abs/1608.03983
1434
    """
1435

1436
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose="deprecated"):
1437
        if T_0 <= 0 or not isinstance(T_0, int):
1438
            raise ValueError(f"Expected positive integer T_0, but got {T_0}")
1439
        if T_mult < 1 or not isinstance(T_mult, int):
1440
            raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
1441
        if not isinstance(eta_min, (float, int)):
1442
            raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
1443
        self.T_0 = T_0
1444
        self.T_i = T_0
1445
        self.T_mult = T_mult
1446
        self.eta_min = eta_min
1447
        self.T_cur = last_epoch
1448
        super().__init__(optimizer, last_epoch, verbose)
1449

1450
    def get_lr(self):
1451
        if not self._get_lr_called_within_step:
1452
            warnings.warn("To get the last learning rate computed by the scheduler, "
1453
                          "please use `get_last_lr()`.", UserWarning)
1454

1455
        return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
1456
                for base_lr in self.base_lrs]
1457

1458
    def step(self, epoch=None):
1459
        """Step could be called after every batch update
1460

1461
        Example:
1462
            >>> # xdoctest: +SKIP("Undefined vars")
1463
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1464
            >>> iters = len(dataloader)
1465
            >>> for epoch in range(20):
1466
            >>>     for i, sample in enumerate(dataloader):
1467
            >>>         inputs, labels = sample['inputs'], sample['labels']
1468
            >>>         optimizer.zero_grad()
1469
            >>>         outputs = net(inputs)
1470
            >>>         loss = criterion(outputs, labels)
1471
            >>>         loss.backward()
1472
            >>>         optimizer.step()
1473
            >>>         scheduler.step(epoch + i / iters)
1474

1475
        This function can be called in an interleaved way.
1476

1477
        Example:
1478
            >>> # xdoctest: +SKIP("Undefined vars")
1479
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1480
            >>> for epoch in range(20):
1481
            >>>     scheduler.step()
1482
            >>> scheduler.step(26)
1483
            >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
1484
        """
1485

1486
        if epoch is None and self.last_epoch < 0:
1487
            epoch = 0
1488

1489
        if epoch is None:
1490
            epoch = self.last_epoch + 1
1491
            self.T_cur = self.T_cur + 1
1492
            if self.T_cur >= self.T_i:
1493
                self.T_cur = self.T_cur - self.T_i
1494
                self.T_i = self.T_i * self.T_mult
1495
        else:
1496
            if epoch < 0:
1497
                raise ValueError(f"Expected non-negative epoch, but got {epoch}")
1498
            if epoch >= self.T_0:
1499
                if self.T_mult == 1:
1500
                    self.T_cur = epoch % self.T_0
1501
                else:
1502
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
1503
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
1504
                    self.T_i = self.T_0 * self.T_mult ** (n)
1505
            else:
1506
                self.T_i = self.T_0
1507
                self.T_cur = epoch
1508
        self.last_epoch = math.floor(epoch)
1509

1510
        class _enable_get_lr_call:
1511

1512
            def __init__(self, o):
1513
                self.o = o
1514

1515
            def __enter__(self):
1516
                self.o._get_lr_called_within_step = True
1517
                return self
1518

1519
            def __exit__(self, type, value, traceback):
1520
                self.o._get_lr_called_within_step = False
1521
                return self
1522

1523
        with _enable_get_lr_call(self):
1524
            for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
1525
                param_group, lr = data
1526
                param_group['lr'] = lr
1527

1528
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1529

1530

1531
class OneCycleLR(LRScheduler):
1532
    r"""Sets the learning rate of each parameter group according to the
1533
    1cycle learning rate policy. The 1cycle policy anneals the learning
1534
    rate from an initial learning rate to some maximum learning rate and then
1535
    from that maximum learning rate to some minimum learning rate much lower
1536
    than the initial learning rate.
1537
    This policy was initially described in the paper `Super-Convergence:
1538
    Very Fast Training of Neural Networks Using Large Learning Rates`_.
1539

1540
    The 1cycle learning rate policy changes the learning rate after every batch.
1541
    `step` should be called after a batch has been used for training.
1542

1543
    This scheduler is not chainable.
1544

1545
    Note also that the total number of steps in the cycle can be determined in one
1546
    of two ways (listed in order of precedence):
1547

1548
    #. A value for total_steps is explicitly provided.
1549
    #. A number of epochs (epochs) and a number of steps per epoch
1550
       (steps_per_epoch) are provided.
1551
       In this case, the number of total steps is inferred by
1552
       total_steps = epochs * steps_per_epoch
1553

1554
    You must either provide a value for total_steps or provide a value for both
1555
    epochs and steps_per_epoch.
1556

1557
    The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
1558
    claims that "unpublished work has shown even better results by using only two phases". To
1559
    mimic the behaviour of the original paper instead, set ``three_phase=True``.
1560

1561
    Args:
1562
        optimizer (Optimizer): Wrapped optimizer.
1563
        max_lr (float or list): Upper learning rate boundaries in the cycle
1564
            for each parameter group.
1565
        total_steps (int): The total number of steps in the cycle. Note that
1566
            if a value is not provided here, then it must be inferred by providing
1567
            a value for epochs and steps_per_epoch.
1568
            Default: None
1569
        epochs (int): The number of epochs to train for. This is used along
1570
            with steps_per_epoch in order to infer the total number of steps in the cycle
1571
            if a value for total_steps is not provided.
1572
            Default: None
1573
        steps_per_epoch (int): The number of steps per epoch to train for. This is
1574
            used along with epochs in order to infer the total number of steps in the
1575
            cycle if a value for total_steps is not provided.
1576
            Default: None
1577
        pct_start (float): The percentage of the cycle (in number of steps) spent
1578
            increasing the learning rate.
1579
            Default: 0.3
1580
        anneal_strategy (str): {'cos', 'linear'}
1581
            Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
1582
            linear annealing.
1583
            Default: 'cos'
1584
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
1585
            to learning rate between 'base_momentum' and 'max_momentum'.
1586
            Default: True
1587
        base_momentum (float or list): Lower momentum boundaries in the cycle
1588
            for each parameter group. Note that momentum is cycled inversely
1589
            to learning rate; at the peak of a cycle, momentum is
1590
            'base_momentum' and learning rate is 'max_lr'.
1591
            Default: 0.85
1592
        max_momentum (float or list): Upper momentum boundaries in the cycle
1593
            for each parameter group. Functionally,
1594
            it defines the cycle amplitude (max_momentum - base_momentum).
1595
            Note that momentum is cycled inversely
1596
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
1597
            and learning rate is 'base_lr'
1598
            Default: 0.95
1599
        div_factor (float): Determines the initial learning rate via
1600
            initial_lr = max_lr/div_factor
1601
            Default: 25
1602
        final_div_factor (float): Determines the minimum learning rate via
1603
            min_lr = initial_lr/final_div_factor
1604
            Default: 1e4
1605
        three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
1606
            learning rate according to 'final_div_factor' instead of modifying the second
1607
            phase (the first two phases will be symmetrical about the step indicated by
1608
            'pct_start').
1609
        last_epoch (int): The index of the last batch. This parameter is used when
1610
            resuming a training job. Since `step()` should be invoked after each
1611
            batch instead of after each epoch, this number represents the total
1612
            number of *batches* computed, not the total number of epochs computed.
1613
            When last_epoch=-1, the schedule is started from the beginning.
1614
            Default: -1
1615
        verbose (bool): If ``True``, prints a message to stdout for
1616
            each update. Default: ``False``.
1617

1618
            .. deprecated:: 2.2
1619
                ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1620
                learning rate.
1621

1622
    Example:
1623
        >>> # xdoctest: +SKIP
1624
        >>> data_loader = torch.utils.data.DataLoader(...)
1625
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1626
        >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
1627
        >>> for epoch in range(10):
1628
        >>>     for batch in data_loader:
1629
        >>>         train_batch(...)
1630
        >>>         optimizer.step()
1631
        >>>         scheduler.step()
1632

1633

1634
    .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
1635
        https://arxiv.org/abs/1708.07120
1636
    """
1637
    def __init__(self,
1638
                 optimizer,
1639
                 max_lr,
1640
                 total_steps=None,
1641
                 epochs=None,
1642
                 steps_per_epoch=None,
1643
                 pct_start=0.3,
1644
                 anneal_strategy='cos',
1645
                 cycle_momentum=True,
1646
                 base_momentum=0.85,
1647
                 max_momentum=0.95,
1648
                 div_factor=25.,
1649
                 final_div_factor=1e4,
1650
                 three_phase=False,
1651
                 last_epoch=-1,
1652
                 verbose="deprecated"):
1653

1654
        # Validate optimizer
1655
        if not isinstance(optimizer, Optimizer):
1656
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1657
        self.optimizer = optimizer
1658

1659
        # Validate total_steps
1660
        if total_steps is None and epochs is None and steps_per_epoch is None:
1661
            raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
1662
        elif total_steps is not None:
1663
            if total_steps <= 0 or not isinstance(total_steps, int):
1664
                raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
1665
            self.total_steps = total_steps
1666
        else:
1667
            if epochs <= 0 or not isinstance(epochs, int):
1668
                raise ValueError(f"Expected positive integer epochs, but got {epochs}")
1669
            if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
1670
                raise ValueError(f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}")
1671
            self.total_steps = epochs * steps_per_epoch
1672

1673
        if three_phase:
1674
            self._schedule_phases = [
1675
                {
1676
                    'end_step': float(pct_start * self.total_steps) - 1,
1677
                    'start_lr': 'initial_lr',
1678
                    'end_lr': 'max_lr',
1679
                    'start_momentum': 'max_momentum',
1680
                    'end_momentum': 'base_momentum',
1681
                },
1682
                {
1683
                    'end_step': float(2 * pct_start * self.total_steps) - 2,
1684
                    'start_lr': 'max_lr',
1685
                    'end_lr': 'initial_lr',
1686
                    'start_momentum': 'base_momentum',
1687
                    'end_momentum': 'max_momentum',
1688
                },
1689
                {
1690
                    'end_step': self.total_steps - 1,
1691
                    'start_lr': 'initial_lr',
1692
                    'end_lr': 'min_lr',
1693
                    'start_momentum': 'max_momentum',
1694
                    'end_momentum': 'max_momentum',
1695
                },
1696
            ]
1697
        else:
1698
            self._schedule_phases = [
1699
                {
1700
                    'end_step': float(pct_start * self.total_steps) - 1,
1701
                    'start_lr': 'initial_lr',
1702
                    'end_lr': 'max_lr',
1703
                    'start_momentum': 'max_momentum',
1704
                    'end_momentum': 'base_momentum',
1705
                },
1706
                {
1707
                    'end_step': self.total_steps - 1,
1708
                    'start_lr': 'max_lr',
1709
                    'end_lr': 'min_lr',
1710
                    'start_momentum': 'base_momentum',
1711
                    'end_momentum': 'max_momentum',
1712
                },
1713
            ]
1714

1715
        # Validate pct_start
1716
        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
1717
            raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
1718

1719
        # Validate anneal_strategy
1720
        if anneal_strategy not in ['cos', 'linear']:
1721
            raise ValueError(f"anneal_strategy must by one of 'cos' or 'linear', instead got {anneal_strategy}")
1722
        elif anneal_strategy == 'cos':
1723
            self.anneal_func = self._annealing_cos
1724
        elif anneal_strategy == 'linear':
1725
            self.anneal_func = self._annealing_linear
1726

1727
        # Initialize learning rate variables
1728
        max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
1729
        if last_epoch == -1:
1730
            for idx, group in enumerate(self.optimizer.param_groups):
1731
                group['initial_lr'] = max_lrs[idx] / div_factor
1732
                group['max_lr'] = max_lrs[idx]
1733
                group['min_lr'] = group['initial_lr'] / final_div_factor
1734

1735
        # Initialize momentum variables
1736
        self.cycle_momentum = cycle_momentum
1737
        if self.cycle_momentum:
1738
            if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
1739
                raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
1740
            self.use_beta1 = 'betas' in self.optimizer.defaults
1741
            max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
1742
            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
1743
            if last_epoch == -1:
1744
                for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
1745
                    if self.use_beta1:
1746
                        group['betas'] = (m_momentum, *group['betas'][1:])
1747
                    else:
1748
                        group['momentum'] = m_momentum
1749
                    group['max_momentum'] = m_momentum
1750
                    group['base_momentum'] = b_momentum
1751

1752
        super().__init__(optimizer, last_epoch, verbose)
1753

1754
    def _format_param(self, name, optimizer, param):
1755
        """Return correctly formatted lr/momentum for each param group."""
1756
        if isinstance(param, (list, tuple)):
1757
            if len(param) != len(optimizer.param_groups):
1758
                raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
1759
            return param
1760
        else:
1761
            return [param] * len(optimizer.param_groups)
1762

1763
    @staticmethod
1764
    def _annealing_cos(start, end, pct):
1765
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
1766
        cos_out = math.cos(math.pi * pct) + 1
1767
        return end + (start - end) / 2.0 * cos_out
1768

1769
    @staticmethod
1770
    def _annealing_linear(start, end, pct):
1771
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
1772
        return (end - start) * pct + start
1773

1774
    def get_lr(self):
1775
        if not self._get_lr_called_within_step:
1776
            warnings.warn("To get the last learning rate computed by the scheduler, "
1777
                          "please use `get_last_lr()`.", UserWarning)
1778

1779
        lrs = []
1780
        step_num = self.last_epoch
1781

1782
        if step_num > self.total_steps:
1783
            raise ValueError("Tried to step {} times. The specified number of total steps is {}"
1784
                             .format(step_num, self.total_steps))
1785

1786
        for group in self.optimizer.param_groups:
1787
            start_step = 0
1788
            for i, phase in enumerate(self._schedule_phases):
1789
                end_step = phase['end_step']
1790
                if step_num <= end_step or i == len(self._schedule_phases) - 1:
1791
                    pct = (step_num - start_step) / (end_step - start_step)
1792
                    computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
1793
                    if self.cycle_momentum:
1794
                        computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
1795
                    break
1796
                start_step = phase['end_step']
1797

1798
            lrs.append(computed_lr)
1799
            if self.cycle_momentum:
1800
                if self.use_beta1:
1801
                    group['betas'] = (computed_momentum, *group['betas'][1:])
1802
                else:
1803
                    group['momentum'] = computed_momentum
1804

1805
        return lrs
1806

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.