pytorch-image-models

Форк
0
/
scheduler_factory.py 
206 строк · 6.6 Кб
1
""" Scheduler Factory
2
Hacked together by / Copyright 2021 Ross Wightman
3
"""
4
from typing import List, Optional, Union
5

6
from torch.optim import Optimizer
7

8
from .cosine_lr import CosineLRScheduler
9
from .multistep_lr import MultiStepLRScheduler
10
from .plateau_lr import PlateauLRScheduler
11
from .poly_lr import PolyLRScheduler
12
from .step_lr import StepLRScheduler
13
from .tanh_lr import TanhLRScheduler
14

15

16
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
17
    """ cfg/argparse to kwargs helper
18
    Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
19
    """
20
    eval_metric = getattr(cfg, 'eval_metric', 'top1')
21
    if decreasing_metric is not None:
22
        plateau_mode = 'min' if decreasing_metric else 'max'
23
    else:
24
        plateau_mode = 'min' if 'loss' in eval_metric else 'max'
25
    kwargs = dict(
26
        sched=cfg.sched,
27
        num_epochs=getattr(cfg, 'epochs', 100),
28
        decay_epochs=getattr(cfg, 'decay_epochs', 30),
29
        decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
30
        warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
31
        cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
32
        patience_epochs=getattr(cfg, 'patience_epochs', 10),
33
        decay_rate=getattr(cfg, 'decay_rate', 0.1),
34
        min_lr=getattr(cfg, 'min_lr', 0.),
35
        warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
36
        warmup_prefix=getattr(cfg, 'warmup_prefix', False),
37
        noise=getattr(cfg, 'lr_noise', None),
38
        noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
39
        noise_std=getattr(cfg, 'lr_noise_std', 1.),
40
        noise_seed=getattr(cfg, 'seed', 42),
41
        cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
42
        cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
43
        cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
44
        k_decay=getattr(cfg, 'lr_k_decay', 1.0),
45
        plateau_mode=plateau_mode,
46
        step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
47
    )
48
    return kwargs
49

50

51
def create_scheduler(
52
        args,
53
        optimizer: Optimizer,
54
        updates_per_epoch: int = 0,
55
):
56
    return create_scheduler_v2(
57
        optimizer=optimizer,
58
        **scheduler_kwargs(args),
59
        updates_per_epoch=updates_per_epoch,
60
    )
61

62

63
def create_scheduler_v2(
64
        optimizer: Optimizer,
65
        sched: str = 'cosine',
66
        num_epochs: int = 300,
67
        decay_epochs: int = 90,
68
        decay_milestones: List[int] = (90, 180, 270),
69
        cooldown_epochs: int = 0,
70
        patience_epochs: int = 10,
71
        decay_rate: float = 0.1,
72
        min_lr: float = 0,
73
        warmup_lr: float = 1e-5,
74
        warmup_epochs: int = 0,
75
        warmup_prefix: bool = False,
76
        noise: Union[float, List[float]] = None,
77
        noise_pct: float = 0.67,
78
        noise_std: float = 1.,
79
        noise_seed: int = 42,
80
        cycle_mul: float = 1.,
81
        cycle_decay: float = 0.1,
82
        cycle_limit: int = 1,
83
        k_decay: float = 1.0,
84
        plateau_mode: str = 'max',
85
        step_on_epochs: bool = True,
86
        updates_per_epoch: int = 0,
87
):
88
    t_initial = num_epochs
89
    warmup_t = warmup_epochs
90
    decay_t = decay_epochs
91
    cooldown_t = cooldown_epochs
92

93
    if not step_on_epochs:
94
        assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
95
        t_initial = t_initial * updates_per_epoch
96
        warmup_t = warmup_t * updates_per_epoch
97
        decay_t = decay_t * updates_per_epoch
98
        decay_milestones = [d * updates_per_epoch for d in decay_milestones]
99
        cooldown_t = cooldown_t * updates_per_epoch
100

101
    # warmup args
102
    warmup_args = dict(
103
        warmup_lr_init=warmup_lr,
104
        warmup_t=warmup_t,
105
        warmup_prefix=warmup_prefix,
106
    )
107

108
    # setup noise args for supporting schedulers
109
    if noise is not None:
110
        if isinstance(noise, (list, tuple)):
111
            noise_range = [n * t_initial for n in noise]
112
            if len(noise_range) == 1:
113
                noise_range = noise_range[0]
114
        else:
115
            noise_range = noise * t_initial
116
    else:
117
        noise_range = None
118
    noise_args = dict(
119
        noise_range_t=noise_range,
120
        noise_pct=noise_pct,
121
        noise_std=noise_std,
122
        noise_seed=noise_seed,
123
    )
124

125
    # setup cycle args for supporting schedulers
126
    cycle_args = dict(
127
        cycle_mul=cycle_mul,
128
        cycle_decay=cycle_decay,
129
        cycle_limit=cycle_limit,
130
    )
131

132
    lr_scheduler = None
133
    if sched == 'cosine':
134
        lr_scheduler = CosineLRScheduler(
135
            optimizer,
136
            t_initial=t_initial,
137
            lr_min=min_lr,
138
            t_in_epochs=step_on_epochs,
139
            **cycle_args,
140
            **warmup_args,
141
            **noise_args,
142
            k_decay=k_decay,
143
        )
144
    elif sched == 'tanh':
145
        lr_scheduler = TanhLRScheduler(
146
            optimizer,
147
            t_initial=t_initial,
148
            lr_min=min_lr,
149
            t_in_epochs=step_on_epochs,
150
            **cycle_args,
151
            **warmup_args,
152
            **noise_args,
153
        )
154
    elif sched == 'step':
155
        lr_scheduler = StepLRScheduler(
156
            optimizer,
157
            decay_t=decay_t,
158
            decay_rate=decay_rate,
159
            t_in_epochs=step_on_epochs,
160
            **warmup_args,
161
            **noise_args,
162
        )
163
    elif sched == 'multistep':
164
        lr_scheduler = MultiStepLRScheduler(
165
            optimizer,
166
            decay_t=decay_milestones,
167
            decay_rate=decay_rate,
168
            t_in_epochs=step_on_epochs,
169
            **warmup_args,
170
            **noise_args,
171
        )
172
    elif sched == 'plateau':
173
        assert step_on_epochs, 'Plateau LR only supports step per epoch.'
174
        warmup_args.pop('warmup_prefix', False)
175
        lr_scheduler = PlateauLRScheduler(
176
            optimizer,
177
            decay_rate=decay_rate,
178
            patience_t=patience_epochs,
179
            cooldown_t=0,
180
            **warmup_args,
181
            lr_min=min_lr,
182
            mode=plateau_mode,
183
            **noise_args,
184
        )
185
    elif sched == 'poly':
186
        lr_scheduler = PolyLRScheduler(
187
            optimizer,
188
            power=decay_rate,  # overloading 'decay_rate' as polynomial power
189
            t_initial=t_initial,
190
            lr_min=min_lr,
191
            t_in_epochs=step_on_epochs,
192
            k_decay=k_decay,
193
            **cycle_args,
194
            **warmup_args,
195
            **noise_args,
196
        )
197

198
    if hasattr(lr_scheduler, 'get_cycle_length'):
199
        # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
200
        t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
201
        if step_on_epochs:
202
            num_epochs = t_with_cycles_and_cooldown
203
        else:
204
            num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
205

206
    return lr_scheduler, num_epochs
207

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.