pytorch

Форк
0
419 строк · 17.2 Кб
1
import torch
2
from torch import Tensor
3
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
4
                        _differentiable_doc, _foreach_doc, _maximize_doc, _fused_doc)
5
from typing import List, Optional
6

7
__all__ = ['SGD', 'sgd']
8

9

10
class SGD(Optimizer):
11
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0,
12
                 weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
13
                 differentiable: bool = False, fused: Optional[bool] = None):
14
        if lr < 0.0:
15
            raise ValueError(f"Invalid learning rate: {lr}")
16
        if momentum < 0.0:
17
            raise ValueError(f"Invalid momentum value: {momentum}")
18
        if weight_decay < 0.0:
19
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
20

21
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
22
                        weight_decay=weight_decay, nesterov=nesterov,
23
                        maximize=maximize, foreach=foreach,
24
                        differentiable=differentiable, fused=fused)
25
        if nesterov and (momentum <= 0 or dampening != 0):
26
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
27
        super().__init__(params, defaults)
28

29
        if fused:
30
            self._step_supports_amp_scaling = True
31
            if differentiable:
32
                raise RuntimeError("`fused` does not support `differentiable`")
33
            if foreach:
34
                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
35

36
    def __setstate__(self, state):
37
        super().__setstate__(state)
38
        for group in self.param_groups:
39
            group.setdefault('nesterov', False)
40
            group.setdefault('maximize', False)
41
            group.setdefault('foreach', None)
42
            group.setdefault('differentiable', False)
43
            group.setdefault('fused', False)
44

45
    def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
46
        has_sparse_grad = False
47

48
        for p in group['params']:
49
            if p.grad is not None:
50
                params_with_grad.append(p)
51
                d_p_list.append(p.grad)
52
                if p.grad.is_sparse:
53
                    has_sparse_grad = True
54

55
                state = self.state[p]
56
                momentum_buffer_list.append(state.get('momentum_buffer'))
57

58
        return has_sparse_grad
59

60
    @_use_grad_for_differentiable
61
    def step(self, closure=None):
62
        """Performs a single optimization step.
63

64
        Args:
65
            closure (Callable, optional): A closure that reevaluates the model
66
                and returns the loss.
67
        """
68
        loss = None
69
        if closure is not None:
70
            with torch.enable_grad():
71
                loss = closure()
72

73
        for group in self.param_groups:
74
            params_with_grad = []
75
            d_p_list = []
76
            momentum_buffer_list = []
77

78
            has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
79

80
            sgd(params_with_grad,
81
                d_p_list,
82
                momentum_buffer_list,
83
                weight_decay=group['weight_decay'],
84
                momentum=group['momentum'],
85
                lr=group['lr'],
86
                dampening=group['dampening'],
87
                nesterov=group['nesterov'],
88
                maximize=group['maximize'],
89
                has_sparse_grad=has_sparse_grad,
90
                foreach=group['foreach'],
91
                fused=group['fused'],
92
                grad_scale=getattr(self, "grad_scale", None),
93
                found_inf=getattr(self, "found_inf", None))
94

95
            # update momentum_buffers in state
96
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
97
                state = self.state[p]
98
                state['momentum_buffer'] = momentum_buffer
99

100
        return loss
101

102

