gpt-neox

Форк
0
/
optimizers.py 
497 строк · 17.7 Кб
1
# Copyright (c) 2024, EleutherAI
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15

16
import torch
17
from torch.optim import Optimizer
18

19

20
class SM3(Optimizer):
21
    """Implements SM3 algorithm.
22
    It has been proposed in `Memory-Efficient Adaptive Optimization`_.
23
    Arguments:
24
        params (iterable): iterable of parameters to optimize or dicts defining
25
            parameter groups
26
        lr (float, optional): coefficient that scale delta before it is applied
27
            to the parameters (default: 0.1)
28
        momentum (float, optional): coefficient used to scale prior updates
29
            before adding. This drastically increases memory usage if
30
            `momentum > 0.0`. This is ignored if the parameter's gradient
31
            is sparse. (default: 0.0)
32
        beta (float, optional): coefficient used for exponential moving
33
            averages (default: 0.0)
34
        eps (float, optional): Term added to square-root in denominator to
35
            improve numerical stability (default: 1e-30)
36
    .. _Memory-Efficient Adaptive Optimization:
37
        https://arxiv.org/abs/1901.11150
38
    """
39

40
    def __init__(self, params, lr=0.1, momentum=0.0, beta=0.0, eps=1e-30):
41
        if not 0.0 <= lr:
42
            raise ValueError("Invalid learning rate: {0}".format(lr))
43
        if not 0.0 <= momentum < 1.0:
44
            raise ValueError("Invalid momentum: {0}".format(momentum))
45
        if not 0.0 <= beta < 1.0:
46
            raise ValueError("Invalid beta: {0}".format(beta))
47
        if not 0.0 <= eps:
48
            raise ValueError("Invalid eps: {0}".format(eps))
49

50
        defaults = {"lr": lr, "momentum": momentum, "beta": beta, "eps": eps}
51
        super(SM3, self).__init__(params, defaults)
52

53
    @torch.no_grad()
54
    def step(self, closure=None):
55
        """Performs a single optimization step.
56
        Arguments:
57
            closure (callable, optional): A closure that reevaluates the model
58
                and returns the loss.
59
        """
60
        loss = None
61
        if closure is not None:
62
            with torch.enable_grad():
63
                loss = closure()
64

65
        for group in self.param_groups:
66
            momentum = group["momentum"]
67
            beta = group["beta"]
68
            eps = group["eps"]
69
            for p in group["params"]:
70
                if p is None:
71
                    continue
72
                grad = p.grad
73

74
                state = self.state[p]
75
                shape = grad.shape
76
                rank = len(shape)
77

78
                # State initialization
79
                if len(state) == 0:
80
                    state["step"] = 0
81
                    state["momentum_buffer"] = 0.0
82
                    _add_initial_accumulators(state, grad)
83

84
                if grad.is_sparse:
85
                    # the update is non-linear so indices must be unique
86
                    grad.coalesce()
87
                    grad_indices = grad._indices()
88
                    grad_values = grad._values()
89

90
                    # Transform update_values into sparse tensor
91
                    def make_sparse(values):
92
                        constructor = grad.new
93
                        if grad_indices.dim() == 0 or values.dim() == 0:
94
                            return constructor().resize_as_(grad)
95
                        return constructor(grad_indices, values, grad.size())
96

97
                    acc = state[_key(0)]
98
                    update_values = _compute_sparse_update(
99
                        beta, acc, grad_values, grad_indices
100
                    )
101

102
                    self._update_sparse_accumulator(
103
                        beta, acc, make_sparse(update_values)
104
                    )
105

106
                    # Add small amount for numerical stability
107
                    update_values.add_(eps).rsqrt_().mul_(grad_values)
108

109
                    update = make_sparse(update_values)
110
                else:
111
                    # Get previous accumulators mu_{t-1}
112
                    if rank > 1:
113
                        acc_list = [state[_key(i)] for i in range(rank)]
114
                    else:
115
                        acc_list = [state[_key(0)]]
116

117
                    # Get update from accumulators and gradients
118
                    update = _compute_update(beta, acc_list, grad)
119

120
                    # Update accumulators.
121
                    self._update_accumulator(beta, acc_list, update)
122

123
                    # Add small amount for numerical stability
124
                    update.add_(eps).rsqrt_().mul_(grad)
