pytorch

Форк
0
/
swa_utils.py 
377 строк · 16.2 Кб
1
import itertools
2
import math
3
from copy import deepcopy
4
import warnings
5

6
import torch
7
from torch.nn import Module
8
from torch.optim.lr_scheduler import LRScheduler
9
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
10

11
__all__ = [
12
    'AveragedModel',
13
    'update_bn',
14
    'SWALR',
15
    'get_ema_multi_avg_fn',
16
    'get_swa_multi_avg_fn',
17
    'get_ema_avg_fn',
18
    'get_swa_avg_fn'
19
]
20

21
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
22

23

24
def get_ema_multi_avg_fn(decay=0.999):
25
    @torch.no_grad()
26
    def ema_update(ema_param_list, current_param_list, _):
27
        # foreach lerp only handles float and complex
28
        if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
29
            torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
30
        else:
31
            for p_ema, p_model in zip(ema_param_list, current_param_list):
32
                p_ema.copy_(p_ema * decay + p_model * (1 - decay))
33

34
    return ema_update
35

36

37
def get_swa_multi_avg_fn():
38
    @torch.no_grad()
39
    def swa_update(averaged_param_list, current_param_list, num_averaged):
40
        # foreach lerp only handles float and complex
41
        if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(averaged_param_list[0]):
42
            torch._foreach_lerp_(averaged_param_list, current_param_list, 1 / (num_averaged + 1))
43
        else:
44
            diffs = torch._foreach_sub(current_param_list, averaged_param_list)
45
            torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
46

47
    return swa_update
48

49

50
def get_ema_avg_fn(decay=0.999):
51
    @torch.no_grad()
52
    def ema_update(ema_param, current_param, num_averaged):
53
        return decay * ema_param + (1 - decay) * current_param
54

55
    return ema_update
56

57

58
def get_swa_avg_fn():
59
    @torch.no_grad()
60
    def swa_update(averaged_param, current_param, num_averaged):
61
        return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
62

63
    return swa_update
64

65

66
class AveragedModel(Module):
67
    r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
68
    Exponential Moving Average (EMA).
69

70
    Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
71
    Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
72
    Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
73
    (UAI 2018).
74

75
    Exponential Moving Average is a variation of `Polyak averaging`_,
76
    but using exponential weights instead of equal weights across iterations.
77

78
    AveragedModel class creates a copy of the provided module :attr:`model`
79
    on the device :attr:`device` and allows to compute running averages of the
80
    parameters of the :attr:`model`.
81

82
    Args:
83
        model (torch.nn.Module): model to use with SWA/EMA
84
        device (torch.device, optional): if provided, the averaged model will be
85
            stored on the :attr:`device`
86
        avg_fn (function, optional): the averaging function used to update
87
            parameters; the function must take in the current value of the
88
            :class:`AveragedModel` parameter, the current value of :attr:`model`
89
            parameter, and the number of models already averaged; if None,
90
            an equally weighted average is used (default: None)
91
        multi_avg_fn (function, optional): the averaging function used to update
92
            parameters inplace; the function must take in the current values of the
93
            :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
94
            parameters as a list, and the number of models already averaged; if None,
95
            an equally weighted average is used (default: None)
96
        use_buffers (bool): if ``True``, it will compute running averages for
97
            both the parameters and the buffers of the model. (default: ``False``)
98

99
    Example:
100
        >>> # xdoctest: +SKIP("undefined variables")
101
        >>> loader, optimizer, model, loss_fn = ...
102
        >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
103
        >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
104
        >>>                                     T_max=300)
105
        >>> swa_start = 160
106
        >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
107
        >>> for i in range(300):
108
        >>>      for input, target in loader:
109
        >>>          optimizer.zero_grad()
110
        >>>          loss_fn(model(input), target).backward()
111
        >>>          optimizer.step()
112
        >>>      if i > swa_start:
113
        >>>          swa_model.update_parameters(model)
114
        >>>          swa_scheduler.step()
115
        >>>      else:
116
        >>>          scheduler.step()
117
        >>>
118
        >>> # Update bn statistics for the swa_model at the end
119
        >>> torch.optim.swa_utils.update_bn(loader, swa_model)
120

121
    You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
122
    If no averaging function is provided, the default is to compute
123
    equally-weighted average of the weights (SWA).
124

125
    Example:
126
        >>> # xdoctest: +SKIP("undefined variables")
127
        >>> # Compute exponential moving averages of the weights and buffers
128
        >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
129
        >>>             torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
130

131
    .. note::
132
        When using SWA/EMA with models containing Batch Normalization you may
133
        need to update the activation statistics for Batch Normalization.
134
        This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
135
        or by setting :attr:`use_buffers` to `True`. The first approach updates the
