3
from copy import deepcopy
7
from torch.nn import Module
8
from torch.optim.lr_scheduler import LRScheduler
9
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
15
'get_ema_multi_avg_fn',
16
'get_swa_multi_avg_fn',
21
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
24
def get_ema_multi_avg_fn(decay=0.999):
26
def ema_update(ema_param_list, current_param_list, _):
27
# foreach lerp only handles float and complex
28
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
29
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
31
for p_ema, p_model in zip(ema_param_list, current_param_list):
32
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
37
def get_swa_multi_avg_fn():
39
def swa_update(averaged_param_list, current_param_list, num_averaged):
40
# foreach lerp only handles float and complex
41
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(averaged_param_list[0]):
42
torch._foreach_lerp_(averaged_param_list, current_param_list, 1 / (num_averaged + 1))
44
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
45
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
50
def get_ema_avg_fn(decay=0.999):
52
def ema_update(ema_param, current_param, num_averaged):
53
return decay * ema_param + (1 - decay) * current_param
60
def swa_update(averaged_param, current_param, num_averaged):
61
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
66
class AveragedModel(Module):
67
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
68
Exponential Moving Average (EMA).
70
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
71
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
72
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
75
Exponential Moving Average is a variation of `Polyak averaging`_,
76
but using exponential weights instead of equal weights across iterations.
78
AveragedModel class creates a copy of the provided module :attr:`model`
79
on the device :attr:`device` and allows to compute running averages of the
80
parameters of the :attr:`model`.
83
model (torch.nn.Module): model to use with SWA/EMA
84
device (torch.device, optional): if provided, the averaged model will be
85
stored on the :attr:`device`
86
avg_fn (function, optional): the averaging function used to update
87
parameters; the function must take in the current value of the
88
:class:`AveragedModel` parameter, the current value of :attr:`model`
89
parameter, and the number of models already averaged; if None,
90
an equally weighted average is used (default: None)
91
multi_avg_fn (function, optional): the averaging function used to update
92
parameters inplace; the function must take in the current values of the
93
:class:`AveragedModel` parameters as a list, the current values of :attr:`model`
94
parameters as a list, and the number of models already averaged; if None,
95
an equally weighted average is used (default: None)
96
use_buffers (bool): if ``True``, it will compute running averages for
97
both the parameters and the buffers of the model. (default: ``False``)
100
>>> # xdoctest: +SKIP("undefined variables")
101
>>> loader, optimizer, model, loss_fn = ...
102
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
103
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
106
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
107
>>> for i in range(300):
108
>>> for input, target in loader:
109
>>> optimizer.zero_grad()
110
>>> loss_fn(model(input), target).backward()
112
>>> if i > swa_start:
113
>>> swa_model.update_parameters(model)
114
>>> swa_scheduler.step()
118
>>> # Update bn statistics for the swa_model at the end
119
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
121
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
122
If no averaging function is provided, the default is to compute
123
equally-weighted average of the weights (SWA).
126
>>> # xdoctest: +SKIP("undefined variables")
127
>>> # Compute exponential moving averages of the weights and buffers
128
>>> ema_model = torch.optim.swa_utils.AveragedModel(model,
129
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
132
When using SWA/EMA with models containing Batch Normalization you may
133
need to update the activation statistics for Batch Normalization.
134
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
135
or by setting :attr:`use_buffers` to `True`. The first approach updates the
136
statistics in a post-training step by passing data through the model. The
137
second does it during the parameter update phase by averaging all buffers.
138
Empirical evidence has shown that updating the statistics in normalization
139
layers increases accuracy, but you may wish to empirically test which
140
approach yields the best results in your problem.
143
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
146
When :meth:`update_parameters` is called for the first time (i.e.
147
:attr:`n_averaged` is `0`) the parameters of `model` are copied
148
to the parameters of :class:`AveragedModel`. For every subsequent
149
call of :meth:`update_parameters` the function `avg_fn` is used
150
to update the parameters.
152
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
153
https://arxiv.org/abs/1803.05407
154
.. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
156
https://arxiv.org/abs/1806.05594
157
.. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
158
https://arxiv.org/abs/1904.11943
159
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
161
https://arxiv.org/abs/2001.02312
162
.. _Polyak averaging:
163
https://paperswithcode.com/method/polyak-averaging
165
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
167
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
168
self.module = deepcopy(model)
169
if device is not None:
170
self.module = self.module.to(device)
171
self.register_buffer('n_averaged',
172
torch.tensor(0, dtype=torch.long, device=device))
174
self.multi_avg_fn = multi_avg_fn
175
self.use_buffers = use_buffers
177
def forward(self, *args, **kwargs):
178
return self.module(*args, **kwargs)
180
def update_parameters(self, model):
182
itertools.chain(self.module.parameters(), self.module.buffers())
183
if self.use_buffers else self.parameters()
186
itertools.chain(model.parameters(), model.buffers())
187
if self.use_buffers else model.parameters()
189
self_param_detached = []
190
model_param_detached = []
191
for p_averaged, p_model in zip(self_param, model_param):
192
p_model_ = p_model.detach().to(p_averaged.device)
193
self_param_detached.append(p_averaged.detach())
194
model_param_detached.append(p_model_)
195
if self.n_averaged == 0:
196
p_averaged.detach().copy_(p_model_)
198
if self.n_averaged > 0:
199
if self.multi_avg_fn is not None or self.avg_fn is None:
200
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
201
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
202
if self.multi_avg_fn:
203
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
204
elif device.type in _get_foreach_kernels_supported_devices():
205
multi_avg_fn = get_swa_multi_avg_fn()
206
multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
208
avg_fn = get_swa_avg_fn()
209
n_averaged = self.n_averaged.to(device)
210
for p_averaged, p_model in zip(self_params, model_params):
211
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
213
for p_averaged, p_model in zip(self_param_detached, model_param_detached):
214
n_averaged = self.n_averaged.to(p_averaged.device)
215
p_averaged.detach().copy_(self.avg_fn(p_averaged.detach(), p_model, n_averaged))
217
if not self.use_buffers:
218
# If not apply running averages to the buffers,
219
# keep the buffers in sync with the source model.
220
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
221
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
226
def update_bn(loader, model, device=None):
227
r"""Updates BatchNorm running_mean, running_var buffers in the model.
229
It performs one pass over data in `loader` to estimate the activation
230
statistics for BatchNorm layers in the model.
232
loader (torch.utils.data.DataLoader): dataset loader to compute the
233
activation statistics on. Each data batch should be either a
234
tensor, or a list/tuple whose first element is a tensor
236
model (torch.nn.Module): model for which we seek to update BatchNorm
238
device (torch.device, optional): If set, data will be transferred to
239
:attr:`device` before being passed into :attr:`model`.
242
>>> # xdoctest: +SKIP("Undefined variables")
243
>>> loader, model = ...
244
>>> torch.optim.swa_utils.update_bn(loader, model)
247
The `update_bn` utility assumes that each data batch in :attr:`loader`
248
is either a tensor or a list or tuple of tensors; in the latter case it
249
is assumed that :meth:`model.forward()` should be called on the first
250
element of the list or tuple corresponding to the data batch.
253
for module in model.modules():
254
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
255
module.reset_running_stats()
256
momenta[module] = module.momentum
261
was_training = model.training
263
for module in momenta.keys():
264
module.momentum = None
267
if isinstance(input, (list, tuple)):
269
if device is not None:
270
input = input.to(device)
274
for bn_module in momenta.keys():
275
bn_module.momentum = momenta[bn_module]
276
model.train(was_training)
279
class SWALR(LRScheduler):
280
r"""Anneals the learning rate in each parameter group to a fixed value.
282
This learning rate scheduler is meant to be used with Stochastic Weight
283
Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
286
optimizer (torch.optim.Optimizer): wrapped optimizer
287
swa_lrs (float or list): the learning rate value for all param groups
288
together or separately for each group.
289
annealing_epochs (int): number of epochs in the annealing phase
291
annealing_strategy (str): "cos" or "linear"; specifies the annealing
292
strategy: "cos" for cosine annealing, "linear" for linear annealing
294
last_epoch (int): the index of the last epoch (default: -1)
296
The :class:`SWALR` scheduler can be used together with other
297
schedulers to switch to a constant learning rate late in the training
298
as in the example below.
301
>>> # xdoctest: +SKIP("Undefined variables")
302
>>> loader, optimizer, model = ...
303
>>> lr_lambda = lambda epoch: 0.9
304
>>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
305
>>> lr_lambda=lr_lambda)
306
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
307
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
309
>>> for i in range(300):
310
>>> for input, target in loader:
311
>>> optimizer.zero_grad()
312
>>> loss_fn(model(input), target).backward()
314
>>> if i > swa_start:
315
>>> swa_scheduler.step()
319
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
320
https://arxiv.org/abs/1803.05407
322
def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
323
swa_lrs = self._format_param(optimizer, swa_lr)
324
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
325
group['swa_lr'] = swa_lr
326
if anneal_strategy not in ['cos', 'linear']:
327
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', "
328
f"instead got {anneal_strategy}")
329
elif anneal_strategy == 'cos':
330
self.anneal_func = self._cosine_anneal
331
elif anneal_strategy == 'linear':
332
self.anneal_func = self._linear_anneal
333
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
334
raise ValueError(f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}")
335
self.anneal_epochs = anneal_epochs
336
super().__init__(optimizer, last_epoch)
339
def _format_param(optimizer, swa_lrs):
340
if isinstance(swa_lrs, (list, tuple)):
341
if len(swa_lrs) != len(optimizer.param_groups):
342
raise ValueError("swa_lr must have the same length as "
343
f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, "
344
f"optimizer.param_groups has {len(optimizer.param_groups)}")
347
return [swa_lrs] * len(optimizer.param_groups)
350
def _linear_anneal(t):
354
def _cosine_anneal(t):
355
return (1 - math.cos(math.pi * t)) / 2
358
def _get_initial_lr(lr, swa_lr, alpha):
361
return (lr - alpha * swa_lr) / (1 - alpha)
364
if not self._get_lr_called_within_step:
365
warnings.warn("To get the last learning rate computed by the scheduler, "
366
"please use `get_last_lr()`.", UserWarning)
367
step = self._step_count - 1
368
if self.anneal_epochs == 0:
370
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
371
prev_alpha = self.anneal_func(prev_t)
372
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
373
for group in self.optimizer.param_groups]
374
t = max(0, min(1, step / max(1, self.anneal_epochs)))
375
alpha = self.anneal_func(t)
376
return [group['swa_lr'] * alpha + lr * (1 - alpha)
377
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]