2
from torch import Tensor
3
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
4
_differentiable_doc, _foreach_doc, _maximize_doc, _fused_doc)
5
from typing import List, Optional
7
__all__ = ['SGD', 'sgd']
11
def __init__(self, params, lr=1e-3, momentum=0, dampening=0,
12
weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
13
differentiable: bool = False, fused: Optional[bool] = None):
15
raise ValueError(f"Invalid learning rate: {lr}")
17
raise ValueError(f"Invalid momentum value: {momentum}")
18
if weight_decay < 0.0:
19
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
21
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
22
weight_decay=weight_decay, nesterov=nesterov,
23
maximize=maximize, foreach=foreach,
24
differentiable=differentiable, fused=fused)
25
if nesterov and (momentum <= 0 or dampening != 0):
26
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
27
super().__init__(params, defaults)
30
self._step_supports_amp_scaling = True
32
raise RuntimeError("`fused` does not support `differentiable`")
34
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
36
def __setstate__(self, state):
37
super().__setstate__(state)
38
for group in self.param_groups:
39
group.setdefault('nesterov', False)
40
group.setdefault('maximize', False)
41
group.setdefault('foreach', None)
42
group.setdefault('differentiable', False)
43
group.setdefault('fused', False)
45
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
46
has_sparse_grad = False
48
for p in group['params']:
49
if p.grad is not None:
50
params_with_grad.append(p)
51
d_p_list.append(p.grad)
53
has_sparse_grad = True
56
momentum_buffer_list.append(state.get('momentum_buffer'))
58
return has_sparse_grad
60
@_use_grad_for_differentiable
61
def step(self, closure=None):
62
"""Performs a single optimization step.
65
closure (Callable, optional): A closure that reevaluates the model
69
if closure is not None:
70
with torch.enable_grad():
73
for group in self.param_groups:
76
momentum_buffer_list = []
78
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
83
weight_decay=group['weight_decay'],
84
momentum=group['momentum'],
86
dampening=group['dampening'],
87
nesterov=group['nesterov'],
88
maximize=group['maximize'],
89
has_sparse_grad=has_sparse_grad,
90
foreach=group['foreach'],
92
grad_scale=getattr(self, "grad_scale", None),
93
found_inf=getattr(self, "found_inf", None))
95
# update momentum_buffers in state
96
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
98
state['momentum_buffer'] = momentum_buffer
103
SGD.__doc__ = r"""Implements stochastic gradient descent (optionally with momentum).
107
&\rule{110mm}{0.4pt} \\
108
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
109
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
110
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
111
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
112
&\rule{110mm}{0.4pt} \\
113
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
114
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
115
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
116
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
117
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
118
&\hspace{10mm}\textbf{if} \: t > 1 \\
119
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
120
&\hspace{10mm}\textbf{else} \\
121
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
122
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
123
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
124
&\hspace{10mm}\textbf{else} \\[-1.ex]
125
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
126
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
127
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
128
&\hspace{5mm}\textbf{else} \\[-1.ex]
129
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
130
&\rule{110mm}{0.4pt} \\[-1.ex]
131
&\bf{return} \: \theta_t \\[-1.ex]
132
&\rule{110mm}{0.4pt} \\[-1.ex]
135
Nesterov momentum is based on the formula from
136
`On the importance of initialization and momentum in deep learning`__.
139
params (iterable): iterable of parameters to optimize or dicts defining
141
lr (float, optional): learning rate (default: 1e-3)
142
momentum (float, optional): momentum factor (default: 0)
143
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
144
dampening (float, optional): dampening for momentum (default: 0)
145
nesterov (bool, optional): enables Nesterov momentum (default: False)
148
{_differentiable_doc}
153
>>> # xdoctest: +SKIP
154
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
155
>>> optimizer.zero_grad()
156
>>> loss_fn(model(input), target).backward()
159
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
162
The implementation of SGD with Momentum/Nesterov subtly differs from
163
Sutskever et. al. and implementations in some other frameworks.
165
Considering the specific case of Momentum, the update can be written as
169
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
170
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
173
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
174
parameters, gradient, velocity, and momentum respectively.
176
This is in contrast to Sutskever et. al. and
177
other frameworks which employ an update of the form
181
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
182
p_{t+1} & = p_{t} - v_{t+1}.
185
The Nesterov version is analogously modified.
187
Moreover, the initial value of the momentum buffer is set to the
188
gradient value at the first step. This is in contrast to some other
189
frameworks that initialize it to all zeros.
194
def sgd(params: List[Tensor],
195
d_p_list: List[Tensor],
196
momentum_buffer_list: List[Optional[Tensor]],
197
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
198
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
199
has_sparse_grad: bool = None,
200
foreach: Optional[bool] = None,
201
fused: Optional[bool] = None,
202
grad_scale: Optional[Tensor] = None,
203
found_inf: Optional[Tensor] = None,
211
r"""Functional API that performs SGD algorithm computation.
213
See :class:`~torch.optim.SGD` for details.
216
# Respect when the user inputs False/True for foreach or fused. We only want to change
217
# the default when neither have been user-specified. Note that we default to foreach
218
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
219
# bake-in time before making it the default, even if it is typically faster.
220
if foreach is None and fused is None:
221
# why must we be explicit about an if statement for torch.jit.is_scripting here?
222
# because JIT can't handle Optionals nor fancy conditionals when scripting
223
if not torch.jit.is_scripting():
224
fused, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
233
if foreach and torch.jit.is_scripting():
234
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
235
if fused and torch.jit.is_scripting():
236
raise RuntimeError('torch.jit.script not supported with fused optimizers')
238
if foreach and not torch.jit.is_scripting():
239
func = _multi_tensor_sgd
240
elif fused and not torch.jit.is_scripting():
243
func = _single_tensor_sgd
247
momentum_buffer_list,
248
weight_decay=weight_decay,
253
has_sparse_grad=has_sparse_grad,
255
grad_scale=grad_scale,
258
def _single_tensor_sgd(params: List[Tensor],
259
d_p_list: List[Tensor],
260
momentum_buffer_list: List[Optional[Tensor]],
261
grad_scale: Optional[Tensor],
262
found_inf: Optional[Tensor],
270
has_sparse_grad: bool):
271
assert grad_scale is None and found_inf is None
273
for i, param in enumerate(params):
274
d_p = d_p_list[i] if not maximize else -d_p_list[i]
276
if weight_decay != 0:
277
d_p = d_p.add(param, alpha=weight_decay)
280
buf = momentum_buffer_list[i]
283
buf = torch.clone(d_p).detach()
284
momentum_buffer_list[i] = buf
286
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
289
d_p = d_p.add(buf, alpha=momentum)
293
param.add_(d_p, alpha=-lr)
296
def _multi_tensor_sgd(params: List[Tensor],
298
momentum_buffer_list: List[Optional[Tensor]],
299
grad_scale: Optional[Tensor],
300
found_inf: Optional[Tensor],
308
has_sparse_grad: bool):
309
assert grad_scale is None and found_inf is None
314
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
315
for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
316
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
319
device_grads = torch._foreach_neg(device_grads)
321
if weight_decay != 0:
322
# Re-use the intermediate memory (device_grads) already allocated for maximize
324
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
326
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
331
all_states_with_momentum_buffer = True
332
for i in range(len(device_momentum_buffer_list)):
333
if device_momentum_buffer_list[i] is None:
334
all_states_with_momentum_buffer = False
337
bufs.append(device_momentum_buffer_list[i])
339
if all_states_with_momentum_buffer:
340
torch._foreach_mul_(bufs, momentum)
341
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
344
for i in range(len(device_momentum_buffer_list)):
345
if device_momentum_buffer_list[i] is None:
346
buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
347
torch.clone(device_grads[i]).detach()
349
buf = device_momentum_buffer_list[i]
350
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
355
torch._foreach_add_(device_grads, bufs, alpha=momentum)
359
if not device_has_sparse_grad:
360
torch._foreach_add_(device_params, device_grads, alpha=-lr)
362
# foreach APIs don't support sparse
363
for i in range(len(device_params)):
364
device_params[i].add_(device_grads[i], alpha=-lr)
368
params: List[Tensor],
370
momentum_buffer_list: List[Optional[Tensor]],
371
grad_scale: Optional[Tensor],
372
found_inf: Optional[Tensor],
380
has_sparse_grad: bool,
385
raise RuntimeError("`_fused_sgd` does not support sparse gradients")
386
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
387
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
389
no_momentum_buffer = momentum == 0
390
is_first_step = all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
392
for i, g in enumerate(grads):
393
momentum_buffer_list[i] = torch.empty_like(g)
394
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
395
[params, grads, momentum_buffer_list], with_indices=False)
396
for (device, dtype), ((device_params, device_grads, device_momentum_buffer_list), _) in grouped_tensors.items():
397
device_grad_scale, device_found_inf = None, None
398
if grad_scale is not None:
399
if device not in grad_scale_dict:
400
grad_scale_dict[device] = grad_scale.to(device)
401
device_grad_scale = grad_scale_dict[device]
402
if found_inf is not None:
403
if device not in found_inf_dict:
404
found_inf_dict[device] = found_inf.to(device)
405
device_found_inf = found_inf_dict[device]
409
[] if no_momentum_buffer else device_momentum_buffer_list,
410
weight_decay=weight_decay,
416
is_first_step=is_first_step,
417
grad_scale=device_grad_scale,
418
found_inf=device_found_inf,