pytorch

Форк
0
/
adamw.py 
688 строк · 27.9 Кб
1
import torch
2
from torch import Tensor
3
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
4
                        _stack_if_compiling, _get_scalar_dtype, _capturable_doc, _differentiable_doc,
5
                        _foreach_doc, _fused_doc, _maximize_doc, _default_to_fused_or_foreach,
6
                        ParamsT, _view_as_real)
7
from typing import List, Optional, Tuple, Union
8
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
9

10
__all__ = ["AdamW", "adamw"]
11

12

13
class AdamW(Optimizer):
14
    def __init__(
15
        self,
16
        params: ParamsT,
17
        lr: Union[float, Tensor] = 1e-3,
18
        betas: Tuple[float, float] = (0.9, 0.999),
19
        eps: float = 1e-8,
20
        weight_decay: float = 1e-2,
21
        amsgrad: bool = False,
22
        *,
23
        maximize: bool = False,
24
        foreach: Optional[bool] = None,
25
        capturable: bool = False,
26
        differentiable: bool = False,
27
        fused: Optional[bool] = None,
28
    ):
29
        if not 0.0 <= lr:
30
            raise ValueError(f"Invalid learning rate: {lr}")
31
        if isinstance(lr, Tensor) and foreach and not capturable:
32
            raise ValueError("lr as a Tensor is not supported for capturable=False and foreach=True")
33
        if not 0.0 <= eps:
34
            raise ValueError(f"Invalid epsilon value: {eps}")
35
        if not 0.0 <= betas[0] < 1.0:
36
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
37
        if not 0.0 <= betas[1] < 1.0:
38
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
39
        if not 0.0 <= weight_decay:
40
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
41
        defaults = dict(
42
            lr=lr,
43
            betas=betas,
44
            eps=eps,
45
            weight_decay=weight_decay,
46
            amsgrad=amsgrad,
47
            foreach=foreach,
48
            maximize=maximize,
49
            capturable=capturable,
50
            differentiable=differentiable,
51
            fused=fused,
52
        )
53
        super().__init__(params, defaults)
54

55
        if fused:
56
            if differentiable:
57
                raise RuntimeError("`fused` does not support `differentiable`")
58
            self._step_supports_amp_scaling = True
59
            # TODO(crcrpar): [low prec params & their higher prec copy]
60
            # Suppor AMP with FP16/BF16 model params which would need
61
            # higher prec copy of params to do update math in higher prec to
62
            # alleviate the loss of information.
63
            fused_supported_devices = _get_fused_kernels_supported_devices()
64
            if not all(
65
                p.device.type in fused_supported_devices and
66
                torch.is_floating_point(p)
67
                for pg in self.param_groups for p in pg['params']
68
            ):
69
                raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
70
                                   f"supported devices: {fused_supported_devices}.")
71
            if foreach:
72
                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
73

74
    def __setstate__(self, state):
75
        super().__setstate__(state)
76
        for group in self.param_groups:
77
            group.setdefault("amsgrad", False)
78
            group.setdefault("maximize", False)
79
            group.setdefault("foreach", None)
80
            group.setdefault("capturable", False)
81
            group.setdefault("differentiable", False)
82
            fused = group.setdefault("fused", None)
83
            for p in group["params"]:
84
                p_state = self.state.get(p, [])
85
                if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
86
                    step_val = float(p_state["step"])
87
                    p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
88
                                       if group['capturable'] or group['fused']
89
                                       else torch.tensor(step_val, dtype=_get_scalar_dtype()))
90

91
    def _init_group(
92
        self,
93
        group,
94
        params_with_grad,
95
        grads,
96
        amsgrad,
97
        exp_avgs,
98
        exp_avg_sqs,
99
        max_exp_avg_sqs,
100
        state_steps,
101
    ):
102
        has_complex = False
103
        for p in group["params"]:
104
            if p.grad is None:
105
                continue
106
            has_complex |= torch.is_complex(p)
