17
from torch.optim import Optimizer
21
"""Implements SM3 algorithm.
22
It has been proposed in `Memory-Efficient Adaptive Optimization`_.
24
params (iterable): iterable of parameters to optimize or dicts defining
26
lr (float, optional): coefficient that scale delta before it is applied
27
to the parameters (default: 0.1)
28
momentum (float, optional): coefficient used to scale prior updates
29
before adding. This drastically increases memory usage if
30
`momentum > 0.0`. This is ignored if the parameter's gradient
31
is sparse. (default: 0.0)
32
beta (float, optional): coefficient used for exponential moving
33
averages (default: 0.0)
34
eps (float, optional): Term added to square-root in denominator to
35
improve numerical stability (default: 1e-30)
36
.. _Memory-Efficient Adaptive Optimization:
37
https://arxiv.org/abs/1901.11150
40
def __init__(self, params, lr=0.1, momentum=0.0, beta=0.0, eps=1e-30):
42
raise ValueError("Invalid learning rate: {0}".format(lr))
43
if not 0.0 <= momentum < 1.0:
44
raise ValueError("Invalid momentum: {0}".format(momentum))
45
if not 0.0 <= beta < 1.0:
46
raise ValueError("Invalid beta: {0}".format(beta))
48
raise ValueError("Invalid eps: {0}".format(eps))
50
defaults = {"lr": lr, "momentum": momentum, "beta": beta, "eps": eps}
51
super(SM3, self).__init__(params, defaults)
54
def step(self, closure=None):
55
"""Performs a single optimization step.
57
closure (callable, optional): A closure that reevaluates the model
61
if closure is not None:
62
with torch.enable_grad():
65
for group in self.param_groups:
66
momentum = group["momentum"]
69
for p in group["params"]:
81
state["momentum_buffer"] = 0.0
82
_add_initial_accumulators(state, grad)
87
grad_indices = grad._indices()
88
grad_values = grad._values()
91
def make_sparse(values):
92
constructor = grad.new
93
if grad_indices.dim() == 0 or values.dim() == 0:
94
return constructor().resize_as_(grad)
95
return constructor(grad_indices, values, grad.size())
98
update_values = _compute_sparse_update(
99
beta, acc, grad_values, grad_indices
102
self._update_sparse_accumulator(
103
beta, acc, make_sparse(update_values)
107
update_values.add_(eps).rsqrt_().mul_(grad_values)
109
update = make_sparse(update_values)
113
acc_list = [state[_key(i)] for i in range(rank)]
115
acc_list = [state[_key(0)]]
118
update = _compute_update(beta, acc_list, grad)
121
self._update_accumulator(beta, acc_list, update)
124
update.add_(eps).rsqrt_().mul_(grad)
127
m = state["momentum_buffer"]
128
update.mul_(1.0 - momentum).add_(m, alpha=momentum)
129
state["momentum_buffer"] = update.detach()
131
p.sub_(update, alpha=group["lr"])
136
def _update_accumulator(beta, acc_list, update):
137
for i, acc in enumerate(acc_list):
138
nu_max = _max_reduce_except_dim(update, i)
140
torch.max(acc, nu_max, out=acc)
146
def _update_sparse_accumulator(beta, acc, update):
147
nu_max = _max_reduce_except_dim(update.to_dense(), 0).squeeze()
149
torch.max(acc, nu_max, out=acc)
155
def _compute_sparse_update(beta, acc, grad_values, grad_indices):
157
update_values = torch.gather(acc, 0, grad_indices[0])
159
update_values.mul_(beta)
160
update_values.addcmul_(grad_values, grad_values, value=1.0 - beta)
164
def _compute_update(beta, acc_list, grad):
166
update = acc_list[0].clone()
167
for i in range(1, rank):
169
update = torch.min(update, acc_list[i])
172
update.addcmul_(grad, grad, value=1.0 - beta)
179
return "accumulator_" + str(i)
182
def _add_initial_accumulators(state, grad):
189
defaults = {"device": grad.device, "dtype": grad.dtype}
193
acc[_key(0)] = torch.zeros(shape[0], **defaults)
196
acc[_key(0)] = torch.zeros(shape, **defaults)
198
for i in range(rank):
199
acc_shape = [1] * i + [shape[i]] + [1] * (rank - 1 - i)
200
acc[_key(i)] = torch.zeros(acc_shape, **defaults)
205
def _max_reduce_except_dim(tensor, dim):
208
rank = len(tensor.shape)
212
for d in range(rank):
214
result = result.max(dim=d, keepdim=True).values
230
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional, Tuple
237
from torch.optim.optimizer import _params_t
242
class madgrad_wd(torch.optim.Optimizer):
244
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
247
.. _MADGRAD: https://arxiv.org/abs/2101.11075
249
MADGRAD is a general purpose optimizer that can be used in place of SGD or
250
Adam may converge faster and generalize better. Currently GPU-only.
251
Typically, the same learning rate schedule that is used for SGD or Adam may
252
be used. The overall learning rate is not comparable to either method and
253
should be determined by a hyper-parameter sweep.
255
MADGRAD requires less weight decay than other methods, often as little as
256
zero. Momentum values used for SGD or Adam's beta1 should work here also.
258
On sparse problems both weight_decay and momentum should be set to 0.
262
Iterable of parameters to optimize or dicts defining parameter groups.
264
Learning rate (default: 1e-2).
266
Momentum value in the range [0,1) (default: 0.9).
267
weight_decay (float):
268
Weight decay, i.e. a L2 penalty (default: 0).
270
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
277
momentum: float = 0.9,
278
weight_decay: float = 0,
281
if momentum < 0 or momentum >= 1:
282
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
284
raise ValueError(f"Learning rate {lr} must be positive")
286
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
288
raise ValueError(f"Eps must be non-negative")
290
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
291
super().__init__(params, defaults)
294
def supports_memory_efficient_fp16(self) -> bool:
298
def supports_flat_params(self) -> bool:
301
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
302
"""Performs a single optimization step.
305
closure (callable, optional): A closure that reevaluates the model
306
and returns the loss.
309
if closure is not None and isinstance(closure, collections.Callable):
314
if "k" not in self.state:
315
self.state["k"] = torch.tensor([0], dtype=torch.long)
316
k = self.state["k"].item()
318
for group in self.param_groups:
320
lr = group["lr"] + eps
321
decay = group["weight_decay"]
322
momentum = group["momentum"]
325
lamb = lr * math.pow(k + 1, 0.5)
327
for p in group["params"]:
331
state = self.state[p]
333
if "grad_sum_sq" not in state:
334
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
335
state["s"] = torch.zeros_like(p.data).detach()
337
state["x0"] = torch.clone(p.data).detach()
339
if momentum != 0.0 and grad.is_sparse:
341
"momentum != 0 is not compatible with sparse gradients"
344
grad_sum_sq = state["grad_sum_sq"]
349
p.data.mul_(1 - lr * decay)
354
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
356
grad.add_(p.data, alpha=decay)
360
grad = grad.coalesce()
361
grad_val = grad._values()
363
p_masked = p.sparse_mask(grad)
364
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
365
s_masked = s.sparse_mask(grad)
368
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
369
x0_masked_vals = p_masked._values().addcdiv(
370
s_masked._values(), rms_masked_vals, value=1
374
grad_sq = grad * grad
375
grad_sum_sq.add_(grad_sq, alpha=lamb)
376
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
378
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
380
s.add_(grad, alpha=lamb)
381
s_masked._values().add_(grad_val, alpha=lamb)
384
p_kp1_masked_vals = x0_masked_vals.addcdiv(
385
s_masked._values(), rms_masked_vals, value=-1
388
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
389
p.data.add_(p_masked, alpha=-1)
393
rms = grad_sum_sq.pow(1 / 3).add_(eps)
394
x0 = p.data.addcdiv(s, rms, value=1)
399
grad_sum_sq.addcmul_(grad, grad, value=lamb)
400
rms = grad_sum_sq.pow(1 / 3).add_(eps)
403
s.data.add_(grad, alpha=lamb)
407
p.data.copy_(x0.addcdiv(s, rms, value=-1))
409
z = x0.addcdiv(s, rms, value=-1)
412
p.data.mul_(1 - ck).add_(z, alpha=ck)
418
class Lion(Optimizer):
420
Implements the Lion Algorithm
422
.. / _Lion: https://arxiv.org/abs/2302.06675
424
Compared to AdamW and various adaptive optimizers that need to save both first and second moments,
425
Lion only needs the momentum, halving the additional memory footprint. This is beneficial when training large models
426
and / or with a large batch size.
430
Iterable of parameters to optimize or dicts defining parameter groups.
432
Learning rate (default: 1e-2).
434
coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
435
weight_decay (float):
436
Weight decay, i.e. a L2 penalty (default: 0).
444
betas: Tuple[float, float] = (0.9, 0.99),
445
weight_decay: float = 0.0,
448
raise ValueError(f"Learning rate {lr} must be positive")
450
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
451
if not (0 <= betas[0] <= 1 and 0 <= betas[1] <= 1):
452
raise ValueError(f"Betas {betas} must be in range [0, 1)")
454
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
455
super().__init__(params, defaults)
457
def update(self, p, grad, exp_avg, lr, wd, beta1, beta2):
458
"""https://arxiv.org/pdf/2302.06675.pdf#appendix.A"""
462
sign = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
463
p.add_(sign, alpha=-lr)
466
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
469
def step(self, closure: Optional[Callable] = None):
472
if closure is not None:
473
with torch.enable_grad():
476
for group in self.param_groups:
477
for p in group["params"]:
481
state = self.state[p]
485
state["exp_avg"] = torch.zeros_like(p.data).detach()
492
group["weight_decay"],