125

126
                    if momentum > 0.0:
127
                        m = state["momentum_buffer"]
128
                        update.mul_(1.0 - momentum).add_(m, alpha=momentum)
129
                        state["momentum_buffer"] = update.detach()
130

131
                p.sub_(update, alpha=group["lr"])
132
                state["step"] += 1
133
        return loss
134

135
    @staticmethod
136
    def _update_accumulator(beta, acc_list, update):
137
        for i, acc in enumerate(acc_list):
138
            nu_max = _max_reduce_except_dim(update, i)
139
            if beta > 0.0:
140
                torch.max(acc, nu_max, out=acc)
141
            else:
142
                # No need to compare - nu_max is bigger because of grad ** 2
143
                acc.copy_(nu_max)
144

145
    @staticmethod
146
    def _update_sparse_accumulator(beta, acc, update):
147
        nu_max = _max_reduce_except_dim(update.to_dense(), 0).squeeze()
148
        if beta > 0.0:
149
            torch.max(acc, nu_max, out=acc)
150
        else:
151
            # No need to compare - nu_max is bigger because of grad ** 2
152
            acc.copy_(nu_max)
153

154

155
def _compute_sparse_update(beta, acc, grad_values, grad_indices):
156
    # In the sparse case, a single accumulator is used.
157
    update_values = torch.gather(acc, 0, grad_indices[0])
158
    if beta > 0.0:
159
        update_values.mul_(beta)
160
    update_values.addcmul_(grad_values, grad_values, value=1.0 - beta)
161
    return update_values
162

163

164
def _compute_update(beta, acc_list, grad):
165
    rank = len(acc_list)
166
    update = acc_list[0].clone()
167
    for i in range(1, rank):
168
        # We rely on broadcasting to get the proper end shape.
169
        update = torch.min(update, acc_list[i])
170
    if beta > 0.0:
171
        update.mul_(beta)
172
    update.addcmul_(grad, grad, value=1.0 - beta)
173

174
    return update
175

176

177
def _key(i):
178
    # Returns key used for accessing accumulators
179
    return "accumulator_" + str(i)
180

181

182
def _add_initial_accumulators(state, grad):
183
    # Creates initial accumulators. For a dense tensor of shape (n1, n2, n3),
184
    # then our initial accumulators are of shape (n1, 1, 1), (1, n2, 1) and
185
    # (1, 1, n3). For a sparse tensor of shape (n, *), we use a single
186
    # accumulator of shape (n,).
187
    shape = grad.shape
188
    rank = len(shape)
189
    defaults = {"device": grad.device, "dtype": grad.dtype}
190
    acc = {}
191

192
    if grad.is_sparse:
193
        acc[_key(0)] = torch.zeros(shape[0], **defaults)
194
    elif rank == 0:
195
        # The scalar case is handled separately
196
        acc[_key(0)] = torch.zeros(shape, **defaults)
197
    else:
198
        for i in range(rank):
199
            acc_shape = [1] * i + [shape[i]] + [1] * (rank - 1 - i)
200
            acc[_key(i)] = torch.zeros(acc_shape, **defaults)
201

202
    state.update(acc)
203

204

205
def _max_reduce_except_dim(tensor, dim):
206
    # Computes max along all dimensions except the given dim.
207
    # If tensor is a scalar, it returns tensor.
208
    rank = len(tensor.shape)
209
    result = tensor
210
    if rank > 0:
211
        assert dim < rank
212
        for d in range(rank):
213
            if d != dim:
214
                result = result.max(dim=d, keepdim=True).values
215
    return result
216

217

218
# Copyright (c) Facebook, Inc. and its affiliates.
219
#
220
# This source code is licensed under the MIT license found in the
221
# LICENSE file in the root directory of this source tree.
222

223
# modifications  - 4/4/2021  @lessw2020  (decay issue spotted by @nestordemeure )
224
# weight decay has been implemented AdamW style instead of the original madgrad Adam style.
225
# in initial image classification testing, this outperformed 0 weight decay or original style weight decay.
226

227
# closure is checked if callable or not since some code passes loss directly, rather than in closure param
228

229
import math
230
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional, Tuple
231

232
import torch
233
import torch.optim
234
import collections
235

236
if TYPE_CHECKING:
237
    from torch.optim.optimizer import _params_t