107
            params_with_grad.append(p)
108
            if p.grad.is_sparse:
109
                raise RuntimeError("AdamW does not support sparse gradients")
110
            grads.append(p.grad)
111

112
            state = self.state[p]
113

114
            # State initialization
115
            if len(state) == 0:
116
                # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
117
                # This is because kernel launches are costly on CUDA and XLA.
118
                state["step"] = (
119
                    torch.zeros((), dtype=_get_scalar_dtype(is_fused=group["fused"]), device=p.device)
120
                    if group["capturable"] or group["fused"]
121
                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
122
                )
123
                # Exponential moving average of gradient values
124
                state["exp_avg"] = torch.zeros_like(
125
                    p, memory_format=torch.preserve_format
126
                )
127
                # Exponential moving average of squared gradient values
128
                state["exp_avg_sq"] = torch.zeros_like(
129
                    p, memory_format=torch.preserve_format
130
                )
131
                if amsgrad:
132
                    # Maintains max of all exp. moving avg. of sq. grad. values
133
                    state["max_exp_avg_sq"] = torch.zeros_like(
134
                        p, memory_format=torch.preserve_format
135
                    )
136

137
            exp_avgs.append(state["exp_avg"])
138
            exp_avg_sqs.append(state["exp_avg_sq"])
139

140
            if group['amsgrad']:
141
                max_exp_avg_sqs.append(state["max_exp_avg_sq"])
142
            if group['differentiable'] and state['step'].requires_grad:
143
                raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode')
144

145
            # Foreach without capturable does not support a tensor lr
146
            if group['foreach'] and isinstance(group['lr'], Tensor) and not group['capturable']:
147
                raise RuntimeError('lr as a Tensor is not supported for capturable=False and foreach=True')
148

149
            state_steps.append(state["step"])
150
        return has_complex
151

152
    @_use_grad_for_differentiable
153
    def step(self, closure=None):
154
        """Perform a single optimization step.
155

156
        Args:
157
            closure (Callable, optional): A closure that reevaluates the model
158
                and returns the loss.
159
        """
160
        self._cuda_graph_capture_health_check()
161

162
        loss = None
163
        if closure is not None:
164
            with torch.enable_grad():
165
                loss = closure()
166

167
        for group in self.param_groups:
168
            params_with_grad = []
169
            grads = []
170
            exp_avgs = []
171
            exp_avg_sqs = []
172
            max_exp_avg_sqs = []
173
            state_steps = []
174
            amsgrad = group["amsgrad"]
175
            beta1, beta2 = group["betas"]
176

177
            has_complex = self._init_group(
178
                group,
179
                params_with_grad,
180
                grads,
181
                amsgrad,
182
                exp_avgs,
183
                exp_avg_sqs,
184
                max_exp_avg_sqs,
185
                state_steps,
186
            )
187

188
            adamw(
189
                params_with_grad,
190
                grads,
191
                exp_avgs,
192
                exp_avg_sqs,
193
                max_exp_avg_sqs,
194
                state_steps,
195
                amsgrad=amsgrad,
196
                beta1=beta1,
197
                beta2=beta2,
198
                lr=group["lr"],
199
                weight_decay=group["weight_decay"],
200
                eps=group["eps"],
201
                maximize=group["maximize"],
202
                foreach=group["foreach"],
203
                capturable=group["capturable"],
204
                differentiable=group["differentiable"],
205
                fused=group["fused"],
206
                grad_scale=getattr(self, "grad_scale", None),
207
                found_inf=getattr(self, "found_inf", None),
208
                has_complex=has_complex,
209
            )
210

211
        return loss
212

213