136
        statistics in a post-training step by passing data through the model. The
137
        second does it during the parameter update phase by averaging all buffers.
138
        Empirical evidence has shown that updating the statistics in normalization
139
        layers increases accuracy, but you may wish to empirically test which
140
        approach yields the best results in your problem.
141

142
    .. note::
143
        :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
144

145
    .. note::
146
        When :meth:`update_parameters` is called for the first time (i.e.
147
        :attr:`n_averaged` is `0`) the parameters of `model` are copied
148
        to the parameters of :class:`AveragedModel`. For every subsequent
149
        call of :meth:`update_parameters` the function `avg_fn` is used
150
        to update the parameters.
151

152
    .. _Averaging Weights Leads to Wider Optima and Better Generalization:
153
        https://arxiv.org/abs/1803.05407
154
    .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
155
        Average:
156
        https://arxiv.org/abs/1806.05594
157
    .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
158
        https://arxiv.org/abs/1904.11943
159
    .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
160
        Generalizes Well:
161
        https://arxiv.org/abs/2001.02312
162
    .. _Polyak averaging:
163
        https://paperswithcode.com/method/polyak-averaging
164
    """
165
    def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
166
        super().__init__()
167
        assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
168
        self.module = deepcopy(model)
169
        if device is not None:
170
            self.module = self.module.to(device)
171
        self.register_buffer('n_averaged',
172
                             torch.tensor(0, dtype=torch.long, device=device))
173
        self.avg_fn = avg_fn
174
        self.multi_avg_fn = multi_avg_fn
175
        self.use_buffers = use_buffers
176

177
    def forward(self, *args, **kwargs):
178
        return self.module(*args, **kwargs)
179

180
    def update_parameters(self, model):
181
        self_param = (
182
            itertools.chain(self.module.parameters(), self.module.buffers())
183
            if self.use_buffers else self.parameters()
184
        )
185
        model_param = (
186
            itertools.chain(model.parameters(), model.buffers())
187
            if self.use_buffers else model.parameters()
188
        )
189
        self_param_detached = []
190
        model_param_detached = []
191
        for p_averaged, p_model in zip(self_param, model_param):
192
            p_model_ = p_model.detach().to(p_averaged.device)
193
            self_param_detached.append(p_averaged.detach())
194
            model_param_detached.append(p_model_)
195
            if self.n_averaged == 0:
196
                p_averaged.detach().copy_(p_model_)
197

198
        if self.n_averaged > 0:
199
            if self.multi_avg_fn is not None or self.avg_fn is None:
200
                grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
201
                for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
202
                    if self.multi_avg_fn:
203
                        self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
204
                    elif device.type in _get_foreach_kernels_supported_devices():
205
                        multi_avg_fn = get_swa_multi_avg_fn()
206
                        multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
207
                    else:
208
                        avg_fn = get_swa_avg_fn()
209
                        n_averaged = self.n_averaged.to(device)
210
                        for p_averaged, p_model in zip(self_params, model_params):
211
                            p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
212
            else:
213
                for p_averaged, p_model in zip(self_param_detached, model_param_detached):
214
                    n_averaged = self.n_averaged.to(p_averaged.device)
215
                    p_averaged.detach().copy_(self.avg_fn(p_averaged.detach(), p_model, n_averaged))
216

217
        if not self.use_buffers:
218
            # If not apply running averages to the buffers,
219
            # keep the buffers in sync with the source model.
220
            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
221
                b_swa.detach().copy_(b_model.detach().to(b_swa.device))
222
        self.n_averaged += 1
223

224

225
@torch.no_grad()
226
def update_bn(loader, model, device=None):
227
    r"""Updates BatchNorm running_mean, running_var buffers in the model.
228

229
    It performs one pass over data in `loader` to estimate the activation
230
    statistics for BatchNorm layers in the model.
231
    Args:
232
        loader (torch.utils.data.DataLoader): dataset loader to compute the
233
            activation statistics on. Each data batch should be either a
234
            tensor, or a list/tuple whose first element is a tensor
235
            containing data.
236
        model (torch.nn.Module): model for which we seek to update BatchNorm
237
            statistics.
238
        device (torch.device, optional): If set, data will be transferred to
239
            :attr:`device` before being passed into :attr:`model`.
240

241
    Example:
242
        >>> # xdoctest: +SKIP("Undefined variables")
243
        >>> loader, model = ...
244
        >>> torch.optim.swa_utils.update_bn(loader, model)
245

246
    .. note::
247
        The `update_bn` utility assumes that each data batch in :attr:`loader`
248
        is either a tensor or a list or tuple of tensors; in the latter case it
249
        is assumed that :meth:`model.forward()` should be called on the first
250
        element of the list or tuple corresponding to the data batch.
