2
from torch import Tensor
3
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
4
_stack_if_compiling, _get_scalar_dtype, _default_to_fused_or_foreach,
5
_view_as_real, _capturable_doc, _differentiable_doc, _foreach_doc,)
6
from typing import List, Optional
8
__all__ = ['NAdam', 'nadam']
10
class NAdam(Optimizer):
11
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
12
weight_decay=0, momentum_decay=4e-3, decoupled_weight_decay: bool = False,
13
*, foreach: Optional[bool] = None, capturable: bool = False,
14
differentiable: bool = False):
16
raise ValueError(f"Invalid learning rate: {lr}")
18
raise ValueError(f"Invalid epsilon value: {eps}")
19
if not 0.0 <= betas[0] < 1.0:
20
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
21
if not 0.0 <= betas[1] < 1.0:
22
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
23
if not 0.0 <= weight_decay:
24
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
25
if not 0.0 <= momentum_decay:
26
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
27
defaults = dict(lr=lr, betas=betas, eps=eps,
28
weight_decay=weight_decay, momentum_decay=momentum_decay,
29
decoupled_weight_decay=decoupled_weight_decay,
30
foreach=foreach, capturable=capturable, differentiable=differentiable)
31
super().__init__(params, defaults)
33
def __setstate__(self, state):
34
super().__setstate__(state)
35
for group in self.param_groups:
36
group.setdefault('foreach', None)
37
group.setdefault('capturable', False)
38
group.setdefault('differentiable', False)
39
group.setdefault('decoupled_weight_decay', False)
40
for p in group["params"]:
41
p_state = self.state.get(p, [])
43
if not torch.is_tensor(p_state['step']):
44
step_val = float(p_state["step"])
45
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
46
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
47
if not torch.is_tensor(p_state['mu_product']):
48
mu_prod_val = p_state["mu_product"]
49
p_state["mu_product"] = (torch.tensor(mu_prod_val, dtype=_get_scalar_dtype(), device=p.device)
50
if group['capturable'] else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()))
53
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
55
for p in group['params']:
56
if p.grad is not None:
57
has_complex |= torch.is_complex(p)
58
params_with_grad.append(p)
60
raise RuntimeError('NAdam does not support sparse gradients')
70
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
71
if group['capturable'] else torch.tensor(0.0, dtype=_get_scalar_dtype())
73
state['mu_product'] = (
74
torch.ones((), dtype=_get_scalar_dtype(), device=p.device)
75
if group['capturable'] else torch.tensor(1.0, dtype=_get_scalar_dtype())
78
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
80
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
82
exp_avgs.append(state['exp_avg'])
83
exp_avg_sqs.append(state['exp_avg_sq'])
84
mu_products.append(state['mu_product'])
85
state_steps.append(state['step'])
88
@_use_grad_for_differentiable
89
def step(self, closure=None):
90
"""Performs a single optimization step.
93
closure (Callable, optional): A closure that reevaluates the model
96
self._cuda_graph_capture_health_check()
99
if closure is not None:
100
with torch.enable_grad():
103
for group in self.param_groups:
104
params_with_grad = []
110
beta1, beta2 = group['betas']
112
has_complex = self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps)
114
nadam(params_with_grad,
123
weight_decay=group['weight_decay'],
124
momentum_decay=group['momentum_decay'],
126
decoupled_weight_decay=group['decoupled_weight_decay'],
127
foreach=group['foreach'],
128
capturable=group['capturable'],
129
differentiable=group['differentiable'],
130
has_complex=has_complex)
134
NAdam.__doc__ = r"""Implements NAdam algorithm.
138
&\rule{110mm}{0.4pt} \\
139
&\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
140
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
141
&\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
142
&\hspace{13mm} \: \textit{decoupled\_weight\_decay} \\
143
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
144
v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
145
&\rule{110mm}{0.4pt} \\
146
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
147
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
148
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\
149
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
150
&\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\
151
&\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
152
&\hspace{10mm}\textbf{else} \\
153
&\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
154
&\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
155
&\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
156
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
157
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
158
&\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
159
& \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
160
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
161
&\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
162
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
163
&\rule{110mm}{0.4pt} \\[-1.ex]
164
&\bf{return} \: \theta_t \\[-1.ex]
165
&\rule{110mm}{0.4pt} \\[-1.ex]
168
For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
171
params (iterable): iterable of parameters to optimize or dicts defining
173
lr (float, optional): learning rate (default: 2e-3)
174
betas (Tuple[float, float], optional): coefficients used for computing
175
running averages of gradient and its square (default: (0.9, 0.999))
176
eps (float, optional): term added to the denominator to improve
177
numerical stability (default: 1e-8)
178
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
179
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
180
decoupled_weight_decay (bool, optional): whether to use decoupled weight
181
decay as in AdamW to obtain NAdamW (default: False)
184
{_differentiable_doc}
186
.. _Incorporating Nesterov Momentum into Adam:
187
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
188
.. _Decoupled Weight Decay Regularization:
189
https://arxiv.org/abs/1711.05101
194
def nadam(params: List[Tensor],
196
exp_avgs: List[Tensor],
197
exp_avg_sqs: List[Tensor],
198
mu_products: List[Tensor],
199
state_steps: List[Tensor],
202
decoupled_weight_decay: bool = False,
203
foreach: Optional[bool] = None,
204
capturable: bool = False,
205
differentiable: bool = False,
206
has_complex: bool = False,
212
momentum_decay: float,
214
r"""Functional API that performs NAdam algorithm computation.
216
See :class:`~torch.optim.NAdam` for details.
220
if not all(isinstance(t, torch.Tensor) for t in state_steps):
221
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
223
if not all(isinstance(t, torch.Tensor) for t in mu_products):
224
raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
227
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
229
if foreach and torch.jit.is_scripting():
230
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
232
if foreach and not torch.jit.is_scripting():
233
func = _multi_tensor_nadam
235
func = _single_tensor_nadam
246
weight_decay=weight_decay,
247
momentum_decay=momentum_decay,
248
decoupled_weight_decay=decoupled_weight_decay,
250
capturable=capturable,
251
differentiable=differentiable,
252
has_complex=has_complex)
255
def _single_tensor_nadam(params: List[Tensor],
257
exp_avgs: List[Tensor],
258
exp_avg_sqs: List[Tensor],
259
mu_products: List[Tensor],
260
state_steps: List[Tensor],
266
momentum_decay: float,
268
decoupled_weight_decay: bool,
270
differentiable: bool,
273
for i, param in enumerate(params):
275
exp_avg = exp_avgs[i]
276
exp_avg_sq = exp_avg_sqs[i]
277
mu_product = mu_products[i]
278
step_t = state_steps[i]
280
if torch.is_complex(param):
281
param = torch.view_as_real(param)
282
grad = torch.view_as_real(grad)
283
exp_avg = torch.view_as_real(exp_avg)
284
exp_avg_sq = torch.view_as_real(exp_avg_sq)
287
if not torch._utils.is_compiling() and capturable:
289
(param.is_cuda and mu_product.is_cuda and step_t.is_cuda) or (param.is_xla and mu_product.is_xla and step_t.is_xla)
290
), "If capturable=True, params, mu_products, and state_steps must be CUDA or XLA tensors."
298
step = _get_value(step_t)
300
bias_correction2 = 1 - beta2 ** step
302
if weight_decay != 0:
303
if decoupled_weight_decay:
305
param.mul_(1 - lr * weight_decay)
307
grad = grad.add(param, alpha=weight_decay)
310
mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay)))
311
mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
317
exp_avg.lerp_(grad, 1 - beta1)
318
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
319
denom = exp_avg_sq.div(bias_correction2).sqrt()
321
if differentiable or capturable:
322
denom = denom.add(eps)
326
mu_product_next = mu_product * mu_next
327
grad = grad * (-lr * (1. - mu) / (1. - mu_product))
328
exp_avg = exp_avg * (-lr * mu_next / (1. - mu_product_next))
329
param.addcdiv_(grad, denom)
330
param.addcdiv_(exp_avg, denom)
332
mu_product_next = _get_value(mu_product) * mu_next
334
param.addcdiv_(grad, denom, value=(-lr * (1. - mu) / (1. - _get_value(mu_product))))
335
param.addcdiv_(exp_avg, denom, value=(-lr * mu_next) / (1. - mu_product_next))
338
def _multi_tensor_nadam(params: List[Tensor],
340
exp_avgs: List[Tensor],
341
exp_avg_sqs: List[Tensor],
342
mu_products: List[Tensor],
343
state_steps: List[Tensor],
349
momentum_decay: float,
351
decoupled_weight_decay: bool,
353
differentiable: bool,
359
assert not differentiable, "_foreach ops don't support autograd"
362
if not torch._utils.is_compiling() and capturable:
363
assert all(p.is_cuda and mp.is_cuda and step.is_cuda
364
for p, mp, step in zip(params, mu_products, state_steps)), \
365
"If capturable=True, params, mu_products, and state_steps must be CUDA tensors."
368
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps])
369
for ((grouped_params, grouped_grads, grouped_exp_avgs,
370
grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps), _) in grouped_tensors.values():
374
_view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs)
380
if grouped_state_steps[0].is_cpu:
381
torch._foreach_add_(grouped_state_steps, torch.tensor(1.0, device='cpu'), alpha=1.0)
383
torch._foreach_add_(grouped_state_steps, 1)
385
if weight_decay != 0:
386
if decoupled_weight_decay:
388
torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
390
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
393
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
395
torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
396
torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
398
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
402
exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
403
mus = torch._foreach_pow(0.96, exponent)
404
torch._foreach_mul_(mus, -0.5)
405
torch._foreach_add_(mus, 1.0)
406
torch._foreach_mul_(mus, beta1)
409
torch._foreach_add_(exponent, momentum_decay)
410
mu_nexts = torch._foreach_pow(0.96, exponent)
411
torch._foreach_mul_(mu_nexts, -0.5)
412
torch._foreach_add_(mu_nexts, 1.0)
413
torch._foreach_mul_(mu_nexts, beta1)
418
bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps)
420
torch._foreach_sub_(bias_correction_sqrt, 1.0)
421
torch._foreach_neg_(bias_correction_sqrt)
422
torch._foreach_sqrt_(bias_correction_sqrt)
424
bias_correction_sqrt = [_dispatch_sqrt(1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
425
mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) for step in grouped_state_steps]
426
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
427
for step in grouped_state_steps]
430
torch._foreach_mul_(grouped_mu_products, mus)
432
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
433
torch._foreach_add_(exp_avg_sq_sqrt, eps)
436
del bias_correction_sqrt
440
torch._foreach_sub_(mus, 1.0)
441
torch._foreach_mul_(mus, lr)
443
denom = torch._foreach_sub(grouped_mu_products, 1.0)
444
torch._foreach_neg_(denom)
445
torch._foreach_div_(mus, denom)
447
step_size_grads = mus
452
denom = torch._foreach_mul(grouped_mu_products, mu_nexts)
453
torch._foreach_mul_(mu_nexts, lr)
456
torch._foreach_sub_(denom, 1.0)
457
torch._foreach_div_(mu_nexts, denom)
459
step_size_expavg = mu_nexts
465
numerator = torch._foreach_mul(step_size_grads, grouped_grads)
466
torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs)
469
torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt)
471
step_size_grads = _stack_if_compiling([(lr * (1. - mu) / (1. - _get_value(mu_product))) * -1
472
for mu_product, mu in zip(grouped_mu_products, mus)])
473
step_size_expavg = _stack_if_compiling([(lr * mu_next / (1. - _get_value(mu_product) * mu_next)) * -1
474
for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)])
476
torch._foreach_addcdiv_(grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads)
477
torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg)