pytorch-image-models
206 строк · 6.6 Кб
1""" Scheduler Factory
2Hacked together by / Copyright 2021 Ross Wightman
3"""
4from typing import List, Optional, Union5
6from torch.optim import Optimizer7
8from .cosine_lr import CosineLRScheduler9from .multistep_lr import MultiStepLRScheduler10from .plateau_lr import PlateauLRScheduler11from .poly_lr import PolyLRScheduler12from .step_lr import StepLRScheduler13from .tanh_lr import TanhLRScheduler14
15
16def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):17""" cfg/argparse to kwargs helper18Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
19"""
20eval_metric = getattr(cfg, 'eval_metric', 'top1')21if decreasing_metric is not None:22plateau_mode = 'min' if decreasing_metric else 'max'23else:24plateau_mode = 'min' if 'loss' in eval_metric else 'max'25kwargs = dict(26sched=cfg.sched,27num_epochs=getattr(cfg, 'epochs', 100),28decay_epochs=getattr(cfg, 'decay_epochs', 30),29decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),30warmup_epochs=getattr(cfg, 'warmup_epochs', 5),31cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),32patience_epochs=getattr(cfg, 'patience_epochs', 10),33decay_rate=getattr(cfg, 'decay_rate', 0.1),34min_lr=getattr(cfg, 'min_lr', 0.),35warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),36warmup_prefix=getattr(cfg, 'warmup_prefix', False),37noise=getattr(cfg, 'lr_noise', None),38noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),39noise_std=getattr(cfg, 'lr_noise_std', 1.),40noise_seed=getattr(cfg, 'seed', 42),41cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),42cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),43cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),44k_decay=getattr(cfg, 'lr_k_decay', 1.0),45plateau_mode=plateau_mode,46step_on_epochs=not getattr(cfg, 'sched_on_updates', False),47)48return kwargs49
50
51def create_scheduler(52args,53optimizer: Optimizer,54updates_per_epoch: int = 0,55):56return create_scheduler_v2(57optimizer=optimizer,58**scheduler_kwargs(args),59updates_per_epoch=updates_per_epoch,60)61
62
63def create_scheduler_v2(64optimizer: Optimizer,65sched: str = 'cosine',66num_epochs: int = 300,67decay_epochs: int = 90,68decay_milestones: List[int] = (90, 180, 270),69cooldown_epochs: int = 0,70patience_epochs: int = 10,71decay_rate: float = 0.1,72min_lr: float = 0,73warmup_lr: float = 1e-5,74warmup_epochs: int = 0,75warmup_prefix: bool = False,76noise: Union[float, List[float]] = None,77noise_pct: float = 0.67,78noise_std: float = 1.,79noise_seed: int = 42,80cycle_mul: float = 1.,81cycle_decay: float = 0.1,82cycle_limit: int = 1,83k_decay: float = 1.0,84plateau_mode: str = 'max',85step_on_epochs: bool = True,86updates_per_epoch: int = 0,87):88t_initial = num_epochs89warmup_t = warmup_epochs90decay_t = decay_epochs91cooldown_t = cooldown_epochs92
93if not step_on_epochs:94assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'95t_initial = t_initial * updates_per_epoch96warmup_t = warmup_t * updates_per_epoch97decay_t = decay_t * updates_per_epoch98decay_milestones = [d * updates_per_epoch for d in decay_milestones]99cooldown_t = cooldown_t * updates_per_epoch100
101# warmup args102warmup_args = dict(103warmup_lr_init=warmup_lr,104warmup_t=warmup_t,105warmup_prefix=warmup_prefix,106)107
108# setup noise args for supporting schedulers109if noise is not None:110if isinstance(noise, (list, tuple)):111noise_range = [n * t_initial for n in noise]112if len(noise_range) == 1:113noise_range = noise_range[0]114else:115noise_range = noise * t_initial116else:117noise_range = None118noise_args = dict(119noise_range_t=noise_range,120noise_pct=noise_pct,121noise_std=noise_std,122noise_seed=noise_seed,123)124
125# setup cycle args for supporting schedulers126cycle_args = dict(127cycle_mul=cycle_mul,128cycle_decay=cycle_decay,129cycle_limit=cycle_limit,130)131
132lr_scheduler = None133if sched == 'cosine':134lr_scheduler = CosineLRScheduler(135optimizer,136t_initial=t_initial,137lr_min=min_lr,138t_in_epochs=step_on_epochs,139**cycle_args,140**warmup_args,141**noise_args,142k_decay=k_decay,143)144elif sched == 'tanh':145lr_scheduler = TanhLRScheduler(146optimizer,147t_initial=t_initial,148lr_min=min_lr,149t_in_epochs=step_on_epochs,150**cycle_args,151**warmup_args,152**noise_args,153)154elif sched == 'step':155lr_scheduler = StepLRScheduler(156optimizer,157decay_t=decay_t,158decay_rate=decay_rate,159t_in_epochs=step_on_epochs,160**warmup_args,161**noise_args,162)163elif sched == 'multistep':164lr_scheduler = MultiStepLRScheduler(165optimizer,166decay_t=decay_milestones,167decay_rate=decay_rate,168t_in_epochs=step_on_epochs,169**warmup_args,170**noise_args,171)172elif sched == 'plateau':173assert step_on_epochs, 'Plateau LR only supports step per epoch.'174warmup_args.pop('warmup_prefix', False)175lr_scheduler = PlateauLRScheduler(176optimizer,177decay_rate=decay_rate,178patience_t=patience_epochs,179cooldown_t=0,180**warmup_args,181lr_min=min_lr,182mode=plateau_mode,183**noise_args,184)185elif sched == 'poly':186lr_scheduler = PolyLRScheduler(187optimizer,188power=decay_rate, # overloading 'decay_rate' as polynomial power189t_initial=t_initial,190lr_min=min_lr,191t_in_epochs=step_on_epochs,192k_decay=k_decay,193**cycle_args,194**warmup_args,195**noise_args,196)197
198if hasattr(lr_scheduler, 'get_cycle_length'):199# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown200t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t201if step_on_epochs:202num_epochs = t_with_cycles_and_cooldown203else:204num_epochs = t_with_cycles_and_cooldown // updates_per_epoch205
206return lr_scheduler, num_epochs207