214
AdamW.__doc__ = r"""Implements AdamW algorithm.
215

216
    .. math::
217
       \begin{aligned}
218
            &\rule{110mm}{0.4pt}                                                                 \\
219
            &\textbf{input}      : \gamma \text{(lr)}, \: \beta_1, \beta_2
220
                \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
221
                \: \epsilon \text{ (epsilon)}                                                    \\
222
            &\hspace{13mm}      \lambda \text{(weight decay)},  \: \textit{amsgrad},
223
                \: \textit{maximize}                                                             \\
224
            &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
225
                \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0              \\[-1.ex]
226
            &\rule{110mm}{0.4pt}                                                                 \\
227
            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
228

229
            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
230
            &\hspace{10mm}g_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})          \\
231
            &\hspace{5mm}\textbf{else}                                                           \\
232
            &\hspace{10mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
233
            &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}         \\
234
            &\hspace{5mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
235
            &\hspace{5mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
236
            &\hspace{5mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
237
            &\hspace{5mm}\widehat{v_t} \leftarrow   v_t/\big(1-\beta_2^t \big)                   \\
238
            &\hspace{5mm}\textbf{if} \: amsgrad                                                  \\
239
            &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
240
                \widehat{v_t})                                                                   \\
241
            &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
242
                \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big)                                 \\
243
            &\hspace{5mm}\textbf{else}                                                           \\
244
            &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
245
                \big(\sqrt{\widehat{v_t}} + \epsilon \big)                                       \\
246
            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
247
            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
248
            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
249
       \end{aligned}
250

251
    For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
252
    """ + fr"""
253
    Args:
254
        params (iterable): iterable of parameters to optimize or dicts defining
255
            parameter groups
256
        lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
257
            is not yet supported for all our implementations. Please use a float
258
            LR if you are not also specifying fused=True or capturable=True.
259
        betas (Tuple[float, float], optional): coefficients used for computing
260
            running averages of gradient and its square (default: (0.9, 0.999))
261
        eps (float, optional): term added to the denominator to improve
262
            numerical stability (default: 1e-8)
263
        weight_decay (float, optional): weight decay coefficient (default: 1e-2)
264
        amsgrad (bool, optional): whether to use the AMSGrad variant of this
265
            algorithm from the paper `On the Convergence of Adam and Beyond`_
266
            (default: False)
267
        {_maximize_doc}
268
        {_foreach_doc}
269
        {_capturable_doc}
270
        {_differentiable_doc}
271
        {_fused_doc}
272
    .. _Decoupled Weight Decay Regularization:
273
        https://arxiv.org/abs/1711.05101
274
    .. _On the Convergence of Adam and Beyond:
275
        https://openreview.net/forum?id=ryQu7f-RZ
276

277
    """
278

279

280
def adamw(
281
    params: List[Tensor],
282
    grads: List[Tensor],
283
    exp_avgs: List[Tensor],
284
    exp_avg_sqs: List[Tensor],
285
    max_exp_avg_sqs: List[Tensor],
286
    state_steps: List[Tensor],
287
    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
288
    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
289
    foreach: Optional[bool] = None,
290
    capturable: bool = False,
291
    differentiable: bool = False,
292
    fused: Optional[bool] = None,
293
    grad_scale: Optional[Tensor] = None,
294
    found_inf: Optional[Tensor] = None,
295
    has_complex: bool = False,
296
    *,
297
    amsgrad: bool,
298
    beta1: float,
299
    beta2: float,
300
    lr: Union[float, Tensor],
301
    weight_decay: float,
302
    eps: float,
303
    maximize: bool,
304
):
305
    r"""Functional API that performs AdamW algorithm computation.
306

307
    See :class:`~torch.optim.AdamW` for details.
308
    """
309
    if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
310
        raise RuntimeError(
311
            "API has changed, `state_steps` argument must contain a list of singleton tensors"
312
        )
313

314
    # Respect when the user inputs False/True for foreach or fused. We only want to change
315
    # the default when neither have been user-specified. Note that we default to foreach
316
    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
317
    # bake-in time before making it the default, even if it is typically faster.
318
    if fused is None and foreach is None:
319
        _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
320
        # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
