4
from functools import wraps, partial
7
from collections import Counter
8
from bisect import bisect_right
10
from .optimizer import Optimizer
12
__all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
13
'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
14
'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
16
EPOCH_DEPRECATION_WARNING = (
17
"The epoch parameter in `scheduler.step()` was not necessary and is being "
18
"deprecated where possible. Please use `scheduler.step()` to step the "
19
"scheduler. During the deprecation, if epoch is different from None, the "
20
"closed form is used instead of the new chainable form, where available. "
21
"Please open an issue if you are unable to replicate your use case: "
22
"https://github.com/pytorch/pytorch/issues/new/choose."
25
def _check_verbose_deprecated_warning(verbose):
26
"""Raises a warning when verbose is not the default value."""
27
if verbose != "deprecated":
28
warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
29
"to access the learning rate.", UserWarning)
35
def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
38
if not isinstance(optimizer, Optimizer):
39
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
40
self.optimizer = optimizer
42
# Initialize epoch and base learning rates
44
for group in optimizer.param_groups:
45
group.setdefault('initial_lr', group['lr'])
47
for i, group in enumerate(optimizer.param_groups):
48
if 'initial_lr' not in group:
49
raise KeyError("param 'initial_lr' is not specified "
50
f"in param_groups[{i}] when resuming an optimizer")
51
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
52
self.last_epoch = last_epoch
54
# Following https://github.com/pytorch/pytorch/issues/20124
55
# We would like to ensure that `lr_scheduler.step()` is called after
57
def with_counter(method):
58
if getattr(method, '_with_counter', False):
59
# `optimizer.step()` has already been replaced, return.
62
# Keep a weak reference to the optimizer instance to prevent
64
instance_ref = weakref.ref(method.__self__)
65
# Get the unbound method for the same purpose.
66
func = method.__func__
67
cls = instance_ref().__class__
71
def wrapper(*args, **kwargs):
72
instance = instance_ref()
73
instance._step_count += 1
74
wrapped = func.__get__(instance, cls)
75
return wrapped(*args, **kwargs)
77
# Note that the returned function here is no longer a bound method,
78
# so attributes like `__func__` and `__self__` no longer exist.
79
wrapper._with_counter = True
82
self.optimizer.step = with_counter(self.optimizer.step)
83
self.verbose = _check_verbose_deprecated_warning(verbose)
87
def _initial_step(self):
88
"""Initialize step counts and performs a step"""
89
self.optimizer._step_count = 0
94
"""Returns the state of the scheduler as a :class:`dict`.
96
It contains an entry for every variable in self.__dict__ which
99
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
101
def load_state_dict(self, state_dict):
102
"""Loads the schedulers state.
105
state_dict (dict): scheduler state. Should be an object returned
106
from a call to :meth:`state_dict`.
108
self.__dict__.update(state_dict)
110
def get_last_lr(self):
111
""" Return last computed learning rate by current scheduler.
116
# Compute learning rate using chainable form of the scheduler
117
raise NotImplementedError
119
def print_lr(self, is_verbose, group, lr, epoch=None):
120
"""Display the current learning rate.
124
print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
126
epoch_str = ("%.2f" if isinstance(epoch, float) else
128
print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
131
def step(self, epoch=None):
132
# Raise a warning if old pattern is detected
133
# https://github.com/pytorch/pytorch/issues/20124
134
if self._step_count == 1:
135
if not hasattr(self.optimizer.step, "_with_counter"):
136
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
137
"initialization. Please, make sure to call `optimizer.step()` before "
138
"`lr_scheduler.step()`. See more details at "
139
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
141
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
142
elif self.optimizer._step_count < 1:
143
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
144
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
145
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
146
"will result in PyTorch skipping the first value of the learning rate schedule. "
147
"See more details at "
148
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
149
self._step_count += 1
151
with _enable_get_lr_call(self):
154
values = self.get_lr()
156
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
157
self.last_epoch = epoch
158
if hasattr(self, "_get_closed_form_lr"):
159
values = self._get_closed_form_lr()
161
values = self.get_lr()
163
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
164
param_group, lr = data
165
param_group['lr'] = lr
167
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
170
# Including _LRScheduler for backwards compatibility
171
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
172
class _LRScheduler(LRScheduler):
176
class _enable_get_lr_call:
178
def __init__(self, o):
182
self.o._get_lr_called_within_step = True
185
def __exit__(self, type, value, traceback):
186
self.o._get_lr_called_within_step = False
189
class LambdaLR(LRScheduler):
190
"""Sets the learning rate of each parameter group to the initial lr
191
times a given function. When last_epoch=-1, sets initial lr as lr.
194
optimizer (Optimizer): Wrapped optimizer.
195
lr_lambda (function or list): A function which computes a multiplicative
196
factor given an integer parameter epoch, or a list of such
197
functions, one for each group in optimizer.param_groups.
198
last_epoch (int): The index of last epoch. Default: -1.
199
verbose (bool): If ``True``, prints a message to stdout for
200
each update. Default: ``False``.
203
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
207
>>> # xdoctest: +SKIP
208
>>> # Assuming optimizer has two groups.
209
>>> lambda1 = lambda epoch: epoch // 30
210
>>> lambda2 = lambda epoch: 0.95 ** epoch
211
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
212
>>> for epoch in range(100):
218
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
219
self.optimizer = optimizer
221
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
222
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
224
if len(lr_lambda) != len(optimizer.param_groups):
225
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
226
self.lr_lambdas = list(lr_lambda)
227
super().__init__(optimizer, last_epoch, verbose)
229
def state_dict(self):
230
"""Returns the state of the scheduler as a :class:`dict`.
232
It contains an entry for every variable in self.__dict__ which
233
is not the optimizer.
234
The learning rate lambda functions will only be saved if they are callable objects
235
and not if they are functions or lambdas.
237
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
240
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
241
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
243
for idx, fn in enumerate(self.lr_lambdas):
244
if not isinstance(fn, types.FunctionType):
245
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
249
def load_state_dict(self, state_dict):
250
"""Loads the schedulers state.
252
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
255
state_dict (dict): scheduler state. Should be an object returned
256
from a call to :meth:`state_dict`.
259
lr_lambdas = state_dict.pop('lr_lambdas')
260
self.__dict__.update(state_dict)
261
# Restore state_dict keys in order to prevent side effects
262
# https://github.com/pytorch/pytorch/issues/32756
263
state_dict['lr_lambdas'] = lr_lambdas
265
for idx, fn in enumerate(lr_lambdas):
267
self.lr_lambdas[idx].__dict__.update(fn)
270
if not self._get_lr_called_within_step:
271
warnings.warn("To get the last learning rate computed by the scheduler, "
272
"please use `get_last_lr()`.")
274
return [base_lr * lmbda(self.last_epoch)
275
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
278
class MultiplicativeLR(LRScheduler):
279
"""Multiply the learning rate of each parameter group by the factor given
280
in the specified function. When last_epoch=-1, sets initial lr as lr.
283
optimizer (Optimizer): Wrapped optimizer.
284
lr_lambda (function or list): A function which computes a multiplicative
285
factor given an integer parameter epoch, or a list of such
286
functions, one for each group in optimizer.param_groups.
287
last_epoch (int): The index of last epoch. Default: -1.
288
verbose (bool): If ``True``, prints a message to stdout for
289
each update. Default: ``False``.
292
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
296
>>> # xdoctest: +SKIP
297
>>> lmbda = lambda epoch: 0.95
298
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
299
>>> for epoch in range(100):
305
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
306
self.optimizer = optimizer
308
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
309
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
311
if len(lr_lambda) != len(optimizer.param_groups):
312
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
313
self.lr_lambdas = list(lr_lambda)
314
super().__init__(optimizer, last_epoch, verbose)
316
def state_dict(self):
317
"""Returns the state of the scheduler as a :class:`dict`.
319
It contains an entry for every variable in self.__dict__ which
320
is not the optimizer.
321
The learning rate lambda functions will only be saved if they are callable objects
322
and not if they are functions or lambdas.
324
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
325
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
327
for idx, fn in enumerate(self.lr_lambdas):
328
if not isinstance(fn, types.FunctionType):
329
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
333
def load_state_dict(self, state_dict):
334
"""Loads the schedulers state.
337
state_dict (dict): scheduler state. Should be an object returned
338
from a call to :meth:`state_dict`.
340
lr_lambdas = state_dict.pop('lr_lambdas')
341
self.__dict__.update(state_dict)
342
# Restore state_dict keys in order to prevent side effects
343
# https://github.com/pytorch/pytorch/issues/32756
344
state_dict['lr_lambdas'] = lr_lambdas
346
for idx, fn in enumerate(lr_lambdas):
348
self.lr_lambdas[idx].__dict__.update(fn)
351
if not self._get_lr_called_within_step:
352
warnings.warn("To get the last learning rate computed by the scheduler, "
353
"please use `get_last_lr()`.", UserWarning)
355
if self.last_epoch > 0:
356
return [group['lr'] * lmbda(self.last_epoch)
357
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
359
return [group['lr'] for group in self.optimizer.param_groups]
362
class StepLR(LRScheduler):
363
"""Decays the learning rate of each parameter group by gamma every
364
step_size epochs. Notice that such decay can happen simultaneously with
365
other changes to the learning rate from outside this scheduler. When
366
last_epoch=-1, sets initial lr as lr.
369
optimizer (Optimizer): Wrapped optimizer.
370
step_size (int): Period of learning rate decay.
371
gamma (float): Multiplicative factor of learning rate decay.
373
last_epoch (int): The index of last epoch. Default: -1.
374
verbose (bool): If ``True``, prints a message to stdout for
375
each update. Default: ``False``.
378
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
382
>>> # xdoctest: +SKIP
383
>>> # Assuming optimizer uses lr = 0.05 for all groups
384
>>> # lr = 0.05 if epoch < 30
385
>>> # lr = 0.005 if 30 <= epoch < 60
386
>>> # lr = 0.0005 if 60 <= epoch < 90
388
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
389
>>> for epoch in range(100):
395
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose="deprecated"):
396
self.step_size = step_size
398
super().__init__(optimizer, last_epoch, verbose)
401
if not self._get_lr_called_within_step:
402
warnings.warn("To get the last learning rate computed by the scheduler, "
403
"please use `get_last_lr()`.", UserWarning)
405
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
406
return [group['lr'] for group in self.optimizer.param_groups]
407
return [group['lr'] * self.gamma
408
for group in self.optimizer.param_groups]
410
def _get_closed_form_lr(self):
411
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
412
for base_lr in self.base_lrs]
415
class MultiStepLR(LRScheduler):
416
"""Decays the learning rate of each parameter group by gamma once the
417
number of epoch reaches one of the milestones. Notice that such decay can
418
happen simultaneously with other changes to the learning rate from outside
419
this scheduler. When last_epoch=-1, sets initial lr as lr.
422
optimizer (Optimizer): Wrapped optimizer.
423
milestones (list): List of epoch indices. Must be increasing.
424
gamma (float): Multiplicative factor of learning rate decay.
426
last_epoch (int): The index of last epoch. Default: -1.
427
verbose (bool): If ``True``, prints a message to stdout for
428
each update. Default: ``False``.
431
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
435
>>> # xdoctest: +SKIP
436
>>> # Assuming optimizer uses lr = 0.05 for all groups
437
>>> # lr = 0.05 if epoch < 30
438
>>> # lr = 0.005 if 30 <= epoch < 80
439
>>> # lr = 0.0005 if epoch >= 80
440
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
441
>>> for epoch in range(100):
447
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose="deprecated"):
448
self.milestones = Counter(milestones)
450
super().__init__(optimizer, last_epoch, verbose)
453
if not self._get_lr_called_within_step:
454
warnings.warn("To get the last learning rate computed by the scheduler, "
455
"please use `get_last_lr()`.", UserWarning)
457
if self.last_epoch not in self.milestones:
458
return [group['lr'] for group in self.optimizer.param_groups]
459
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
460
for group in self.optimizer.param_groups]
462
def _get_closed_form_lr(self):
463
milestones = sorted(self.milestones.elements())
464
return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
465
for base_lr in self.base_lrs]
468
class ConstantLR(LRScheduler):
469
"""Decays the learning rate of each parameter group by a small constant factor until the
470
number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
471
happen simultaneously with other changes to the learning rate from outside this scheduler.
472
When last_epoch=-1, sets initial lr as lr.
475
optimizer (Optimizer): Wrapped optimizer.
476
factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
477
total_iters (int): The number of steps that the scheduler decays the learning rate.
479
last_epoch (int): The index of the last epoch. Default: -1.
480
verbose (bool): If ``True``, prints a message to stdout for
481
each update. Default: ``False``.
484
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
488
>>> # xdoctest: +SKIP
489
>>> # Assuming optimizer uses lr = 0.05 for all groups
490
>>> # lr = 0.025 if epoch == 0
491
>>> # lr = 0.025 if epoch == 1
492
>>> # lr = 0.025 if epoch == 2
493
>>> # lr = 0.025 if epoch == 3
494
>>> # lr = 0.05 if epoch >= 4
495
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
496
>>> for epoch in range(100):
502
def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose="deprecated"):
503
if factor > 1.0 or factor < 0:
504
raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
507
self.total_iters = total_iters
508
super().__init__(optimizer, last_epoch, verbose)
511
if not self._get_lr_called_within_step:
512
warnings.warn("To get the last learning rate computed by the scheduler, "
513
"please use `get_last_lr()`.", UserWarning)
515
if self.last_epoch == 0:
516
return [group['lr'] * self.factor for group in self.optimizer.param_groups]
518
if self.last_epoch != self.total_iters:
519
return [group['lr'] for group in self.optimizer.param_groups]
521
return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
523
def _get_closed_form_lr(self):
524
return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
525
for base_lr in self.base_lrs]
528
class LinearLR(LRScheduler):
529
"""Decays the learning rate of each parameter group by linearly changing small
530
multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
531
Notice that such decay can happen simultaneously with other changes to the learning rate
532
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
535
optimizer (Optimizer): Wrapped optimizer.
536
start_factor (float): The number we multiply learning rate in the first epoch.
537
The multiplication factor changes towards end_factor in the following epochs.
539
end_factor (float): The number we multiply learning rate at the end of linear changing
540
process. Default: 1.0.
541
total_iters (int): The number of iterations that multiplicative factor reaches to 1.
543
last_epoch (int): The index of the last epoch. Default: -1.
544
verbose (bool): If ``True``, prints a message to stdout for
545
each update. Default: ``False``.
548
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
552
>>> # xdoctest: +SKIP
553
>>> # Assuming optimizer uses lr = 0.05 for all groups
554
>>> # lr = 0.025 if epoch == 0
555
>>> # lr = 0.03125 if epoch == 1
556
>>> # lr = 0.0375 if epoch == 2
557
>>> # lr = 0.04375 if epoch == 3
558
>>> # lr = 0.05 if epoch >= 4
559
>>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
560
>>> for epoch in range(100):
566
def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
567
verbose="deprecated"):
568
if start_factor > 1.0 or start_factor <= 0:
569
raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
571
if end_factor > 1.0 or end_factor < 0:
572
raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
574
self.start_factor = start_factor
575
self.end_factor = end_factor
576
self.total_iters = total_iters
577
super().__init__(optimizer, last_epoch, verbose)
580
if not self._get_lr_called_within_step:
581
warnings.warn("To get the last learning rate computed by the scheduler, "
582
"please use `get_last_lr()`.", UserWarning)
584
if self.last_epoch == 0:
585
return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
587
if self.last_epoch > self.total_iters:
588
return [group['lr'] for group in self.optimizer.param_groups]
590
return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
591
(self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
592
for group in self.optimizer.param_groups]
594
def _get_closed_form_lr(self):
595
return [base_lr * (self.start_factor +
596
(self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
597
for base_lr in self.base_lrs]
600
class ExponentialLR(LRScheduler):
601
"""Decays the learning rate of each parameter group by gamma every epoch.
602
When last_epoch=-1, sets initial lr as lr.
605
optimizer (Optimizer): Wrapped optimizer.
606
gamma (float): Multiplicative factor of learning rate decay.
607
last_epoch (int): The index of last epoch. Default: -1.
608
verbose (bool): If ``True``, prints a message to stdout for
609
each update. Default: ``False``.
612
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
616
def __init__(self, optimizer, gamma, last_epoch=-1, verbose="deprecated"):
618
super().__init__(optimizer, last_epoch, verbose)
621
if not self._get_lr_called_within_step:
622
warnings.warn("To get the last learning rate computed by the scheduler, "
623
"please use `get_last_lr()`.", UserWarning)
625
if self.last_epoch == 0:
626
return [group['lr'] for group in self.optimizer.param_groups]
627
return [group['lr'] * self.gamma
628
for group in self.optimizer.param_groups]
630
def _get_closed_form_lr(self):
631
return [base_lr * self.gamma ** self.last_epoch
632
for base_lr in self.base_lrs]
635
class SequentialLR(LRScheduler):
636
"""Receives the list of schedulers that is expected to be called sequentially during
637
optimization process and milestone points that provides exact intervals to reflect
638
which scheduler is supposed to be called at a given epoch.
641
optimizer (Optimizer): Wrapped optimizer.
642
schedulers (list): List of chained schedulers.
643
milestones (list): List of integers that reflects milestone points.
644
last_epoch (int): The index of last epoch. Default: -1.
645
verbose (bool): Does nothing.
648
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
652
>>> # xdoctest: +SKIP
653
>>> # Assuming optimizer uses lr = 1. for all groups
654
>>> # lr = 0.1 if epoch == 0
655
>>> # lr = 0.1 if epoch == 1
656
>>> # lr = 0.9 if epoch == 2
657
>>> # lr = 0.81 if epoch == 3
658
>>> # lr = 0.729 if epoch == 4
659
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
660
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
661
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
662
>>> for epoch in range(100):
668
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose="deprecated"):
669
for scheduler_idx in range(len(schedulers)):
670
if schedulers[scheduler_idx].optimizer != optimizer:
672
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
673
f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
676
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
678
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
679
f"got schedulers at index {0} and {scheduler_idx} to be different."
681
if (len(milestones) != len(schedulers) - 1):
683
"Sequential Schedulers expects number of schedulers provided to be one more "
684
f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
685
f"number of milestones to be equal to {len(milestones)}"
687
_check_verbose_deprecated_warning(verbose)
688
self._schedulers = schedulers
689
self._milestones = milestones
690
self.last_epoch = last_epoch + 1
691
self.optimizer = optimizer
693
# Reset learning rates back to initial values
694
for group in self.optimizer.param_groups:
695
group["lr"] = group["initial_lr"]
697
# "Undo" the step performed by other schedulers
698
for scheduler in self._schedulers:
699
scheduler.last_epoch -= 1
701
# Perform the initial step for only the first scheduler
702
self._schedulers[0]._initial_step()
704
self._last_lr = schedulers[0].get_last_lr()
708
idx = bisect_right(self._milestones, self.last_epoch)
709
scheduler = self._schedulers[idx]
710
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
715
self._last_lr = scheduler.get_last_lr()
717
def state_dict(self):
718
"""Returns the state of the scheduler as a :class:`dict`.
720
It contains an entry for every variable in self.__dict__ which
721
is not the optimizer.
722
The wrapped scheduler states will also be saved.
724
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
725
state_dict['_schedulers'] = [None] * len(self._schedulers)
727
for idx, s in enumerate(self._schedulers):
728
state_dict['_schedulers'][idx] = s.state_dict()
732
def load_state_dict(self, state_dict):
733
"""Loads the schedulers state.
736
state_dict (dict): scheduler state. Should be an object returned
737
from a call to :meth:`state_dict`.
739
_schedulers = state_dict.pop('_schedulers')
740
self.__dict__.update(state_dict)
741
# Restore state_dict keys in order to prevent side effects
742
# https://github.com/pytorch/pytorch/issues/32756
743
state_dict['_schedulers'] = _schedulers
745
for idx, s in enumerate(_schedulers):
746
self._schedulers[idx].load_state_dict(s)
749
class PolynomialLR(LRScheduler):
750
"""Decays the learning rate of each parameter group using a polynomial function
751
in the given total_iters. When last_epoch=-1, sets initial lr as lr.
754
optimizer (Optimizer): Wrapped optimizer.
755
total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
756
power (float): The power of the polynomial. Default: 1.0.
757
verbose (bool): If ``True``, prints a message to stdout for
758
each update. Default: ``False``.
761
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
765
>>> # xdoctest: +SKIP("undefined vars")
766
>>> # Assuming optimizer uses lr = 0.001 for all groups
767
>>> # lr = 0.001 if epoch == 0
768
>>> # lr = 0.00075 if epoch == 1
769
>>> # lr = 0.00050 if epoch == 2
770
>>> # lr = 0.00025 if epoch == 3
771
>>> # lr = 0.0 if epoch >= 4
772
>>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
773
>>> for epoch in range(100):
778
def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose="deprecated"):
779
self.total_iters = total_iters
781
super().__init__(optimizer, last_epoch, verbose)
784
if not self._get_lr_called_within_step:
785
warnings.warn("To get the last learning rate computed by the scheduler, "
786
"please use `get_last_lr()`.", UserWarning)
788
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
789
return [group["lr"] for group in self.optimizer.param_groups]
791
decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
792
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
794
def _get_closed_form_lr(self):
797
base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
799
for base_lr in self.base_lrs
803
class CosineAnnealingLR(LRScheduler):
804
r"""Set the learning rate of each parameter group using a cosine annealing
805
schedule, where :math:`\eta_{max}` is set to the initial lr and
806
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
810
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
811
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
812
& T_{cur} \neq (2k+1)T_{max}; \\
813
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
814
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
815
& T_{cur} = (2k+1)T_{max}.
818
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
819
is defined recursively, the learning rate can be simultaneously modified
820
outside this scheduler by other operators. If the learning rate is set
821
solely by this scheduler, the learning rate at each step becomes:
824
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
825
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
827
It has been proposed in
828
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
829
implements the cosine annealing part of SGDR, and not the restarts.
832
optimizer (Optimizer): Wrapped optimizer.
833
T_max (int): Maximum number of iterations.
834
eta_min (float): Minimum learning rate. Default: 0.
835
last_epoch (int): The index of last epoch. Default: -1.
836
verbose (bool): If ``True``, prints a message to stdout for
837
each update. Default: ``False``.
840
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
843
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
844
https://arxiv.org/abs/1608.03983
847
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose="deprecated"):
849
self.eta_min = eta_min
850
super().__init__(optimizer, last_epoch, verbose)
853
if not self._get_lr_called_within_step:
854
warnings.warn("To get the last learning rate computed by the scheduler, "
855
"please use `get_last_lr()`.", UserWarning)
857
if self.last_epoch == 0:
858
return [group['lr'] for group in self.optimizer.param_groups]
859
elif self._step_count == 1 and self.last_epoch > 0:
860
return [self.eta_min + (base_lr - self.eta_min) *
861
(1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
862
for base_lr, group in
863
zip(self.base_lrs, self.optimizer.param_groups)]
864
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
865
return [group['lr'] + (base_lr - self.eta_min) *
866
(1 - math.cos(math.pi / self.T_max)) / 2
867
for base_lr, group in
868
zip(self.base_lrs, self.optimizer.param_groups)]
869
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
870
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
871
(group['lr'] - self.eta_min) + self.eta_min
872
for group in self.optimizer.param_groups]
874
def _get_closed_form_lr(self):
875
return [self.eta_min + (base_lr - self.eta_min) *
876
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
877
for base_lr in self.base_lrs]
880
class ChainedScheduler(LRScheduler):
881
"""Chains list of learning rate schedulers. It takes a list of chainable learning
882
rate schedulers and performs consecutive step() functions belonging to them by just
886
schedulers (list): List of chained schedulers.
889
>>> # xdoctest: +SKIP
890
>>> # Assuming optimizer uses lr = 1. for all groups
891
>>> # lr = 0.09 if epoch == 0
892
>>> # lr = 0.081 if epoch == 1
893
>>> # lr = 0.729 if epoch == 2
894
>>> # lr = 0.6561 if epoch == 3
895
>>> # lr = 0.59049 if epoch >= 4
896
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
897
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
898
>>> scheduler = ChainedScheduler([scheduler1, scheduler2])
899
>>> for epoch in range(100):
905
def __init__(self, schedulers):
906
for scheduler_idx in range(1, len(schedulers)):
907
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
909
"ChainedScheduler expects all schedulers to belong to the same optimizer, but "
910
f"got schedulers at index {0} and {scheduler_idx} to be different"
912
self._schedulers = list(schedulers)
913
self.optimizer = schedulers[0].optimizer
914
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
917
for scheduler in self._schedulers:
919
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
921
def state_dict(self):
922
"""Returns the state of the scheduler as a :class:`dict`.
924
It contains an entry for every variable in self.__dict__ which
925
is not the optimizer.
926
The wrapped scheduler states will also be saved.
928
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
929
state_dict['_schedulers'] = [None] * len(self._schedulers)
931
for idx, s in enumerate(self._schedulers):
932
state_dict['_schedulers'][idx] = s.state_dict()
936
def load_state_dict(self, state_dict):
937
"""Loads the schedulers state.
940
state_dict (dict): scheduler state. Should be an object returned
941
from a call to :meth:`state_dict`.
943
_schedulers = state_dict.pop('_schedulers')
944
self.__dict__.update(state_dict)
945
# Restore state_dict keys in order to prevent side effects
946
# https://github.com/pytorch/pytorch/issues/32756
947
state_dict['_schedulers'] = _schedulers
949
for idx, s in enumerate(_schedulers):
950
self._schedulers[idx].load_state_dict(s)
953
class ReduceLROnPlateau(LRScheduler):
954
"""Reduce learning rate when a metric has stopped improving.
955
Models often benefit from reducing the learning rate by a factor
956
of 2-10 once learning stagnates. This scheduler reads a metrics
957
quantity and if no improvement is seen for a 'patience' number
958
of epochs, the learning rate is reduced.
961
optimizer (Optimizer): Wrapped optimizer.
962
mode (str): One of `min`, `max`. In `min` mode, lr will
963
be reduced when the quantity monitored has stopped
964
decreasing; in `max` mode it will be reduced when the
965
quantity monitored has stopped increasing. Default: 'min'.
966
factor (float): Factor by which the learning rate will be
967
reduced. new_lr = lr * factor. Default: 0.1.
968
patience (int): The number of allowed epochs with no improvement after
969
which the learning rate will be reduced.
970
For example, consider the case of having no patience (`patience = 0`).
971
In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
972
In the second epoch, if the performance is worse than the baseline,
973
we have what is considered an intolerable epoch.
974
Since the count of intolerable epochs (1) is greater than the patience level (0),
975
the learning rate is reduced at the end of this epoch.
976
From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
977
if the performance is worse than the baseline. If the performance improves or remains the same,
978
the learning rate is not adjusted.
980
threshold (float): Threshold for measuring the new optimum,
981
to only focus on significant changes. Default: 1e-4.
982
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
983
dynamic_threshold = best * ( 1 + threshold ) in 'max'
984
mode or best * ( 1 - threshold ) in `min` mode.
985
In `abs` mode, dynamic_threshold = best + threshold in
986
`max` mode or best - threshold in `min` mode. Default: 'rel'.
987
cooldown (int): Number of epochs to wait before resuming
988
normal operation after lr has been reduced. Default: 0.
989
min_lr (float or list): A scalar or a list of scalars. A
990
lower bound on the learning rate of all param groups
991
or each group respectively. Default: 0.
992
eps (float): Minimal decay applied to lr. If the difference
993
between new and old lr is smaller than eps, the update is
994
ignored. Default: 1e-8.
995
verbose (bool): If ``True``, prints a message to stdout for
996
each update. Default: ``False``.
999
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1003
>>> # xdoctest: +SKIP
1004
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1005
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
1006
>>> for epoch in range(10):
1008
>>> val_loss = validate(...)
1009
>>> # Note that step should be called after validate()
1010
>>> scheduler.step(val_loss)
1013
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
1014
threshold=1e-4, threshold_mode='rel', cooldown=0,
1015
min_lr=0, eps=1e-8, verbose="deprecated"):
1018
raise ValueError('Factor should be < 1.0.')
1019
self.factor = factor
1022
if not isinstance(optimizer, Optimizer):
1023
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1024
self.optimizer = optimizer
1026
if isinstance(min_lr, (list, tuple)):
1027
if len(min_lr) != len(optimizer.param_groups):
1028
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
1029
self.min_lrs = list(min_lr)
1031
self.min_lrs = [min_lr] * len(optimizer.param_groups)
1033
self.patience = patience
1035
self.verbose = _check_verbose_deprecated_warning(verbose)
1036
self.cooldown = cooldown
1037
self.cooldown_counter = 0
1039
self.threshold = threshold
1040
self.threshold_mode = threshold_mode
1042
self.num_bad_epochs = None
1043
self.mode_worse = None # the worse value for the chosen mode
1046
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1047
self._init_is_better(mode=mode, threshold=threshold,
1048
threshold_mode=threshold_mode)
1052
"""Resets num_bad_epochs counter and cooldown counter."""
1053
self.best = self.mode_worse
1054
self.cooldown_counter = 0
1055
self.num_bad_epochs = 0
1057
def step(self, metrics, epoch=None):
1058
# convert `metrics` to float, in case it's a zero-dim Tensor
1059
current = float(metrics)
1061
epoch = self.last_epoch + 1
1063
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
1064
self.last_epoch = epoch
1066
if self.is_better(current, self.best):
1068
self.num_bad_epochs = 0
1070
self.num_bad_epochs += 1
1072
if self.in_cooldown:
1073
self.cooldown_counter -= 1
1074
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
1076
if self.num_bad_epochs > self.patience:
1077
self._reduce_lr(epoch)
1078
self.cooldown_counter = self.cooldown
1079
self.num_bad_epochs = 0
1081
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1083
def _reduce_lr(self, epoch):
1084
for i, param_group in enumerate(self.optimizer.param_groups):
1085
old_lr = float(param_group['lr'])
1086
new_lr = max(old_lr * self.factor, self.min_lrs[i])
1087
if old_lr - new_lr > self.eps:
1088
param_group['lr'] = new_lr
1091
def in_cooldown(self):
1092
return self.cooldown_counter > 0
1094
def is_better(self, a, best):
1095
if self.mode == 'min' and self.threshold_mode == 'rel':
1096
rel_epsilon = 1. - self.threshold
1097
return a < best * rel_epsilon
1099
elif self.mode == 'min' and self.threshold_mode == 'abs':
1100
return a < best - self.threshold
1102
elif self.mode == 'max' and self.threshold_mode == 'rel':
1103
rel_epsilon = self.threshold + 1.
1104
return a > best * rel_epsilon
1106
else: # mode == 'max' and epsilon_mode == 'abs':
1107
return a > best + self.threshold
1109
def _init_is_better(self, mode, threshold, threshold_mode):
1110
if mode not in {'min', 'max'}:
1111
raise ValueError('mode ' + mode + ' is unknown!')
1112
if threshold_mode not in {'rel', 'abs'}:
1113
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
1116
self.mode_worse = inf
1117
else: # mode == 'max':
1118
self.mode_worse = -inf
1121
self.threshold = threshold
1122
self.threshold_mode = threshold_mode
1124
def state_dict(self):
1125
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
1127
def load_state_dict(self, state_dict):
1128
self.__dict__.update(state_dict)
1129
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
1132
class CyclicLR(LRScheduler):
1133
r"""Sets the learning rate of each parameter group according to
1134
cyclical learning rate policy (CLR). The policy cycles the learning
1135
rate between two boundaries with a constant frequency, as detailed in
1136
the paper `Cyclical Learning Rates for Training Neural Networks`_.
1137
The distance between the two boundaries can be scaled on a per-iteration
1140
Cyclical learning rate policy changes the learning rate after every batch.
1141
`step` should be called after a batch has been used for training.
1143
This class has three built-in policies, as put forth in the paper:
1145
* "triangular": A basic triangular cycle without amplitude scaling.
1146
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
1147
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
1148
at each cycle iteration.
1150
This implementation was adapted from the github repo: `bckenstler/CLR`_
1153
optimizer (Optimizer): Wrapped optimizer.
1154
base_lr (float or list): Initial learning rate which is the
1155
lower boundary in the cycle for each parameter group.
1156
max_lr (float or list): Upper learning rate boundaries in the cycle
1157
for each parameter group. Functionally,
1158
it defines the cycle amplitude (max_lr - base_lr).
1159
The lr at any cycle is the sum of base_lr
1160
and some scaling of the amplitude; therefore
1161
max_lr may not actually be reached depending on
1163
step_size_up (int): Number of training iterations in the
1164
increasing half of a cycle. Default: 2000
1165
step_size_down (int): Number of training iterations in the
1166
decreasing half of a cycle. If step_size_down is None,
1167
it is set to step_size_up. Default: None
1168
mode (str): One of {triangular, triangular2, exp_range}.
1169
Values correspond to policies detailed above.
1170
If scale_fn is not None, this argument is ignored.
1171
Default: 'triangular'
1172
gamma (float): Constant in 'exp_range' scaling function:
1173
gamma**(cycle iterations)
1175
scale_fn (function): Custom scaling policy defined by a single
1176
argument lambda function, where
1177
0 <= scale_fn(x) <= 1 for all x >= 0.
1178
If specified, then 'mode' is ignored.
1180
scale_mode (str): {'cycle', 'iterations'}.
1181
Defines whether scale_fn is evaluated on
1182
cycle number or cycle iterations (training
1183
iterations since start of cycle).
1185
cycle_momentum (bool): If ``True``, momentum is cycled inversely
1186
to learning rate between 'base_momentum' and 'max_momentum'.
1188
base_momentum (float or list): Lower momentum boundaries in the cycle
1189
for each parameter group. Note that momentum is cycled inversely
1190
to learning rate; at the peak of a cycle, momentum is
1191
'base_momentum' and learning rate is 'max_lr'.
1193
max_momentum (float or list): Upper momentum boundaries in the cycle
1194
for each parameter group. Functionally,
1195
it defines the cycle amplitude (max_momentum - base_momentum).
1196
The momentum at any cycle is the difference of max_momentum
1197
and some scaling of the amplitude; therefore
1198
base_momentum may not actually be reached depending on
1199
scaling function. Note that momentum is cycled inversely
1200
to learning rate; at the start of a cycle, momentum is 'max_momentum'
1201
and learning rate is 'base_lr'
1203
last_epoch (int): The index of the last batch. This parameter is used when
1204
resuming a training job. Since `step()` should be invoked after each
1205
batch instead of after each epoch, this number represents the total
1206
number of *batches* computed, not the total number of epochs computed.
1207
When last_epoch=-1, the schedule is started from the beginning.
1209
verbose (bool): If ``True``, prints a message to stdout for
1210
each update. Default: ``False``.
1213
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1217
>>> # xdoctest: +SKIP
1218
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1219
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
1220
>>> data_loader = torch.utils.data.DataLoader(...)
1221
>>> for epoch in range(10):
1222
>>> for batch in data_loader:
1223
>>> train_batch(...)
1224
>>> scheduler.step()
1227
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
1228
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
1236
step_size_down=None,
1241
cycle_momentum=True,
1245
verbose="deprecated"):
1248
if not isinstance(optimizer, Optimizer):
1249
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1250
self.optimizer = optimizer
1252
base_lrs = self._format_param('base_lr', optimizer, base_lr)
1253
if last_epoch == -1:
1254
for lr, group in zip(base_lrs, optimizer.param_groups):
1257
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
1259
step_size_up = float(step_size_up)
1260
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
1261
self.total_size = step_size_up + step_size_down
1262
self.step_ratio = step_size_up / self.total_size
1264
if mode not in ['triangular', 'triangular2', 'exp_range'] \
1265
and scale_fn is None:
1266
raise ValueError('mode is invalid and scale_fn is None')
1271
self._scale_fn_ref = None
1272
self._scale_fn_custom = scale_fn
1273
self.scale_mode = scale_mode
1274
self._init_scale_fn()
1276
self.cycle_momentum = cycle_momentum
1278
if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
1279
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
1281
self.use_beta1 = 'betas' in self.optimizer.defaults
1282
self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
1283
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
1284
if last_epoch == -1:
1285
for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
1287
group['betas'] = (m_momentum, *group['betas'][1:])
1289
group['momentum'] = m_momentum
1290
group['max_momentum'] = m_momentum
1291
group['base_momentum'] = b_momentum
1293
super().__init__(optimizer, last_epoch, verbose)
1294
self.base_lrs = base_lrs
1296
def _init_scale_fn(self):
1297
if self._scale_fn_custom is not None:
1299
if self.mode == 'triangular':
1300
self._scale_fn_ref = self._triangular_scale_fn
1301
self.scale_mode = 'cycle'
1302
elif self.mode == 'triangular2':
1303
self._scale_fn_ref = self._triangular2_scale_fn
1304
self.scale_mode = 'cycle'
1305
elif self.mode == 'exp_range':
1306
self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
1307
self.scale_mode = 'iterations'
1309
def _format_param(self, name, optimizer, param):
1310
"""Return correctly formatted lr/momentum for each param group."""
1311
if isinstance(param, (list, tuple)):
1312
if len(param) != len(optimizer.param_groups):
1313
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
1316
return [param] * len(optimizer.param_groups)
1318
def scale_fn(self, x):
1319
if self._scale_fn_custom is not None:
1320
return self._scale_fn_custom(x)
1322
return self._scale_fn_ref(x) # static method
1325
def _triangular_scale_fn(x):
1329
def _triangular2_scale_fn(x):
1330
return 1 / (2. ** (x - 1))
1333
def _exp_range_scale_fn(gamma, x):
1337
"""Calculates the learning rate at batch index. This function treats
1338
`self.last_epoch` as the last batch index.
1340
If `self.cycle_momentum` is ``True``, this function has a side effect of
1341
updating the optimizer's momentum.
1344
if not self._get_lr_called_within_step:
1345
warnings.warn("To get the last learning rate computed by the scheduler, "
1346
"please use `get_last_lr()`.", UserWarning)
1348
cycle = math.floor(1 + self.last_epoch / self.total_size)
1349
x = 1. + self.last_epoch / self.total_size - cycle
1350
if x <= self.step_ratio:
1351
scale_factor = x / self.step_ratio
1353
scale_factor = (x - 1) / (self.step_ratio - 1)
1356
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
1357
base_height = (max_lr - base_lr) * scale_factor
1358
if self.scale_mode == 'cycle':
1359
lr = base_lr + base_height * self.scale_fn(cycle)
1361
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
1364
if self.cycle_momentum:
1366
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
1367
base_height = (max_momentum - base_momentum) * scale_factor
1368
if self.scale_mode == 'cycle':
1369
momentum = max_momentum - base_height * self.scale_fn(cycle)
1371
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
1372
momentums.append(momentum)
1373
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
1375
param_group['betas'] = (momentum, *param_group['betas'][1:])
1377
param_group['momentum'] = momentum
1381
def state_dict(self):
1382
state = super().state_dict()
1383
# We are dropping the `_scale_fn_ref` attribute because it is a
1384
# `weakref.WeakMethod` and can't be pickled.
1385
state.pop('_scale_fn_ref')
1386
fn = state.pop('_scale_fn_custom')
1387
state['_scale_fn_custom'] = None
1388
if fn is not None and not isinstance(fn, types.FunctionType):
1389
# The _scale_fn_custom will only be saved if it is a callable object
1390
# and not if it is a function or lambda.
1391
state['_scale_fn_custom'] = fn.__dict__.copy()
1395
def load_state_dict(self, state_dict):
1396
fn = state_dict.pop('_scale_fn_custom')
1397
super().load_state_dict(state_dict)
1399
self._scale_fn_custom.__dict__.update(fn)
1400
self._init_scale_fn()
1403
class CosineAnnealingWarmRestarts(LRScheduler):
1404
r"""Set the learning rate of each parameter group using a cosine annealing
1405
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
1406
is the number of epochs since the last restart and :math:`T_{i}` is the number
1407
of epochs between two warm restarts in SGDR:
1410
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1411
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
1413
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
1414
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
1416
It has been proposed in
1417
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
1420
optimizer (Optimizer): Wrapped optimizer.
1421
T_0 (int): Number of iterations for the first restart.
1422
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
1423
eta_min (float, optional): Minimum learning rate. Default: 0.
1424
last_epoch (int, optional): The index of last epoch. Default: -1.
1425
verbose (bool): If ``True``, prints a message to stdout for
1426
each update. Default: ``False``.
1429
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1432
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
1433
https://arxiv.org/abs/1608.03983
1436
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose="deprecated"):
1437
if T_0 <= 0 or not isinstance(T_0, int):
1438
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
1439
if T_mult < 1 or not isinstance(T_mult, int):
1440
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
1441
if not isinstance(eta_min, (float, int)):
1442
raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
1445
self.T_mult = T_mult
1446
self.eta_min = eta_min
1447
self.T_cur = last_epoch
1448
super().__init__(optimizer, last_epoch, verbose)
1451
if not self._get_lr_called_within_step:
1452
warnings.warn("To get the last learning rate computed by the scheduler, "
1453
"please use `get_last_lr()`.", UserWarning)
1455
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
1456
for base_lr in self.base_lrs]
1458
def step(self, epoch=None):
1459
"""Step could be called after every batch update
1462
>>> # xdoctest: +SKIP("Undefined vars")
1463
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1464
>>> iters = len(dataloader)
1465
>>> for epoch in range(20):
1466
>>> for i, sample in enumerate(dataloader):
1467
>>> inputs, labels = sample['inputs'], sample['labels']
1468
>>> optimizer.zero_grad()
1469
>>> outputs = net(inputs)
1470
>>> loss = criterion(outputs, labels)
1472
>>> optimizer.step()
1473
>>> scheduler.step(epoch + i / iters)
1475
This function can be called in an interleaved way.
1478
>>> # xdoctest: +SKIP("Undefined vars")
1479
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
1480
>>> for epoch in range(20):
1481
>>> scheduler.step()
1482
>>> scheduler.step(26)
1483
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
1486
if epoch is None and self.last_epoch < 0:
1490
epoch = self.last_epoch + 1
1491
self.T_cur = self.T_cur + 1
1492
if self.T_cur >= self.T_i:
1493
self.T_cur = self.T_cur - self.T_i
1494
self.T_i = self.T_i * self.T_mult
1497
raise ValueError(f"Expected non-negative epoch, but got {epoch}")
1498
if epoch >= self.T_0:
1499
if self.T_mult == 1:
1500
self.T_cur = epoch % self.T_0
1502
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
1503
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
1504
self.T_i = self.T_0 * self.T_mult ** (n)
1508
self.last_epoch = math.floor(epoch)
1510
class _enable_get_lr_call:
1512
def __init__(self, o):
1515
def __enter__(self):
1516
self.o._get_lr_called_within_step = True
1519
def __exit__(self, type, value, traceback):
1520
self.o._get_lr_called_within_step = False
1523
with _enable_get_lr_call(self):
1524
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
1525
param_group, lr = data
1526
param_group['lr'] = lr
1528
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
1531
class OneCycleLR(LRScheduler):
1532
r"""Sets the learning rate of each parameter group according to the
1533
1cycle learning rate policy. The 1cycle policy anneals the learning
1534
rate from an initial learning rate to some maximum learning rate and then
1535
from that maximum learning rate to some minimum learning rate much lower
1536
than the initial learning rate.
1537
This policy was initially described in the paper `Super-Convergence:
1538
Very Fast Training of Neural Networks Using Large Learning Rates`_.
1540
The 1cycle learning rate policy changes the learning rate after every batch.
1541
`step` should be called after a batch has been used for training.
1543
This scheduler is not chainable.
1545
Note also that the total number of steps in the cycle can be determined in one
1546
of two ways (listed in order of precedence):
1548
#. A value for total_steps is explicitly provided.
1549
#. A number of epochs (epochs) and a number of steps per epoch
1550
(steps_per_epoch) are provided.
1551
In this case, the number of total steps is inferred by
1552
total_steps = epochs * steps_per_epoch
1554
You must either provide a value for total_steps or provide a value for both
1555
epochs and steps_per_epoch.
1557
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
1558
claims that "unpublished work has shown even better results by using only two phases". To
1559
mimic the behaviour of the original paper instead, set ``three_phase=True``.
1562
optimizer (Optimizer): Wrapped optimizer.
1563
max_lr (float or list): Upper learning rate boundaries in the cycle
1564
for each parameter group.
1565
total_steps (int): The total number of steps in the cycle. Note that
1566
if a value is not provided here, then it must be inferred by providing
1567
a value for epochs and steps_per_epoch.
1569
epochs (int): The number of epochs to train for. This is used along
1570
with steps_per_epoch in order to infer the total number of steps in the cycle
1571
if a value for total_steps is not provided.
1573
steps_per_epoch (int): The number of steps per epoch to train for. This is
1574
used along with epochs in order to infer the total number of steps in the
1575
cycle if a value for total_steps is not provided.
1577
pct_start (float): The percentage of the cycle (in number of steps) spent
1578
increasing the learning rate.
1580
anneal_strategy (str): {'cos', 'linear'}
1581
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
1584
cycle_momentum (bool): If ``True``, momentum is cycled inversely
1585
to learning rate between 'base_momentum' and 'max_momentum'.
1587
base_momentum (float or list): Lower momentum boundaries in the cycle
1588
for each parameter group. Note that momentum is cycled inversely
1589
to learning rate; at the peak of a cycle, momentum is
1590
'base_momentum' and learning rate is 'max_lr'.
1592
max_momentum (float or list): Upper momentum boundaries in the cycle
1593
for each parameter group. Functionally,
1594
it defines the cycle amplitude (max_momentum - base_momentum).
1595
Note that momentum is cycled inversely
1596
to learning rate; at the start of a cycle, momentum is 'max_momentum'
1597
and learning rate is 'base_lr'
1599
div_factor (float): Determines the initial learning rate via
1600
initial_lr = max_lr/div_factor
1602
final_div_factor (float): Determines the minimum learning rate via
1603
min_lr = initial_lr/final_div_factor
1605
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
1606
learning rate according to 'final_div_factor' instead of modifying the second
1607
phase (the first two phases will be symmetrical about the step indicated by
1609
last_epoch (int): The index of the last batch. This parameter is used when
1610
resuming a training job. Since `step()` should be invoked after each
1611
batch instead of after each epoch, this number represents the total
1612
number of *batches* computed, not the total number of epochs computed.
1613
When last_epoch=-1, the schedule is started from the beginning.
1615
verbose (bool): If ``True``, prints a message to stdout for
1616
each update. Default: ``False``.
1619
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
1623
>>> # xdoctest: +SKIP
1624
>>> data_loader = torch.utils.data.DataLoader(...)
1625
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1626
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
1627
>>> for epoch in range(10):
1628
>>> for batch in data_loader:
1629
>>> train_batch(...)
1630
>>> optimizer.step()
1631
>>> scheduler.step()
1634
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
1635
https://arxiv.org/abs/1708.07120
1642
steps_per_epoch=None,
1644
anneal_strategy='cos',
1645
cycle_momentum=True,
1649
final_div_factor=1e4,
1652
verbose="deprecated"):
1654
# Validate optimizer
1655
if not isinstance(optimizer, Optimizer):
1656
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
1657
self.optimizer = optimizer
1659
# Validate total_steps
1660
if total_steps is None and epochs is None and steps_per_epoch is None:
1661
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
1662
elif total_steps is not None:
1663
if total_steps <= 0 or not isinstance(total_steps, int):
1664
raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
1665
self.total_steps = total_steps
1667
if epochs <= 0 or not isinstance(epochs, int):
1668
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
1669
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
1670
raise ValueError(f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}")
1671
self.total_steps = epochs * steps_per_epoch
1674
self._schedule_phases = [
1676
'end_step': float(pct_start * self.total_steps) - 1,
1677
'start_lr': 'initial_lr',
1679
'start_momentum': 'max_momentum',
1680
'end_momentum': 'base_momentum',
1683
'end_step': float(2 * pct_start * self.total_steps) - 2,
1684
'start_lr': 'max_lr',
1685
'end_lr': 'initial_lr',
1686
'start_momentum': 'base_momentum',
1687
'end_momentum': 'max_momentum',
1690
'end_step': self.total_steps - 1,
1691
'start_lr': 'initial_lr',
1693
'start_momentum': 'max_momentum',
1694
'end_momentum': 'max_momentum',
1698
self._schedule_phases = [
1700
'end_step': float(pct_start * self.total_steps) - 1,
1701
'start_lr': 'initial_lr',
1703
'start_momentum': 'max_momentum',
1704
'end_momentum': 'base_momentum',
1707
'end_step': self.total_steps - 1,
1708
'start_lr': 'max_lr',
1710
'start_momentum': 'base_momentum',
1711
'end_momentum': 'max_momentum',
1715
# Validate pct_start
1716
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
1717
raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
1719
# Validate anneal_strategy
1720
if anneal_strategy not in ['cos', 'linear']:
1721
raise ValueError(f"anneal_strategy must by one of 'cos' or 'linear', instead got {anneal_strategy}")
1722
elif anneal_strategy == 'cos':
1723
self.anneal_func = self._annealing_cos
1724
elif anneal_strategy == 'linear':
1725
self.anneal_func = self._annealing_linear
1727
# Initialize learning rate variables
1728
max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
1729
if last_epoch == -1:
1730
for idx, group in enumerate(self.optimizer.param_groups):
1731
group['initial_lr'] = max_lrs[idx] / div_factor
1732
group['max_lr'] = max_lrs[idx]
1733
group['min_lr'] = group['initial_lr'] / final_div_factor
1735
# Initialize momentum variables
1736
self.cycle_momentum = cycle_momentum
1737
if self.cycle_momentum:
1738
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
1739
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
1740
self.use_beta1 = 'betas' in self.optimizer.defaults
1741
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
1742
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
1743
if last_epoch == -1:
1744
for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
1746
group['betas'] = (m_momentum, *group['betas'][1:])
1748
group['momentum'] = m_momentum
1749
group['max_momentum'] = m_momentum
1750
group['base_momentum'] = b_momentum
1752
super().__init__(optimizer, last_epoch, verbose)
1754
def _format_param(self, name, optimizer, param):
1755
"""Return correctly formatted lr/momentum for each param group."""
1756
if isinstance(param, (list, tuple)):
1757
if len(param) != len(optimizer.param_groups):
1758
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
1761
return [param] * len(optimizer.param_groups)
1764
def _annealing_cos(start, end, pct):
1765
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
1766
cos_out = math.cos(math.pi * pct) + 1
1767
return end + (start - end) / 2.0 * cos_out
1770
def _annealing_linear(start, end, pct):
1771
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
1772
return (end - start) * pct + start
1775
if not self._get_lr_called_within_step:
1776
warnings.warn("To get the last learning rate computed by the scheduler, "
1777
"please use `get_last_lr()`.", UserWarning)
1780
step_num = self.last_epoch
1782
if step_num > self.total_steps:
1783
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
1784
.format(step_num, self.total_steps))
1786
for group in self.optimizer.param_groups:
1788
for i, phase in enumerate(self._schedule_phases):
1789
end_step = phase['end_step']
1790
if step_num <= end_step or i == len(self._schedule_phases) - 1:
1791
pct = (step_num - start_step) / (end_step - start_step)
1792
computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
1793
if self.cycle_momentum:
1794
computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
1796
start_step = phase['end_step']
1798
lrs.append(computed_lr)
1799
if self.cycle_momentum:
1801
group['betas'] = (computed_momentum, *group['betas'][1:])
1803
group['momentum'] = computed_momentum