251
    """
252
    momenta = {}
253
    for module in model.modules():
254
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
255
            module.reset_running_stats()
256
            momenta[module] = module.momentum
257

258
    if not momenta:
259
        return
260

261
    was_training = model.training
262
    model.train()
263
    for module in momenta.keys():
264
        module.momentum = None
265

266
    for input in loader:
267
        if isinstance(input, (list, tuple)):
268
            input = input[0]
269
        if device is not None:
270
            input = input.to(device)
271

272
        model(input)
273

274
    for bn_module in momenta.keys():
275
        bn_module.momentum = momenta[bn_module]
276
    model.train(was_training)
277

278

279
class SWALR(LRScheduler):
280
    r"""Anneals the learning rate in each parameter group to a fixed value.
281

282
    This learning rate scheduler is meant to be used with Stochastic Weight
283
    Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
284

285
    Args:
286
        optimizer (torch.optim.Optimizer): wrapped optimizer
287
        swa_lrs (float or list): the learning rate value for all param groups
288
            together or separately for each group.
289
        annealing_epochs (int): number of epochs in the annealing phase
290
            (default: 10)
291
        annealing_strategy (str): "cos" or "linear"; specifies the annealing
292
            strategy: "cos" for cosine annealing, "linear" for linear annealing
293
            (default: "cos")
294
        last_epoch (int): the index of the last epoch (default: -1)
295

296
    The :class:`SWALR` scheduler can be used together with other
297
    schedulers to switch to a constant learning rate late in the training
298
    as in the example below.
299

300
    Example:
301
        >>> # xdoctest: +SKIP("Undefined variables")
302
        >>> loader, optimizer, model = ...
303
        >>> lr_lambda = lambda epoch: 0.9
304
        >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
305
        >>>        lr_lambda=lr_lambda)
306
        >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
307
        >>>        anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
308
        >>> swa_start = 160
309
        >>> for i in range(300):
310
        >>>      for input, target in loader:
311
        >>>          optimizer.zero_grad()
312
        >>>          loss_fn(model(input), target).backward()
313
        >>>          optimizer.step()
314
        >>>      if i > swa_start:
315
        >>>          swa_scheduler.step()
316
        >>>      else:
317
        >>>          scheduler.step()
318

319
    .. _Averaging Weights Leads to Wider Optima and Better Generalization:
320
        https://arxiv.org/abs/1803.05407
321
    """
322
    def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
323
        swa_lrs = self._format_param(optimizer, swa_lr)
324
        for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
325
            group['swa_lr'] = swa_lr
326
        if anneal_strategy not in ['cos', 'linear']:
327
            raise ValueError("anneal_strategy must by one of 'cos' or 'linear', "
328
                             f"instead got {anneal_strategy}")
329
        elif anneal_strategy == 'cos':
330
            self.anneal_func = self._cosine_anneal
331
        elif anneal_strategy == 'linear':
332
            self.anneal_func = self._linear_anneal
333
        if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
334
            raise ValueError(f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}")
335
        self.anneal_epochs = anneal_epochs
336
        super().__init__(optimizer, last_epoch)
337

338
    @staticmethod
339
    def _format_param(optimizer, swa_lrs):
340
        if isinstance(swa_lrs, (list, tuple)):
341
            if len(swa_lrs) != len(optimizer.param_groups):
342
                raise ValueError("swa_lr must have the same length as "
343
                                 f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, "
344
                                 f"optimizer.param_groups has {len(optimizer.param_groups)}")
345
            return swa_lrs
346
        else:
347
            return [swa_lrs] * len(optimizer.param_groups)
348

349
    @staticmethod
350
    def _linear_anneal(t):
351
        return t
352

353
    @staticmethod
354
    def _cosine_anneal(t):
355
        return (1 - math.cos(math.pi * t)) / 2
356

357
    @staticmethod
358
    def _get_initial_lr(lr, swa_lr, alpha):
359
        if alpha == 1:
360
            return swa_lr
361
        return (lr - alpha * swa_lr) / (1 - alpha)
362

363
    def get_lr(self):
364
        if not self._get_lr_called_within_step:
365
            warnings.warn("To get the last learning rate computed by the scheduler, "
366
                          "please use `get_last_lr()`.", UserWarning)
367
        step = self._step_count - 1
368
        if self.anneal_epochs == 0:
369
            step = max(1, step)
370
        prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
371
        prev_alpha = self.anneal_func(prev_t)
372
        prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
373
                    for group in self.optimizer.param_groups]
374
        t = max(0, min(1, step / max(1, self.anneal_epochs)))
375
        alpha = self.anneal_func(t)
376
        return [group['swa_lr'] * alpha + lr * (1 - alpha)
377
                for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
378

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.