103
SGD.__doc__ = r"""Implements stochastic gradient descent (optionally with momentum).
104

105
    .. math::
106
       \begin{aligned}
107
            &\rule{110mm}{0.4pt}                                                                 \\
108
            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
109
                \text{ (objective)}, \: \lambda \text{ (weight decay)},                          \\
110
            &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
111
            \:\textit{ nesterov,}\:\textit{ maximize}                                     \\[-1.ex]
112
            &\rule{110mm}{0.4pt}                                                                 \\
113
            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
114
            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
115
            &\hspace{5mm}\textbf{if} \: \lambda \neq 0                                           \\
116
            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
117
            &\hspace{5mm}\textbf{if} \: \mu \neq 0                                               \\
118
            &\hspace{10mm}\textbf{if} \: t > 1                                                   \\
119
            &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t           \\
120
            &\hspace{10mm}\textbf{else}                                                          \\
121
            &\hspace{15mm} \textbf{b}_t \leftarrow g_t                                           \\
122
            &\hspace{10mm}\textbf{if} \: \textit{nesterov}                                       \\
123
            &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t                             \\
124
            &\hspace{10mm}\textbf{else}                                                   \\[-1.ex]
125
            &\hspace{15mm} g_t  \leftarrow  \textbf{b}_t                                         \\
126
            &\hspace{5mm}\textbf{if} \: \textit{maximize}                                          \\
127
            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t                   \\[-1.ex]
128
            &\hspace{5mm}\textbf{else}                                                    \\[-1.ex]
129
            &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t                   \\[-1.ex]
130
            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
131
            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
132
            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
133
       \end{aligned}
134

135
    Nesterov momentum is based on the formula from
136
    `On the importance of initialization and momentum in deep learning`__.
137
    """ + fr"""
138
    Args:
139
        params (iterable): iterable of parameters to optimize or dicts defining
140
            parameter groups
141
        lr (float, optional): learning rate (default: 1e-3)
142
        momentum (float, optional): momentum factor (default: 0)
143
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
144
        dampening (float, optional): dampening for momentum (default: 0)
145
        nesterov (bool, optional): enables Nesterov momentum (default: False)
146
        {_maximize_doc}
147
        {_foreach_doc}
148
        {_differentiable_doc}
149
        {_fused_doc}
150
    """ + r"""
151

152
    Example:
153
        >>> # xdoctest: +SKIP
154
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
155
        >>> optimizer.zero_grad()
156
        >>> loss_fn(model(input), target).backward()
157
        >>> optimizer.step()
158

159
    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
160

161
    .. note::
162
        The implementation of SGD with Momentum/Nesterov subtly differs from
163
        Sutskever et. al. and implementations in some other frameworks.
164

165
        Considering the specific case of Momentum, the update can be written as
166

167
        .. math::
168
            \begin{aligned}
169
                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
170
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
171
            \end{aligned}
172

173
        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
174
        parameters, gradient, velocity, and momentum respectively.
175

176
        This is in contrast to Sutskever et. al. and
177
        other frameworks which employ an update of the form
178

179
        .. math::
180
            \begin{aligned}
181
                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
182
                p_{t+1} & = p_{t} - v_{t+1}.
183
            \end{aligned}
184

185
        The Nesterov version is analogously modified.
186

187
        Moreover, the initial value of the momentum buffer is set to the
188
        gradient value at the first step. This is in contrast to some other
189
        frameworks that initialize it to all zeros.
190

191
    """
192

193

194
def sgd(params: List[Tensor],
195
        d_p_list: List[Tensor],
196
        momentum_buffer_list: List[Optional[Tensor]],
197
        # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
198
        # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
199
        has_sparse_grad: bool = None,
200
        foreach: Optional[bool] = None,
201
        fused: Optional[bool] = None,
202
        grad_scale: Optional[Tensor] = None,
203
        found_inf: Optional[Tensor] = None,
204
        *,
205
        weight_decay: float,
206
        momentum: float,
207
        lr: float,
208
        dampening: float,
209
        nesterov: bool,
210
        maximize: bool):
211
    r"""Functional API that performs SGD algorithm computation.
212

213
    See :class:`~torch.optim.SGD` for details.
214
    """
215

216
    # Respect when the user inputs False/True for foreach or fused. We only want to change
217
    # the default when neither have been user-specified. Note that we default to foreach
218
    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
219
    # bake-in time before making it the default, even if it is typically faster.
220
    if foreach is None and fused is None:
221
        # why must we be explicit about an if statement for torch.jit.is_scripting here?
222
        # because JIT can't handle Optionals nor fancy conditionals when scripting
223
        if not torch.jit.is_scripting():
224
            fused, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
225
        else:
226
            foreach = False
227
            fused = False
228
    if foreach is None:
229
        foreach = False
230
    if fused is None:
231
        fused = False
232

233
    if foreach and torch.jit.is_scripting():
234
        raise RuntimeError('torch.jit.script not supported with foreach optimizers')
235
    if fused and torch.jit.is_scripting():
236
        raise RuntimeError('torch.jit.script not supported with fused optimizers')
237

238
    if foreach and not torch.jit.is_scripting():
239
        func = _multi_tensor_sgd
240
    elif fused and not torch.jit.is_scripting():
241
        func = _fused_sgd
242
    else:
243
        func = _single_tensor_sgd
244

245
    func(params,
246
         d_p_list,
247
         momentum_buffer_list,
248
         weight_decay=weight_decay,
249
         momentum=momentum,
250
         lr=lr,
251
         dampening=dampening,
252
         nesterov=nesterov,
253
         has_sparse_grad=has_sparse_grad,
254
         maximize=maximize,
255
         grad_scale=grad_scale,
256
         found_inf=found_inf)
257

258
def _single_tensor_sgd(params: List[Tensor],
259
                       d_p_list: List[Tensor],
260
                       momentum_buffer_list: List[Optional[Tensor]],
261
                       grad_scale: Optional[Tensor],
262
                       found_inf: Optional[Tensor],
263
                       *,
264
                       weight_decay: float,
265
                       momentum: float,
266
                       lr: float,
267
                       dampening: float,
268
                       nesterov: bool,
269
                       maximize: bool,
270
                       has_sparse_grad: bool):
271
    assert grad_scale is None and found_inf is None
272

273
    for i, param in enumerate(params):
274
        d_p = d_p_list[i] if not maximize else -d_p_list[i]