321
        if foreach and isinstance(lr, Tensor) and not capturable:
322
            foreach = False
323
    if fused is None:
324
        fused = False
325
    if foreach is None:
326
        foreach = False
327

328
    if foreach and torch.jit.is_scripting():
329
        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
330
    if fused and torch.jit.is_scripting():
331
        raise RuntimeError("torch.jit.script not supported with fused optimizers")
332

333
    if fused and not torch.jit.is_scripting():
334
        func = _fused_adamw
335
    elif foreach and not torch.jit.is_scripting():
336
        func = _multi_tensor_adamw
337
    else:
338
        func = _single_tensor_adamw
339

340
    func(
341
        params,
342
        grads,
343
        exp_avgs,
344
        exp_avg_sqs,
345
        max_exp_avg_sqs,
346
        state_steps,
347
        amsgrad=amsgrad,
348
        beta1=beta1,
349
        beta2=beta2,
350
        lr=lr,
351
        weight_decay=weight_decay,
352
        eps=eps,
353
        maximize=maximize,
354
        capturable=capturable,
355
        differentiable=differentiable,
356
        grad_scale=grad_scale,
357
        found_inf=found_inf,
358
        has_complex=has_complex,
359
    )
360

361

362
def _single_tensor_adamw(
363
    params: List[Tensor],
364
    grads: List[Tensor],
365
    exp_avgs: List[Tensor],
366
    exp_avg_sqs: List[Tensor],
367
    max_exp_avg_sqs: List[Tensor],
368
    state_steps: List[Tensor],
369
    grad_scale: Optional[Tensor],
370
    found_inf: Optional[Tensor],
371
    *,
372
    amsgrad: bool,
373
    beta1: float,
374
    beta2: float,
375
    lr: Union[Tensor, float],
376
    weight_decay: float,
377
    eps: float,
378
    maximize: bool,
379
    capturable: bool,
380
    differentiable: bool,
381
    has_complex: bool,
382
):
383

384
    assert grad_scale is None and found_inf is None
385

386
    if torch.jit.is_scripting():
387
        # this assert is due to JIT being dumb and not realizing that the ops below
388
        # have overloads to handle both float and Tensor lrs, so we just assert it's
389
        # a float since most people using JIT are using floats
390
        assert isinstance(lr, float)
391

392
    for i, param in enumerate(params):
393
        grad = grads[i] if not maximize else -grads[i]
394
        exp_avg = exp_avgs[i]
395
        exp_avg_sq = exp_avg_sqs[i]
396
        step_t = state_steps[i]
397

398
        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
399
        if not torch._utils.is_compiling() and capturable:
400
            assert (
401
                (param.is_cuda and step_t.is_cuda) or (param.is_xla and step_t.is_xla)
402
            ), "If capturable=True, params and state_steps must be CUDA or XLA tensors."
403

404
        if torch.is_complex(param):
405
            grad = torch.view_as_real(grad)
406
            exp_avg = torch.view_as_real(exp_avg)
407
            exp_avg_sq = torch.view_as_real(exp_avg_sq)
408
            if amsgrad:
409
                max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
410
            param = torch.view_as_real(param)
411

412
        # update step
413
        step_t += 1
414

415
        # Perform stepweight decay
416
        param.mul_(1 - lr * weight_decay)
417

418
        # Decay the first and second moment running average coefficient
419
        exp_avg.lerp_(grad, 1 - beta1)
420
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
421

422
        if capturable or differentiable:
423
            step = step_t
424

425
            bias_correction1 = 1 - beta1 ** step
426
            bias_correction2 = 1 - beta2 ** step
427

428
            step_size = lr / bias_correction1
429
            step_size_neg = step_size.neg()
430

431
            bias_correction2_sqrt = bias_correction2.sqrt()
432

433
            if amsgrad:
434
                # Maintains the maximum of all 2nd moment running avg. till now
435
                if differentiable:
436
                    max_exp_avg_sq = max_exp_avg_sqs[i].clone()