238
else:
239
    _params_t = Any
240

241

242
class madgrad_wd(torch.optim.Optimizer):
243
    """
244
    MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
245
    Optimization.
246

247
    .. _MADGRAD: https://arxiv.org/abs/2101.11075
248

249
    MADGRAD is a general purpose optimizer that can be used in place of SGD or
250
    Adam may converge faster and generalize better. Currently GPU-only.
251
    Typically, the same learning rate schedule that is used for SGD or Adam may
252
    be used. The overall learning rate is not comparable to either method and
253
    should be determined by a hyper-parameter sweep.
254

255
    MADGRAD requires less weight decay than other methods, often as little as
256
    zero. Momentum values used for SGD or Adam's beta1 should work here also.
257

258
    On sparse problems both weight_decay and momentum should be set to 0.
259

260
    Arguments:
261
        params (iterable):
262
            Iterable of parameters to optimize or dicts defining parameter groups.
263
        lr (float):
264
            Learning rate (default: 1e-2).
265
        momentum (float):
266
            Momentum value in  the range [0,1) (default: 0.9).
267
        weight_decay (float):
268
            Weight decay, i.e. a L2 penalty (default: 0).
269
        eps (float):
270
            Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
271
    """
272

273
    def __init__(
274
        self,
275
        params: _params_t,
276
        lr: float = 1e-2,
277
        momentum: float = 0.9,
278
        weight_decay: float = 0,
279
        eps: float = 1e-6,
280
    ):
281
        if momentum < 0 or momentum >= 1:
282
            raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
283
        if lr <= 0:
284
            raise ValueError(f"Learning rate {lr} must be positive")
285
        if weight_decay < 0:
286
            raise ValueError(f"Weight decay {weight_decay} must be non-negative")
287
        if eps < 0:
288
            raise ValueError(f"Eps must be non-negative")
289

290
        defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
291
        super().__init__(params, defaults)
292

293
    @property
294
    def supports_memory_efficient_fp16(self) -> bool:
295
        return False
296

297
    @property
298
    def supports_flat_params(self) -> bool:
299
        return True
300

301
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
302
        """Performs a single optimization step.
303

304
        Arguments:
305
            closure (callable, optional): A closure that reevaluates the model
306
                and returns the loss.
307
        """
308
        loss = None
309
        if closure is not None and isinstance(closure, collections.Callable):
310
            loss = closure()
311

312
        # step counter must be stored in state to ensure correct behavior under
313
        # optimizer sharding
314
        if "k" not in self.state:
315
            self.state["k"] = torch.tensor([0], dtype=torch.long)
316
        k = self.state["k"].item()
317

318
        for group in self.param_groups:
319
            eps = group["eps"]
320
            lr = group["lr"] + eps
321
            decay = group["weight_decay"]
322
            momentum = group["momentum"]
323

324
            ck = 1 - momentum
325
            lamb = lr * math.pow(k + 1, 0.5)
326

327
            for p in group["params"]:
328
                if p.grad is None:
329
                    continue
330
                grad = p.grad.data
331
                state = self.state[p]
332

333
                if "grad_sum_sq" not in state:
334
                    state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
335
                    state["s"] = torch.zeros_like(p.data).detach()
336
                    if momentum != 0:
337
                        state["x0"] = torch.clone(p.data).detach()
338

339
                if momentum != 0.0 and grad.is_sparse:
340
                    raise RuntimeError(
341
                        "momentum != 0 is not compatible with sparse gradients"
342
                    )
343

344
                grad_sum_sq = state["grad_sum_sq"]
345
                s = state["s"]
346

347
                # Apply weight decay - L2 / AdamW style
348
                if decay:
349
                    p.data.mul_(1 - lr * decay)
350

351
                """ original impl:
352
                if decay != 0:
353
                    if grad.is_sparse:
354
                        raise RuntimeError("weight_decay option is not compatible with sparse gradients")
355

356
                    grad.add_(p.data, alpha=decay)
357
                """
358

359
                if grad.is_sparse:
360
                    grad = grad.coalesce()
361
                    grad_val = grad._values()
362

363
                    p_masked = p.sparse_mask(grad)
364
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
365
                    s_masked = s.sparse_mask(grad)
366