275

276
        if weight_decay != 0:
277
            d_p = d_p.add(param, alpha=weight_decay)
278

279
        if momentum != 0:
280
            buf = momentum_buffer_list[i]
281

282
            if buf is None:
283
                buf = torch.clone(d_p).detach()
284
                momentum_buffer_list[i] = buf
285
            else:
286
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
287

288
            if nesterov:
289
                d_p = d_p.add(buf, alpha=momentum)
290
            else:
291
                d_p = buf
292

293
        param.add_(d_p, alpha=-lr)
294

295

296
def _multi_tensor_sgd(params: List[Tensor],
297
                      grads: List[Tensor],
298
                      momentum_buffer_list: List[Optional[Tensor]],
299
                      grad_scale: Optional[Tensor],
300
                      found_inf: Optional[Tensor],
301
                      *,
302
                      weight_decay: float,
303
                      momentum: float,
304
                      lr: float,
305
                      dampening: float,
306
                      nesterov: bool,
307
                      maximize: bool,
308
                      has_sparse_grad: bool):
309
    assert grad_scale is None and found_inf is None
310

311
    if len(params) == 0:
312
        return
313

314
    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
315
    for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
316
        device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
317

318
        if maximize:
319
            device_grads = torch._foreach_neg(device_grads)
320

321
        if weight_decay != 0:
322
            # Re-use the intermediate memory (device_grads) already allocated for maximize
323
            if maximize:
324
                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
325
            else:
326
                device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
327

328
        if momentum != 0:
329
            bufs = []
330

331
            all_states_with_momentum_buffer = True
332
            for i in range(len(device_momentum_buffer_list)):
333
                if device_momentum_buffer_list[i] is None:
334
                    all_states_with_momentum_buffer = False
335
                    break
336
                else:
337
                    bufs.append(device_momentum_buffer_list[i])
338

339
            if all_states_with_momentum_buffer:
340
                torch._foreach_mul_(bufs, momentum)
341
                torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
342
            else:
343
                bufs = []
344
                for i in range(len(device_momentum_buffer_list)):
345
                    if device_momentum_buffer_list[i] is None:
346
                        buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
347
                            torch.clone(device_grads[i]).detach()
348
                    else:
349
                        buf = device_momentum_buffer_list[i]
350
                        buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
351

352
                    bufs.append(buf)
353

354
            if nesterov:
355
                torch._foreach_add_(device_grads, bufs, alpha=momentum)
356
            else:
357
                device_grads = bufs
358

359
        if not device_has_sparse_grad:
360
            torch._foreach_add_(device_params, device_grads, alpha=-lr)
361
        else:
362
            # foreach APIs don't support sparse
363
            for i in range(len(device_params)):
364
                device_params[i].add_(device_grads[i], alpha=-lr)
365

366

367
def _fused_sgd(
368
    params: List[Tensor],
369
    grads: List[Tensor],
370
    momentum_buffer_list: List[Optional[Tensor]],
371
    grad_scale: Optional[Tensor],
372
    found_inf: Optional[Tensor],
373
    *,
374
    weight_decay: float,
375
    momentum: float,
376
    lr: float,
377
    dampening: float,
378
    nesterov: bool,
379
    maximize: bool,
380
    has_sparse_grad: bool,
381
) -> None:
382
    if not params:
383
        return
384
    if has_sparse_grad:
385
        raise RuntimeError("`_fused_sgd` does not support sparse gradients")
386
    grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
387
    found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
388

389
    no_momentum_buffer = momentum == 0
390
    is_first_step = all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
391
    if is_first_step:
392
        for i, g in enumerate(grads):
393
            momentum_buffer_list[i] = torch.empty_like(g)
394
    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
395
        [params, grads, momentum_buffer_list], with_indices=False)
396
    for (device, dtype), ((device_params, device_grads, device_momentum_buffer_list), _) in grouped_tensors.items():
397
        device_grad_scale, device_found_inf = None, None
398
        if grad_scale is not None:
399
            if device not in grad_scale_dict:
400
                grad_scale_dict[device] = grad_scale.to(device)
401
            device_grad_scale = grad_scale_dict[device]
402
        if found_inf is not None:
403
            if device not in found_inf_dict:
404
                found_inf_dict[device] = found_inf.to(device)
405
            device_found_inf = found_inf_dict[device]
406
        torch._fused_sgd_(
407
            device_params,
408
            device_grads,
409
            [] if no_momentum_buffer else device_momentum_buffer_list,
410
            weight_decay=weight_decay,
411
            momentum=momentum,
412
            lr=lr,
413
            dampening=dampening,
414
            nesterov=nesterov,
415
            maximize=maximize,
416
            is_first_step=is_first_step,
417
            grad_scale=device_grad_scale,
418
            found_inf=device_found_inf,
419
        )
420

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.