437
                else:
438
                    max_exp_avg_sq = max_exp_avg_sqs[i]
439

440
                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
441

442
                # Uses the max. for normalizing running avg. of gradient
443
                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
444
                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
445
                denom = (
446
                    max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
447
                ).add_(eps / step_size_neg)
448
            else:
449
                denom = (
450
                    exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
451
                ).add_(eps / step_size_neg)
452

453
            param.addcdiv_(exp_avg, denom)
454
        else:
455
            step = _get_value(step_t)
456

457
            bias_correction1 = 1 - beta1 ** step
458
            bias_correction2 = 1 - beta2 ** step
459

460
            step_size = lr / bias_correction1
461

462
            bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
463

464
            if amsgrad:
465
                # Maintains the maximum of all 2nd moment running avg. till now
466
                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
467

468
                # Use the max. for normalizing running avg. of gradient
469
                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
470
            else:
471
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
472

473
            param.addcdiv_(exp_avg, denom, value=-step_size)
474

475
        # Lastly, switch back to complex view
476
        if amsgrad and torch.is_complex(params[i]):
477
            max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
478

479

480
def _multi_tensor_adamw(
481
    params: List[Tensor],
482
    grads: List[Tensor],
483
    exp_avgs: List[Tensor],
484
    exp_avg_sqs: List[Tensor],
485
    max_exp_avg_sqs: List[Tensor],
486
    state_steps: List[Tensor],
487
    grad_scale: Optional[Tensor],
488
    found_inf: Optional[Tensor],
489
    *,
490
    amsgrad: bool,
491
    beta1: float,
492
    beta2: float,
493
    lr: Union[Tensor, float],
494
    weight_decay: float,
495
    eps: float,
496
    maximize: bool,
497
    capturable: bool,
498
    differentiable: bool,
499
    has_complex: bool,
500
):
501
    if len(params) == 0:
502
        return
503

504
    if isinstance(lr, Tensor) and not capturable:
505
        raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
506

507
    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
508
    if not torch._utils.is_compiling() and capturable:
509
        assert all(
510
            p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
511
        ), "If capturable=True, params and state_steps must be CUDA tensors."
512

513
    assert not differentiable, "_foreach ops don't support autograd"
514

515
    assert grad_scale is None and found_inf is None
516

517
    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([
518
        params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
519
    for ((
520
        device_params,
521
        device_grads,
522
        device_exp_avgs,
523
        device_exp_avg_sqs,
524
        device_max_exp_avg_sqs,
525
        device_state_steps,
526
    ), _) in grouped_tensors.values():
527
        if has_complex:
528
            if amsgrad:
529
                _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
530
            else:
531
                _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
532

533
        if maximize:
534
            device_grads = torch._foreach_neg(device_grads)
535

536
        # Update steps
537
        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
538
        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
539
        # wrapped it once now. The alpha is required to assure we go to the right overload.
540
        if device_state_steps[0].is_cpu:
541
            torch._foreach_add_(device_state_steps, torch.tensor(1.0, device='cpu'), alpha=1.0)
542
        else:
543
            torch._foreach_add_(device_state_steps, 1)
544

545
        # Perform stepweight decay
546
        if weight_decay != 0:
547
            torch._foreach_mul_(device_params, 1 - lr * weight_decay)
548

549
        # Decay the first and second moment running average coefficient
550
        torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
551

552
        torch._foreach_mul_(device_exp_avg_sqs, beta2)
553
        torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
554

555
        # Delete the local intermediate since it won't be used anymore to save on peak memory
556
        del device_grads
557

558
        if capturable:
559
            bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
560
            bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
561
            # foreach_sub doesn't allow a scalar as the first arg
562
            torch._foreach_sub_(bias_correction1, 1)
563
            torch._foreach_sub_(bias_correction2, 1)
564
            # we do not negate bias_correction1 as it'll need to be negated later anyway
565
            torch._foreach_neg_(bias_correction2)
566

567
            # foreach_div doesn't allow a scalar as the first arg
568
            torch._foreach_div_(bias_correction1, lr)
569
            torch._foreach_reciprocal_(bias_correction1)
570

571
            torch._foreach_sqrt_(bias_correction2)
572

573
            # Re-assign for clarity as we maintain minimal intermediates: we'll have
574
            # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
575
            # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
576
            step_size = bias_correction1
577
            bias_correction2_sqrt = bias_correction2
578

579
            if amsgrad:
580
                # Maintains the maximum of all 2nd moment running avg. till now
581
                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
582

583
                # Use the max. for normalizing running avg. of gradient
584
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
585
            else:
586
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
587

588
            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
589
            torch._foreach_add_(exp_avg_sq_sqrt, eps)
590
            torch._foreach_div_(exp_avg_sq_sqrt, step_size)
591

592
            # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
593
            torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
594
        else:
595
            bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
596
            bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]
597

598
            step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
599

600
            bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
601

602
            if amsgrad:
603
                # Maintains the maximum of all 2nd moment running avg. till now
604
                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
605

606
                # Use the max. for normalizing running avg. of gradient
607
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
608
            else:
609
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
610

611
            torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
612
            torch._foreach_add_(exp_avg_sq_sqrt, eps)
613
            torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size)