367
                    # Compute x_0 from other known quantities
368
                    rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
369
                    x0_masked_vals = p_masked._values().addcdiv(
370
                        s_masked._values(), rms_masked_vals, value=1
371
                    )
372

373
                    # Dense + sparse op
374
                    grad_sq = grad * grad
375
                    grad_sum_sq.add_(grad_sq, alpha=lamb)
376
                    grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
377

378
                    rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
379

380
                    s.add_(grad, alpha=lamb)
381
                    s_masked._values().add_(grad_val, alpha=lamb)
382

383
                    # update masked copy of p
384
                    p_kp1_masked_vals = x0_masked_vals.addcdiv(
385
                        s_masked._values(), rms_masked_vals, value=-1
386
                    )
387
                    # Copy updated masked p to dense p using an add operation
388
                    p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
389
                    p.data.add_(p_masked, alpha=-1)
390
                else:
391
                    if momentum == 0:
392
                        # Compute x_0 from other known quantities
393
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
394
                        x0 = p.data.addcdiv(s, rms, value=1)
395
                    else:
396
                        x0 = state["x0"]
397

398
                    # Accumulate second moments
399
                    grad_sum_sq.addcmul_(grad, grad, value=lamb)
400
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)
401

402
                    # Update s
403
                    s.data.add_(grad, alpha=lamb)
404

405
                    # Step
406
                    if momentum == 0:
407
                        p.data.copy_(x0.addcdiv(s, rms, value=-1))
408
                    else:
409
                        z = x0.addcdiv(s, rms, value=-1)
410

411
                        # p is a moving average of z
412
                        p.data.mul_(1 - ck).add_(z, alpha=ck)
413

414
        self.state["k"] += 1
415
        return loss
416

417

418
class Lion(Optimizer):
419
    """
420
    Implements the Lion Algorithm
421

422
    .. / _Lion: https://arxiv.org/abs/2302.06675
423

424
    Compared to AdamW and various adaptive optimizers that need to save both first and second moments,
425
    Lion only needs the momentum, halving the additional memory footprint. This is beneficial when training large models
426
    and / or with a large batch size.
427

428
    Arguments:
429
        params (iterable):
430
            Iterable of parameters to optimize or dicts defining parameter groups.
431
        lr (float):
432
            Learning rate (default: 1e-2).
433
        beta (float):
434
            coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
435
        weight_decay (float):
436
            Weight decay, i.e. a L2 penalty (default: 0).
437

438
    """
439

440
    def __init__(
441
        self,
442
        params,
443
        lr: float = 1e-4,
444
        betas: Tuple[float, float] = (0.9, 0.99),
445
        weight_decay: float = 0.0,
446
    ):
447
        if lr <= 0:
448
            raise ValueError(f"Learning rate {lr} must be positive")
449
        if weight_decay < 0:
450
            raise ValueError(f"Weight decay {weight_decay} must be non-negative")
451
        if not (0 <= betas[0] <= 1 and 0 <= betas[1] <= 1):
452
            raise ValueError(f"Betas {betas} must be in range [0, 1)")
453

454
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
455
        super().__init__(params, defaults)
456

457
    def update(self, p, grad, exp_avg, lr, wd, beta1, beta2):
458
        """https://arxiv.org/pdf/2302.06675.pdf#appendix.A"""
459

460
        # update model parameters
461
        p.mul_(1 - lr * wd)
462
        sign = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
463
        p.add_(sign, alpha=-lr)
464

465
        # update EMA
466
        exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
467

468
    @torch.no_grad()
469
    def step(self, closure: Optional[Callable] = None):
470

471
        loss = None
472
        if closure is not None:
473
            with torch.enable_grad():
474
                loss = closure()
475

476
        for group in self.param_groups:
477
            for p in group["params"]:
478
                if p.grad is None:
479
                    continue
480

481
                state = self.state[p]
482

483
                # init state - exponential moving average of gradient values
484
                if len(state) == 0:
485
                    state["exp_avg"] = torch.zeros_like(p.data).detach()
486

487
                self.update(
488
                    p,
489
                    p.grad,
490
                    state["exp_avg"],
491
                    group["lr"],
492
                    group["weight_decay"],
493
                    group["betas"][0],
494
                    group["betas"][1],
495
                )
496

497
        return loss
498

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.