614

615

616
def _fused_adamw(
617
    params: List[Tensor],
618
    grads: List[Tensor],
619
    exp_avgs: List[Tensor],
620
    exp_avg_sqs: List[Tensor],
621
    max_exp_avg_sqs: List[Tensor],
622
    state_steps: List[Tensor],
623
    grad_scale: Optional[Tensor],
624
    found_inf: Optional[Tensor],
625
    *,
626
    amsgrad: bool,
627
    beta1: float,
628
    beta2: float,
629
    lr: Union[float, Tensor],
630
    weight_decay: float,
631
    eps: float,
632
    maximize: bool,
633
    capturable: bool,  # Needed for consistency.
634
    differentiable: bool,
635
    has_complex: bool,
636
) -> None:
637
    if not params:
638
        return
639
    if differentiable:
640
        raise RuntimeError("Adam with fused=True does not support differentiable=True")
641

642
    grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
643
    found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
644

645
    # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
646
    # treating it as a scalar.
647
    lr_dict = {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
648

649
    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
650
        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
651
    for (device, _), ((device_params,
652
                       device_grads,
653
                       device_exp_avgs,
654
                       device_exp_avg_sqs,
655
                       device_max_exp_avg_sqs,
656
                       device_state_steps,), _) in grouped_tensors.items():
657
        device_grad_scale, device_found_inf = None, None
658
        if grad_scale is not None:
659
            if device not in grad_scale_dict:
660
                grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
661
            device_grad_scale = grad_scale_dict[device]
662
        if found_inf is not None:
663
            if found_inf not in found_inf_dict:
664
                found_inf_dict[device] = found_inf.to(device, non_blocking=True)
665
            device_found_inf = found_inf_dict[device]
666
        if lr_dict is not None and device not in lr_dict:
667
            lr_dict[device] = lr.to(device=device, non_blocking=True)
668
            lr = lr_dict[device]
669
        torch._foreach_add_(device_state_steps, 1)
670
        torch._fused_adamw_(
671
            device_params,
672
            device_grads,
673
            device_exp_avgs,
674
            device_exp_avg_sqs,
675
            device_max_exp_avg_sqs,
676
            device_state_steps,
677
            amsgrad=amsgrad,
678
            lr=lr,
679
            beta1=beta1,
680
            beta2=beta2,
681
            weight_decay=weight_decay,
682
            eps=eps,
683
            maximize=maximize,
684
            grad_scale=device_grad_scale,
685
            found_inf=device_found_inf,
686
        )
687
        if device_found_inf is not None:
688
            torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))
689

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.