pytorch

Форк
0
/
test_optim.py 
2071 строка · 88.2 Кб
1
# Owner(s): ["module: optimizer"]
2
import functools
3
import math
4
import tempfile
5
import unittest
6
from copy import deepcopy
7
from typing import Any, Dict, Tuple
8
from unittest.mock import patch
9

10
from optim.test_lrscheduler import TestLRScheduler  # noqa: F401
11
from optim.test_optim import TestDifferentiableOptimizer  # noqa: F401
12
from optim.test_swa_utils import TestSWAUtils  # noqa: F401
13

14
import torch
15
from torch.nn import Parameter
16
from torch.optim import Optimizer, SGD
17
from torch.optim.lr_scheduler import ReduceLROnPlateau
18
from torch.optim.optimizer import (
19
    register_optimizer_step_post_hook,
20
    register_optimizer_step_pre_hook,
21
)
22
from torch.testing._internal.common_cuda import TEST_MULTIGPU
23
from torch.testing._internal.common_device_type import (
24
    instantiate_device_type_tests,
25
    largeTensorTest,
26
    onlyCPU,
27
    onlyCUDA,
28
    onlyNativeDeviceTypes,
29
    skipMPS,
30
    TEST_WITH_ROCM,
31
)
32
from torch.testing._internal.common_dtype import floating_types_and
33
from torch.testing._internal.common_optimizers import (
34
    _get_device_type,
35
    _get_optim_inputs_including_global_cliquey_kwargs,
36
    optim_db,
37
    OptimizerErrorEnum,
38
    optims,
39
    TensorTracker,
40
)
41
from torch.testing._internal.common_utils import (
42
    markDynamoStrictTest,
43
    parametrize,
44
    run_tests,
45
    TEST_WITH_TORCHDYNAMO,
46
    TestCase,
47
)
48

49

50
FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
51

52

53
def rosenbrock(tensor):
54
    assert tensor.size() == torch.Size(
55
        [2]
56
    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
57
    x, y = tensor
58
    return (1 - x) ** 2 + 100 * (y - x**2) ** 2
59

60

61
def drosenbrock(tensor):
62
    assert tensor.size() == torch.Size(
63
        [2]
64
    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
65
    x, y = tensor
66
    return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
67

68

69
@markDynamoStrictTest
70
class TestOptimRenewed(TestCase):
71
    """
72
    This test class validates the core optimizers and is structured as the correctness of:
73
    - The update algorithms (forloop implementation)
74
        * Every optimizer's algorithm is most readably implemented through a big for-loop
75
          over all the parameters, which is what we refer to as the forloop or single tensor
76
          implementation. These algorithms are manually validated by comparing to the paper
77
          and systematically validated by assuring that the loss goes the right direction
78
          when the optimizer has been applied.
79
        * This implementation should compose with optimizer hyperparameters well, such as
80
          supporting Tensor LRs, the capturable API, and sparse and complex parameters.
81
    - Each varying implementation
82
        * We then have implementations that improve upon the performance of the forloop
83
          implementation by leveraging fusion, namely our foreach (mult_tensor) and fused
84
          implementations.
85
        * These variations are validated numerically by comparing with the forloop version
86
          of the optimizer. In fact, we test most variations this way--we see the forloop
87
          implementation as the ground truth and expect that improvements to it in any way
88
          should be just as correct.
89
        * Both params and optimizer states should be validated numerically.
90
    - state_dict APIs
91
        * The optimizer instance should be serializable
92
        * Calling save and load should be deterministic
93
        * Moving between devices should be seamless
94
        * BC - load_state_dict should be able to handle older optimizer states
95
    - Hook APIs (everything should fire in the right order)
96
    - LR Scheduler integration (composing should not error + should go the right direction)
97
    - Parameter groups (should be equivalent to having multiple optimizers)
98
    - Erroring (what should error should error)
99

100
    We also cover different ways of generating parameters and grads:
101
    - With parameters, we either generate them randomly given specific shapes or we take
102
      them from a sample NN module.
103
        * Variety is important here because NN modules have type Parameter and randomly
104
          generated tensors have type Tensor.
105
        * Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo)
106
        * Complex parameters should be handled using view_as_real
107
        * Parameters can be spread across different devices and different dtypes for any
108
          given optimizer
109
        * Parameters can be contiguous and noncontiguous
110
    - With grads, we follow suit from the parameters.
111
        * Grads can also be None, empty, or zero-valued, and this should not disrupt training.
112
    """
113

114
    @onlyCPU
115
    @optims(optim_db)
116
    def test_optim_infos_do_not_specify_global_cliquey_kwargs(
117
        self, device, dtype, optim_info
118
    ):
119
        global_cliquey_flags = ["foreach", "fused", "differentiable"]
120
        for optim_input in optim_info.optim_inputs_func(device=device):
121
            self.assertFalse(
122
                any(f for f in global_cliquey_flags if f in optim_input.kwargs)
123
            )
124

125
    @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
126
    def test_errors(self, device, dtype, optim_info):
127
        optim_cls = optim_info.optim_cls
128
        error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
129

130
        for error_input in error_inputs:
131
            optim_input = error_input.optimizer_error_input
132
            params, kwargs = optim_input.params, optim_input.kwargs
133
            if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
134
                if issubclass(error_input.error_type, Warning):
135
                    with self.assertWarnsRegex(
136
                        error_input.error_type, error_input.error_regex
137
                    ):
138
                        optim_cls(params, **kwargs)
139
                else:
140
                    with self.assertRaisesRegex(
141
                        error_input.error_type, error_input.error_regex
142
                    ):
143
                        optim_cls(params, **kwargs)
144
            elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
145
                optim = optim_cls(params, **kwargs)
146
                if issubclass(error_input.error_type, Warning):
147
                    with self.assertWarnsRegex(
148
                        error_input.error_type, error_input.error_regex
149
                    ):
150
                        optim.step()
151
                else:
152
                    with self.assertRaisesRegex(
153
                        error_input.error_type, error_input.error_regex
154
                    ):
155
                        optim.step()
156
            else:
157
                raise NotImplementedError(f"Unknown error type {error_input.error_on}")
158

159
    @parametrize("contiguous", [True, False])
160
    @parametrize("with_lrsched", [True, False])
161
    @optims(optim_db, dtypes=[torch.float32])
162
    def test_forloop_goes_right_direction(
163
        self, device, dtype, optim_info, contiguous, with_lrsched
164
    ):
165
        optim_cls = optim_info.optim_cls
166
        schedulers_constructors = (
167
            optim_info.scheduler_inputs if with_lrsched else [None]
168
        )
169

170
        for schedulers_constructor in schedulers_constructors:
171
            # with tensor LR we need fresh inputs for each scheduler
172
            # or mutating it will carry across iters
173
            optim_inputs = optim_info.optim_inputs_func(device=device)
174
            for optim_input in optim_inputs:
175
                if "foreach" in optim_info.supported_impls:
176
                    optim_input.kwargs["foreach"] = False  # force forloop
177
                if contiguous:
178
                    weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
179
                    bias = Parameter(torch.randn((10), device=device, dtype=dtype))
180
                else:
181
                    weight = Parameter(
182
                        torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]
183
                    )
184
                    bias = Parameter(
185
                        torch.randn((10, 2), device=device, dtype=dtype)[..., 0]
186
                    )
187
                input = torch.randn(5, device=device, dtype=dtype)
188

189
                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
190
                schedulers = [
191
                    s(optimizer)
192
                    for s in (schedulers_constructor if schedulers_constructor else [])
193
                ]
194

195
                def closure():
196
                    optimizer.zero_grad()
197
                    loss = (weight.mv(input) + bias).pow(2).sum()
198
                    loss.backward()
199
                    if optim_info.only_supports_sparse_grads:
200
                        # For this test, we naively convert the Tensor layout, which we know does
201
                        # NOT represent the expected use case for optims like SparseAdam!
202
                        weight.grad = weight.grad.to_sparse()
203
                        bias.grad = bias.grad.to_sparse()
204
                    return loss
205

206
                initial_value = closure().item()
207
                for _ in range(20):
208
                    if optim_info.step_requires_closure:
209
                        loss = optimizer.step(closure)
210
                    else:
211
                        loss = closure()
212
                        optimizer.step()
213

214
                    for scheduler in schedulers:
215
                        if isinstance(scheduler, ReduceLROnPlateau):
216
                            scheduler.step(loss)
217
                        else:
218
                            scheduler.step()
219

220
                if optim_input.kwargs.get("maximize", False):
221
                    self.assertGreater(closure().item(), initial_value)
222
                else:
223
                    self.assertLess(closure().item(), initial_value)
224

225
    @onlyCUDA
226
    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
227
    @parametrize("with_lrsched", [True, False])
228
    @optims(optim_db, dtypes=[torch.float32])
229
    def test_forloop_goes_right_direction_multigpu(
230
        self, device, dtype, optim_info, with_lrsched
231
    ):
232
        optim_cls = optim_info.optim_cls
233
        schedulers_constructors = (
234
            optim_info.scheduler_inputs if with_lrsched else [None]
235
        )
236
        for schedulers_constructor in schedulers_constructors:
237
            # We need a fresh set of inputs if we have a tensor LR
238
            # to not carry mutations across iterations.
239
            optim_inputs = optim_info.optim_inputs_func(device=device)
240
            for optim_input in optim_inputs:
241
                if "foreach" in optim_info.supported_impls:
242
                    optim_input.kwargs["foreach"] = False  # force forloop
243

244
                weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
245
                bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
246
                inpt = torch.randn(5, device="cuda:0", dtype=dtype)
247

248
                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
249
                schedulers = [
250
                    s(optimizer)
251
                    for s in (schedulers_constructor if schedulers_constructor else [])
252
                ]
253

254
                def closure():
255
                    optimizer.zero_grad()
256
                    loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
257
                    loss.backward()
258
                    if optim_info.only_supports_sparse_grads:
259
                        # For this test, we naively convert the Tensor layout, which we know does
260
                        # NOT represent the expected use case for optims like SparseAdam!
261
                        weight.grad = weight.grad.to_sparse()
262
                        bias.grad = bias.grad.to_sparse()
263
                    return loss
264

265
                initial_value = closure().item()
266
                for _ in range(20):
267
                    loss = optimizer.step(closure)
268
                    for scheduler in schedulers:
269
                        if isinstance(scheduler, ReduceLROnPlateau):
270
                            scheduler.step(loss)
271
                        else:
272
                            scheduler.step()
273

274
                if optim_input.kwargs.get("maximize", False):
275
                    self.assertGreater(closure().item(), initial_value)
276
                else:
277
                    self.assertLess(closure().item(), initial_value)
278

279
    @optims(optim_db, dtypes=[torch.float32])
280
    def test_param_group_with_lrscheduler_goes_right_direction(
281
        self, device, dtype, optim_info
282
    ):
283
        optim_cls = optim_info.optim_cls
284

285
        for schedulers_c in optim_info.scheduler_inputs:
286
            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
287
            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
288
            inpt = torch.randn(5, device=device, dtype=dtype)
289

290
            # avoid endless recompiles by wrapping LR in a tensor if we're compiling
291
            lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
292
            optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
293
            schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
294

295
            def closure():
296
                optimizer.zero_grad()
297
                loss = (weight.mv(inpt) + bias).pow(2).sum()
298
                loss.backward()
299
                if optim_info.only_supports_sparse_grads:
300
                    # For this test, we naively convert the Tensor layout, which we know does
301
                    # NOT represent the expected use case for optims like SparseAdam!
302
                    weight.grad = weight.grad.to_sparse()
303
                    bias.grad = bias.grad.to_sparse()
304
                return loss
305

306
            initial_value = closure().item()
307
            for _ in range(20):
308
                loss = optimizer.step(closure)
309
                for scheduler in schedulers:
310
                    if isinstance(scheduler, ReduceLROnPlateau):
311
                        scheduler.step(loss)
312
                    else:
313
                        scheduler.step()
314

315
            self.assertLess(closure().item(), initial_value)
316

317
    @optims(optim_db, dtypes=[torch.float32])
318
    def test_tensor_lr(self, device, dtype, optim_info):
319
        optim_cls = optim_info.optim_cls
320

321
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
322
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
323
            device, dtype, optim_info, skip=("differentiable",)
324
        )
325
        for optim_input in all_optim_inputs:
326
            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
327
            weight_c = weight.clone().detach().requires_grad_(True)
328
            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
329
            bias_c = bias.clone().detach().requires_grad_(True)
330
            inpt = torch.randn(5, device=device, dtype=dtype)
331

332
            kwargs = optim_input.kwargs
333
            if "lr" in kwargs:
334
                del kwargs["lr"]
335

336
            kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
337
            optimizer_r = optim_cls([weight, bias], **kwargs)
338

339
            try:
340
                kwargs["lr"] = torch.tensor(kwargs["lr"])
341
                optimizer = optim_cls([weight_c, bias_c], **kwargs)
342
            except ValueError as e:
343
                self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
344
                continue
345

346
            def closure(optim, w, b, i):
347
                optim.zero_grad()
348
                loss = (w.mv(i) + b).pow(2).sum()
349
                loss.backward()
350
                if optim_info.only_supports_sparse_grads:
351
                    # For this test, we naively convert the Tensor layout, which we know does
352
                    # NOT represent the expected use case for optims like SparseAdam!
353
                    w.grad = w.grad.to_sparse()
354
                    b.grad = b.grad.to_sparse()
355
                return loss
356

357
            for _ in range(5):
358
                if optim_info.step_requires_closure:
359
                    optimizer_r.step(
360
                        functools.partial(closure, optimizer_r, weight, bias, inpt)
361
                    )
362
                    optimizer.step(
363
                        functools.partial(closure, optimizer, weight_c, bias_c, inpt)
364
                    )
365
                else:
366
                    closure(optimizer_r, weight, bias, inpt)
367
                    closure(optimizer, weight_c, bias_c, inpt)
368

369
                self.assertEqual(weight, weight_c)
370
                self.assertEqual(bias, bias_c)
371

372
    @parametrize("with_lrsched", [True, False])
373
    @optims(
374
        [o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads],
375
        dtypes=[torch.float64],
376
    )
377
    def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
378
        optim_cls = optim_info.optim_cls
379

380
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
381
        # Fused impls do not support sparse gradients
382
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
383
            device, dtype, optim_info, skip=("differentiable", "fused")
384
        )
385
        kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
386

387
        if with_lrsched and len(schedulers_constructors) == 0:
388
            return
389

390
        supported_inputs = []
391
        if len(kwarg_updates) != 0:
392
            seen = set()
393
            for i in all_optim_inputs:
394
                for k in kwarg_updates:
395
                    if k in i.kwargs:
396
                        del i.kwargs[k]
397
                hashable_kwargs = tuple(sorted(i.kwargs.items()))
398
                if len(i.kwargs) > 0 and hashable_kwargs not in seen:
399
                    supported_inputs.append(i)
400
                    seen.add(hashable_kwargs)
401
                    if "lr" in kwarg_updates:
402
                        i.kwargs["lr"] = kwarg_updates["lr"]
403
        else:
404
            supported_inputs = all_optim_inputs
405

406
        for optim_input in supported_inputs:
407
            kwargs = optim_input.kwargs
408
            multi_tensor = kwargs.get("foreach", False)
409

410
            # For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
411
            if multi_tensor:
412
                params_t = [
413
                    torch.tensor([1.5, 1.5]),
414
                    torch.tensor([1.5, 1.5], dtype=dtype),
415
                ]
416
            else:
417
                params_t = [torch.tensor([1.5, 1.5])]
418

419
            params = [Parameter(param_t) for param_t in params_t]
420
            optimizer = optim_cls(params, **kwargs)
421
            schedulers = [
422
                s(optimizer) for s in (schedulers_constructors if with_lrsched else [])
423
            ]
424

425
            if not optim_info.only_supports_sparse_grads:
426
                params_c = [Parameter(param_t.clone()) for param_t in params_t]
427
                optimizer_c = optim_cls(params_c, **kwargs)
428
                schedulers_c = [
429
                    s(optimizer_c)
430
                    for s in (schedulers_constructors if with_lrsched else [])
431
                ]
432

433
            solution = torch.tensor([1, 1])
434
            with torch.no_grad():
435
                initial_dist = sum(param.dist(solution) for param in params)
436

437
            def get_grad(param, sparse_grad, w):
438
                grad = drosenbrock(param)
439
                # NB: We torture test the optimizer by returning an
440
                # uncoalesced sparse tensor
441

442
                # Depending on w, provide only the x or y gradient
443
                if sparse_grad:
444
                    if w:
445
                        i = torch.tensor([[0, 0]], dtype=torch.int64)
446
                        x = grad[0]
447
                        v = torch.tensor([x / 4.0, x - x / 4.0])
448
                    else:
449
                        i = torch.tensor([[1, 1]], dtype=torch.int64)
450
                        y = grad[1]
451
                        v = torch.tensor([y - y / 4.0, y / 4.0])
452
                    grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
453
                else:
454
                    if w:
455
                        grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
456
                    else:
457
                        grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
458
                return grad_out
459

460
            def eval(params, sparse_grad, w):
461
                optimizer.zero_grad()
462
                if multi_tensor:
463
                    loss = sum(rosenbrock(param) for param in params)
464
                else:
465
                    loss = rosenbrock(params[0])
466
                loss.backward()
467

468
                grads_out = [get_grad(param, sparse_grad, w) for param in params]
469
                with torch.no_grad():
470
                    params[0].grad = grads_out[0]
471
                    if multi_tensor:
472
                        params[1].grad = grads_out[1].to(dtype=dtype)
473
                return loss
474

475
            for i in range(1800):
476
                # Do cyclic coordinate descent
477
                w = i % 2
478
                optimizer.step(functools.partial(eval, params, True, w))
479
                for scheduler in schedulers:
480
                    if isinstance(scheduler, ReduceLROnPlateau):
481
                        scheduler.step(rosenbrock(params[0]))
482
                    else:
483
                        scheduler.step()
484
                if not optim_info.only_supports_sparse_grads:
485
                    optimizer_c.step(functools.partial(eval, params_c, False, w))
486
                    for scheduler in schedulers_c:
487
                        if isinstance(scheduler, ReduceLROnPlateau):
488
                            scheduler.step(rosenbrock(params_c[0]))
489
                        else:
490
                            scheduler.step()
491
                    # Tolerance is increased due to floating point error from different
492
                    # code path for dense case: x v.s. x - x / 4.0 + x / 4.0
493
                    self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
494

495
            if not kwargs.get("maximize", False):
496
                self.assertLessEqual(
497
                    sum(param.dist(solution) for param in params), initial_dist
498
                )
499
            else:
500
                self.assertGreaterEqual(
501
                    sum(rosenbrock(param) for param in params),
502
                    sum(rosenbrock(param_t) for param_t in params_t),
503
                )
504

505
    @skipMPS
506
    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
507
    def test_complex(self, device, dtype, optim_info):
508
        optim_cls = optim_info.optim_cls
509
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
510
        # Also skip fused, since our fused kernels do not support complex
511
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
512
            device, dtype, optim_info, skip=("differentiable", "fused")
513
        )
514
        for optim_input in all_optim_inputs:
515
            # Last param is intentionally real to test that we can mix real and complex
516
            complex_params = [
517
                torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
518
                torch.randn(10, device=device, dtype=dtype, requires_grad=True),
519
                torch.randn(
520
                    10, 5, device=device, dtype=torch.float32, requires_grad=True
521
                ),
522
            ]
523
            real_params = [
524
                (
525
                    torch.view_as_real(param).detach().clone().requires_grad_()
526
                    if param.is_complex()
527
                    else param.detach().clone().requires_grad_()
528
                )
529
                for param in complex_params
530
            ]
531

532
            complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
533
            real_optimizer = optim_cls(real_params, **optim_input.kwargs)
534
            real_steps = []
535
            complex_steps = []
536
            grads_losses = []
537

538
            def real_closure():
539
                for param in real_params:
540
                    grad = torch.randn_like(param)
541
                    param.grad = grad
542
                    real_steps.append(param.detach().clone())
543
                    grads_losses.append(grad.clone())
544
                loss = torch.randn(1)
545
                grads_losses.append(loss.clone())
546
                return loss
547

548
            def complex_closure():
549
                for param in complex_params:
550
                    if torch.is_complex(param):
551
                        grad = torch.view_as_complex(grads_losses.pop(0))
552
                        complex_steps.append(torch.view_as_real_copy(param.detach()))
553
                    else:
554
                        grad = grads_losses.pop(0)
555
                        complex_steps.append(param.detach().clone())
556
                    param.grad = grad
557
                return grads_losses.pop(0)
558

559
            for _ in range(3):
560
                if optim_info.step_requires_closure:
561
                    # LBFGS, for example, requires closure and calls it internally
562
                    real_optimizer.step(real_closure)
563
                    complex_optimizer.step(complex_closure)
564
                else:
565
                    # For other optimizers, we call closure explicitly to set the gradients
566
                    real_closure()
567
                    complex_closure()
568
                    real_optimizer.step()
569
                    complex_optimizer.step()
570

571
            # Final Parameters should be the same
572
            complex_params_asreal = [
573
                torch.view_as_real(param) if param.is_complex() else param
574
                for param in complex_params
575
            ]
576
            self.assertEqual(real_params, complex_params_asreal)
577

578
            # All intermediate steps should also be the same
579
            # also checks steps taken within for example a line search
580
            self.assertEqual(complex_steps, real_steps)
581

582
    @skipMPS
583
    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
584
    def test_complex_2d(self, device, dtype, optim_info):
585
        optim_cls = optim_info.optim_cls
586
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
587
        # Also skip fused, since our fused kernels do not support complex
588
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
589
            device, dtype, optim_info, skip=("differentiable", "fused")
590
        )
591
        for optim_input in all_optim_inputs:
592
            if optim_info.step_requires_closure:
593
                # Why? The way we implement complex is by turning complex params into view_as_real
594
                # alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test,
595
                # we break apart a tensor into its real and imaginary parts, which would be 2x(M,N).
596
                # For other pointwise optimizers, this distinction is trivial, but for LBFGS where
597
                # there are reductions across all parameters (and all the grads get flattened into
598
                # one long Tensor), this ordering matters. Why? Reductions are not deterministic
599
                # because addition between floating point numbers is not associative, i.e.,
600
                # a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that
601
                # will happen with LBFGS. Note that in test_complex above, there is no need for a seed
602
                # nor for increased tolerance, because results should be bitwise equivalent.
603
                torch.manual_seed(2024)
604

605
            a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True)
606
            a1_real = a1.real.clone().detach()
607
            a1_imag = a1.imag.clone().detach()
608
            a1_real.requires_grad_()
609
            a1_imag.requires_grad_()
610
            optim1 = optim_cls([a1], **optim_input.kwargs)
611
            optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs)
612

613
            a1_reals = TensorTracker()
614
            a1_imags = TensorTracker()
615
            a1_grad_reals = TensorTracker()
616
            a1_grad_imags = TensorTracker()
617
            losses = TensorTracker()
618

619
            def closure1():
620
                optim1.zero_grad()
621
                loss = rosenbrock(a1).abs()
622
                loss.backward()
623

624
                # Track clones to best test accuracy
625
                a1_reals.add(a1.real)
626
                a1_imags.add(a1.imag)
627
                a1_grad_reals.add(a1.grad.real)
628
                a1_grad_imags.add(a1.grad.imag)
629

630
                losses.add(loss)
631

632
                return loss
633

634
            def closure2():
635
                optim2.zero_grad()
636
                a1_reals.pop_check_set(a1_real, self)
637
                a1_imags.pop_check_set(a1_imag, self)
638
                a2 = torch.complex(a1_real, a1_imag)
639
                loss = rosenbrock(a2).abs()
640
                losses.pop_check_set(loss, self)
641
                loss.backward()
642
                a1_grad_reals.pop_check_set(a1_real.grad, self)
643
                a1_grad_imags.pop_check_set(a1_imag.grad, self)
644
                return loss
645

646
            for _ in range(3):
647
                if optim_info.step_requires_closure:
648
                    # LBFGS, for example, requires closure and calls it internally
649
                    optim1.step(closure1)
650
                    optim2.step(closure2)
651
                else:
652
                    closure1()
653
                    closure2()
654
                    optim1.step()
655
                    optim2.step()
656

657
                self.assertEqual(a1.real, a1_real)
658
                self.assertEqual(a1.imag, a1_imag)
659

660
            self.assertTrue(a1_reals.all_popped())
661
            self.assertTrue(a1_imags.all_popped())
662
            self.assertTrue(a1_grad_reals.all_popped())
663
            self.assertTrue(a1_grad_imags.all_popped())
664
            self.assertTrue(losses.all_popped())
665

666
    def _compare_between(
667
        self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
668
    ):
669
        # why 7? iteration 7 is where we start to see differences for RAdam
670
        # params interacting with the small eps value, because that's right
671
        # after rho_t becomes greater than 5 in step 6.
672
        if assert_eq_kwargs is None:
673
            assert_eq_kwargs = {}
674
        kIterations = 7
675
        tracker = TensorTracker(assert_eq_kwargs)
676
        for i in range(kIterations):
677
            state, updated_params = [], []
678
            if not isinstance(inputs, list):
679
                inputs = [inputs, inputs]
680
            for input, model, optimizer in zip(inputs, models, optimizers):
681
                optimizer.zero_grad()
682

683
                if i == 3:
684
                    # Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
685
                    # is same as the step in 'forloop'.
686
                    model[2].requires_grad_(False)
687
                if i == 5:
688
                    # Unfreeze the layer after 2 iters.
689
                    model[2].requires_grad_(True)
690

691
                # Test that step behaves as expected (a no-op) when grads are set to None
692
                if i != 2:
693
                    output = model(input)
694
                    loss = output.sum()
695
                    loss.backward()
696

697
                optimizer.step()
698
                state.append(optimizer.state)
699
                updated_params.append(model.parameters())
700

701
            og_state, new_state = state
702
            for og_p, new_p in zip(updated_params[0], updated_params[1]):
703
                tracker.add(og_p)
704
                tracker.pop_check_set(new_p, self)
705

706
                # check that optimizer states are the same
707
                og_p_state = og_state[og_p]
708
                new_p_state = new_state[new_p]
709
                if assert_step_dtype is not None:
710
                    if torch.is_tensor(og_p_state.get("step", None)):
711
                        self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
712
                    if torch.is_tensor(new_p_state.get("step", None)):
713
                        self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
714
                for k in og_p_state:
715
                    tracker.add(og_p_state[k])
716
                    tracker.pop_check_set(new_p_state[k], self)
717

718
            self.assertTrue(tracker.all_popped())
719

720
    def _test_derived_optimizers(
721
        self,
722
        device,
723
        dtype,
724
        optim_info,
725
        flag,
726
        reduced_precision=False,
727
        assert_step_dtype=None,
728
    ):
729
        """
730
        Given a flag 'fused' or 'foreach', test for parity of optimizer state
731
        and updated parameters between when the flag is set to True and False
732
        for provided optimizer configurations.
733
        """
734
        assert flag in ("foreach", "fused")
735
        assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
736

737
        optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
738
        optim_cls = optim_info.optim_cls
739
        for optim_input in optim_inputs:
740
            models, optimizers = [], []
741
            kwargs = deepcopy(optim_input.kwargs)
742
            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
743
                # capturable is not supported on CPU
744
                continue
745
            for flag_value in (False, True):
746
                kwargs[flag] = flag_value
747
                input = torch.tensor(
748
                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
749
                ).reshape(3, 2)
750

751
                torch.manual_seed(1)
752
                model = torch.nn.Sequential(
753
                    torch.nn.Linear(2, 3),
754
                    torch.nn.Sigmoid(),
755
                    torch.nn.Linear(3, 1),
756
                    torch.nn.Sigmoid(),
757
                )
758
                model.to(dtype=dtype, device=device)
759

760
                # foreach/fused optimizers should be tested with a
761
                # zero_size tensor as its last param.
762
                # ref: https://github.com/pytorch/pytorch/issues/100701
763
                empty_param = torch.empty(
764
                    (), device=device, dtype=dtype, requires_grad=True
765
                )
766
                empty_param.grad = torch.rand_like(empty_param)
767
                params = list(model.parameters()) + [empty_param]
768

769
                optimizer = optim_cls(params, **kwargs)
770
                models.append(model)
771
                optimizers.append(optimizer)
772

773
            self._compare_between(
774
                input, models, optimizers, assert_eq_kwargs, assert_step_dtype
775
            )
776

777
    @skipMPS  # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
778
    @optims(
779
        [optim for optim in optim_db if "foreach" in optim.supported_impls],
780
        dtypes=[torch.float64],
781
    )
782
    def test_foreach_matches_forloop(self, device, dtype, optim_info):
783
        self._test_derived_optimizers(device, dtype, optim_info, "foreach")
784

785
    @onlyCUDA
786
    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
787
    @parametrize("impl", ["foreach", "fused"])
788
    @optims(
789
        [
790
            optim
791
            for optim in optim_db
792
            if "foreach" in optim.supported_impls or "fused" in optim.supported_impls
793
        ]
794
    )
795
    def test_mixed_device_dtype(self, device, dtype, optim_info, impl):
796
        """
797
        Similar in essence to _test_derived_optimizers above. The main difference is that
798
        _test_derived_optimizers uses model parameters whereas we randomly pass in
799
        parameters of different dtypes and devices here. We need multiple GPUs (vs just a
800
        CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests
801
        that call into this helper when TEST_MULTIGPU.)
802
        """
803
        assert impl in ("foreach", "fused")
804
        if impl == "foreach" and "foreach" not in optim_info.supported_impls:
805
            return unittest.skip(
806
                f"foreach not supported for {optim_info.optim_cls.__name__}"
807
            )
808
        elif impl == "fused" and "cuda" not in optim_info.supports_fused_on:
809
            return unittest.skip(
810
                f"fused not supported for {optim_info.optim_cls.__name__} on cuda"
811
            )
812

813
        params = [
814
            torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True),
815
            torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True),
816
            torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True),
817
            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True),
818
            torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True),
819
            torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True),
820
            torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True),
821
            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True),
822
            torch.randint(
823
                1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False
824
            ),
825
        ]
826

827
        for p in params:
828
            if p.requires_grad:
829
                p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
830

831
        kIterations = 7 if impl == "foreach" else 1
832
        optim_inputs = optim_info.optim_inputs_func(device=device)
833
        optim_cls = optim_info.optim_cls
834
        for optim_input in optim_inputs:
835
            updated_params, state = [], []
836
            kwargs = deepcopy(optim_input.kwargs)
837
            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
838
                # capturable is not supported on CPU
839
                continue
840
            for use_impl in (False, True):
841
                kwargs[impl] = use_impl
842
                params_clone = []
843
                for p in params:
844
                    p_clone = p.clone().detach()
845
                    if p.requires_grad:
846
                        p_clone.requires_grad = True
847
                        p_clone.grad = p.grad.clone().detach()
848
                        params_clone.append(p_clone)
849

850
                optimizer = optim_cls(params_clone, **kwargs)
851
                for _ in range(kIterations):
852
                    optimizer.step()
853

854
                state.append(optimizer.state)
855
                updated_params.append(params_clone)
856

857
            og_state, new_state = state
858
            for og_p, new_p in zip(updated_params[0], updated_params[1]):
859
                # Increasing the tolerance as we are collating lots of ops together for optimizers and
860
                # the designated tolerances are for single op only.
861
                single_rtol, single_atol = torch.testing._comparison.get_tolerances(
862
                    new_p.dtype, rtol=None, atol=None
863
                )
864
                rtol = 5 * single_rtol
865
                atol = 5 * single_atol
866

867
                self.assertEqual(og_p, new_p, rtol=rtol, atol=atol)
868

869
                # check that optimizer states are the same
870
                og_p_state = og_state[og_p]
871
                new_p_state = new_state[new_p]
872

873
                for k in og_p_state:
874
                    actual = new_p_state[k]
875
                    self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol)
876

877
    @onlyCUDA
878
    @optims(
879
        [optim for optim in optim_db if "foreach" in optim.supported_impls],
880
        dtypes=[torch.float64],
881
    )
882
    def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info):
883
        # https://github.com/pytorch/pytorch/issues/110940
884
        # We coerce step to always be float32 unless the
885
        # default dtype is higher prec float64
886
        old_default_dtype = torch.get_default_dtype()
887
        for default_dtype in [torch.float64, torch.float16]:
888
            try:
889
                torch.set_default_dtype(default_dtype)
890
                self._test_derived_optimizers(
891
                    device,
892
                    dtype,
893
                    optim_info,
894
                    "foreach",
895
                    reduced_precision=default_dtype == torch.float16,
896
                    assert_step_dtype=(
897
                        torch.float64
898
                        if default_dtype == torch.float64
899
                        else torch.float32
900
                    ),
901
                )
902
            finally:
903
                torch.set_default_dtype(old_default_dtype)
904

905
    @onlyCUDA
906
    @largeTensorTest("72GB", "cuda")
907
    @optims(
908
        [optim for optim in optim_db if "foreach" in optim.supported_impls],
909
        dtypes=[torch.float16],
910
    )
911
    def test_foreach_large_tensor(self, device, dtype, optim_info):
912
        optim_cls = optim_info.optim_cls
913
        optim_inputs = optim_info.optim_inputs_func(device=device)
914
        for optim_input in optim_inputs:
915
            params = [torch.ones(2**32, device=device, dtype=dtype)]
916
            params[0].grad = torch.zeros_like(params[0])
917
            optimizer = optim_cls(params, foreach=True, **optim_input.kwargs)
918
            optimizer.step()
919

920
    @onlyCUDA
921
    @optims(
922
        [optim for optim in optim_db if "foreach" in optim.supported_impls],
923
        dtypes=[torch.float32],
924
    )
925
    def test_peak_memory_foreach(self, device, dtype, optim_info):
926
        nparams = 10
927
        optim_inputs = optim_info.optim_inputs_func(device=device)
928
        optim_cls = optim_info.optim_cls
929
        for optim_input in optim_inputs:
930
            kwargs = deepcopy(optim_input.kwargs)
931
            max_mems = []
932
            for flag_value in (False, True):
933
                kwargs["foreach"] = flag_value
934
                # The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks
935
                # of 512, meaning any tensor that occupies <512 bytes of memory will allocate a
936
                # whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param
937
                # is size 512 exactly, making our later calculations for intermediate_size easy.
938
                param = torch.rand(16, 8, device=device, dtype=dtype)
939
                params = [torch.rand_like(param) for _ in range(nparams)]
940

941
                optimizer = optim_cls(params, **kwargs)
942

943
                for p in params:
944
                    p.grad = torch.rand_like(p)
945

946
                optimizer.step()
947
                import gc
948

949
                gc.collect()
950
                torch.cuda.reset_peak_memory_stats()
951
                optimizer.step()
952
                gc.collect()
953
                max_mems.append(torch.cuda.max_memory_allocated())
954

955
            st_max_mem, mt_max_mem = max_mems
956
            intermediate_size = nparams * param.nelement() * param.element_size()
957
            nintermediates = 1  # we expect a budget of 1 intermediate most of the time
958

959
            # Check the param group directly to handle if the compiler set capturable
960
            if optimizer.param_groups[0].get(
961
                "capturable", False
962
            ) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
963
                # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
964
                # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
965
                # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
966
                nintermediates = 3
967
                if optim_cls.__name__ == "NAdam":
968
                    # with capturable in NAdam, we have 3 extra intermediates for the
969
                    # bias_correction, mus, and mu_nexts
970
                    if TEST_WITH_TORCHDYNAMO:
971
                        # With dynamo, the eager/FX backend appears to hold memory longer than
972
                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
973
                        nintermediates = 8
974
                    else:
975
                        nintermediates = 5
976

977
                if optim_cls.__name__ == "RAdam":
978
                    # RAdam has four intermediates with capturable
979
                    # num, unrect_step_size, buffer, grouped_grads
980
                    if TEST_WITH_TORCHDYNAMO:
981
                        # With dynamo, the eager/FX backend appears to hold memory than
982
                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
983
                        nintermediates = 6
984
                    else:
985
                        nintermediates = 4
986

987
            elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]:
988
                # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
989
                # Adagrad uses std and grads at the same time
990
                # RMSprop uses avg and grads
991
                # Adafactor uses row/col var and its mean
992
                nintermediates = 2
993

994
                if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False):
995
                    # When maximize is True, Adafactor also tracks device_grad
996
                    nintermediates = 3
997

998
            # Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
999
            # which makes the foreach memory check fail
1000
            if TEST_WITH_TORCHDYNAMO:
1001
                st_max_mem += 6000
1002

1003
            expected_max_mem = st_max_mem + intermediate_size * nintermediates
1004
            # hipcc currently can't generate efficient code for the small buffer optimization
1005
            # code path (see Note [small buffer optimization] for details), thus we always
1006
            # dynamically allocate the tensor metadata for ROCM. Adjusting the expected max
1007
            # memory usage to account for this.
1008
            if TEST_WITH_ROCM:
1009
                expected_max_mem *= 1.02
1010

1011
            self.assertLessEqual(mt_max_mem, expected_max_mem)
1012

1013
    @optims(
1014
        [optim for optim in optim_db if "fused" in optim.supported_impls],
1015
        dtypes=floating_types_and(
1016
            torch.bfloat16,
1017
            torch.float16,
1018
        ),
1019
    )
1020
    def test_fused_matches_forloop(self, device, dtype, optim_info):
1021
        if _get_device_type(device) not in optim_info.supports_fused_on:
1022
            self.skipTest(
1023
                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1024
            )
1025
        if _get_device_type(device) == "mps" and dtype not in (
1026
            torch.float16,
1027
            torch.float32,
1028
        ):
1029
            self.skipTest("MPS supports only torch.float16 and torch.float32")
1030
        self._test_derived_optimizers(device, dtype, optim_info, "fused")
1031

1032
    @optims(
1033
        [optim for optim in optim_db if "fused" in optim.supported_impls],
1034
        dtypes=(torch.float32,),
1035
    )
1036
    def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
1037
        if _get_device_type(device) not in optim_info.supports_fused_on:
1038
            self.skipTest(
1039
                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1040
            )
1041

1042
        with torch.device("meta"):
1043
            model = torch.nn.Sequential(
1044
                torch.nn.Linear(2, 3),
1045
                torch.nn.Sigmoid(),
1046
                torch.nn.Linear(3, 1),
1047
                torch.nn.Sigmoid(),
1048
            ).to(dtype)
1049

1050
        optimizer = optim_info.optim_cls(model.parameters(), fused=True)
1051
        with torch.device("meta"):
1052
            for p in model.parameters():
1053
                p.grad = torch.rand_like(p)
1054

1055
        with self.assertRaisesRegex(
1056
            RuntimeError,
1057
            "`fused=True` requires all the params to be floating point Tensors",
1058
        ):
1059
            optimizer.step()
1060

1061
        optimizer.zero_grad(set_to_none=True)
1062
        model.to_empty(device=device)
1063
        for p in model.parameters():
1064
            p.grad = torch.rand_like(p)
1065
        optimizer.step()
1066

1067
    @onlyNativeDeviceTypes
1068
    @largeTensorTest("64GB")
1069
    @optims(
1070
        [optim for optim in optim_db if "fused" in optim.supported_impls],
1071
        dtypes=[torch.float16],
1072
    )
1073
    def test_fused_large_tensor(self, device, dtype, optim_info):
1074
        if device not in optim_info.supports_fused_on:
1075
            self.skipTest(
1076
                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1077
            )
1078
        optim_cls = optim_info.optim_cls
1079
        optim_inputs = optim_info.optim_inputs_func(device=device)
1080
        for optim_input in optim_inputs:
1081
            params = [torch.ones(2**32, device=device, dtype=dtype)]
1082
            params[0].grad = torch.zeros_like(params[0])
1083
            optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1084
            optimizer.step()
1085

1086
    @onlyCUDA
1087
    @optims(
1088
        [optim for optim in optim_db if "fused" in optim.supported_impls],
1089
        dtypes=[torch.float32],
1090
    )
1091
    def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
1092
        if device not in optim_info.supports_fused_on:
1093
            self.skipTest(
1094
                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1095
            )
1096
        optim_cls = optim_info.optim_cls
1097
        optim_inputs = optim_info.optim_inputs_func(device=device)
1098
        num_params = 5
1099
        for optim_input in optim_inputs:
1100
            for no_grad_scale in (False, True):
1101
                params = [
1102
                    torch.ones((1,), device=device, dtype=dtype)
1103
                    for _ in range(num_params)
1104
                ]
1105
                params_c = [param.clone().detach() for param in params]
1106
                for p in params:
1107
                    p.grad = torch.ones_like(p)
1108
                optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1109
                optimizer.grad_scale = (
1110
                    None
1111
                    if no_grad_scale
1112
                    else torch.ones((1,), dtype=dtype, device=device)
1113
                )
1114
                optimizer.found_inf = torch.ones((), dtype=dtype, device=device)
1115
                optimizer.step()
1116
                for p in params:
1117
                    if "step" in optimizer.state[p]:
1118
                        self.assertEqual(
1119
                            torch.zeros((), dtype=dtype, device=device),
1120
                            optimizer.state[p]["step"],
1121
                        )
1122
                self.assertEqual(params, params_c)
1123

1124
    @parametrize("impl", ["fused", "capturable"])
1125
    @optims(
1126
        [optim for optim in optim_db if "fused" in optim.supported_impls],
1127
        dtypes=[torch.float32],
1128
    )
1129
    def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
1130
        # NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
1131
        # How do we get there? Users typically create CUDA models on fused optimizers and then
1132
        # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
1133
        # Since this is a unit test, it is more expedient to simulate what the state_dict
1134
        # would look like, which is basically CPU tensors with fused/capturable flag = True.
1135
        optim_cls = optim_info.optim_cls
1136
        opt_name = optim_cls.__name__
1137
        if opt_name in ("SGD", "Adagrad") and impl == "capturable":
1138
            # Capturable SGD/Adagrad does not exist
1139
            self.skipTest("SGD does not currently support capturable")
1140
        if _get_device_type(device) == "cpu":
1141
            self.skipTest("Test is only for non-cpu devices")
1142
        elif (
1143
            impl == "fused"
1144
            and _get_device_type(device) not in optim_info.supports_fused_on
1145
        ):
1146
            self.skipTest(f"{device} is not supported for fused on {opt_name}")
1147
        elif impl == "capturable" and _get_device_type(device) == "mps":
1148
            self.skipTest("MPS does not support capturable")
1149

1150
        cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
1151
        for optim_input in cpu_optim_inputs:
1152
            param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
1153
            optimizer = optim_cls([param], **optim_input.kwargs)
1154
            param.grad = torch.rand_like(param)
1155
            optimizer.step()
1156
            optim_state_dict_cpu = deepcopy(optimizer.state_dict())
1157
            optim_state_dict_cpu["param_groups"][0][impl] = True
1158

1159
            # load
1160
            optim_input.kwargs[impl] = True
1161
            param_device = param.clone().detach().to(device=device)
1162
            optimizer_device = optim_cls([param_device], **optim_input.kwargs)
1163
            optimizer_device.load_state_dict(optim_state_dict_cpu)
1164
            optimizer_device.zero_grad()
1165
            param_device.grad = torch.rand_like(param_device)
1166
            optimizer_device.step()
1167

1168
    @optims(optim_db, dtypes=[torch.float32])
1169
    def test_param_groups_weight_decay(self, device, dtype, optim_info):
1170
        optim_cls = optim_info.optim_cls
1171
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1172
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1173
            device, dtype, optim_info, skip=("differentiable",)
1174
        )
1175
        for optim_input in all_optim_inputs:
1176
            weight_kwargs = optim_input.kwargs
1177
            bias_kwargs = deepcopy(optim_input.kwargs)
1178
            bias_kwargs["weight_decay"] = 0.0
1179

1180
            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1181
            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1182
            input = torch.randn(5, device=device, dtype=dtype)
1183

1184
            optimizer = optim_cls(
1185
                [
1186
                    dict(params=[weight], **weight_kwargs),
1187
                    dict(params=[bias], **bias_kwargs),
1188
                ]
1189
            )
1190

1191
            loss = (weight.mv(input) + bias).pow(2).sum()
1192
            initial_value = loss.item()
1193
            for _ in range(20):
1194
                optimizer.zero_grad()
1195
                loss = (weight.mv(input) + bias).pow(2).sum()
1196
                loss.backward()
1197
                if optim_info.only_supports_sparse_grads:
1198
                    # For this test, we naively convert the Tensor layout, which we know does
1199
                    # NOT represent the expected use case for optims like SparseAdam!
1200
                    weight.grad = weight.grad.to_sparse()
1201
                    bias.grad = bias.grad.to_sparse()
1202
                optimizer.step()
1203

1204
            # Test that the direction of loss moved appropriately
1205
            if optim_input.kwargs.get("maximize", False):
1206
                self.assertGreater(loss.item(), initial_value)
1207
            else:
1208
                self.assertLess(loss.item(), initial_value)
1209

1210
    @optims(optim_db, dtypes=[torch.float32])
1211
    def test_param_groups_lr(self, device, dtype, optim_info):
1212
        optim_cls = optim_info.optim_cls
1213
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1214
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1215
            device, dtype, optim_info, skip=("differentiable",)
1216
        )
1217
        for optim_input in all_optim_inputs:
1218
            # optim_input.kwargs will be the param group kwargs, which should have >0 lr
1219
            if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
1220
                optim_input.kwargs["lr"] = 1e-3
1221
            outer_kwargs = {"lr": 1e-28}
1222
            if optim_cls.__name__ == "Rprop":
1223
                # Allow min step size to be 0
1224
                outer_kwargs["step_sizes"] = (0, 50)
1225

1226
            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1227
            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1228
            irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
1229
            irrelevant_clone = irrelevant.clone()
1230
            input = torch.randn(5, device=device, dtype=dtype)
1231
            optimizer = optim_cls(
1232
                [
1233
                    dict(params=[weight, bias], **optim_input.kwargs),
1234
                    dict(params=[irrelevant]),
1235
                ],
1236
                **outer_kwargs,
1237
            )
1238

1239
            loss = (weight.mv(input) + bias).pow(2).sum()
1240
            initial_value = loss.item()
1241
            for _ in range(20):
1242
                optimizer.zero_grad()
1243
                loss = (weight.mv(input) + bias).pow(2).sum()
1244
                loss.backward()
1245
                irrelevant.grad = torch.rand_like(irrelevant)
1246
                if optim_info.only_supports_sparse_grads:
1247
                    # For this test, we naively convert the Tensor layout, which we know does
1248
                    # NOT represent the expected use case for optims like SparseAdam!
1249
                    weight.grad = weight.grad.to_sparse()
1250
                    bias.grad = bias.grad.to_sparse()
1251
                    irrelevant.grad = irrelevant.grad.to_sparse()
1252
                optimizer.step()
1253

1254
            # Test that the direction of loss moved appropriately
1255
            if optim_input.kwargs.get("maximize", False):
1256
                self.assertGreater(loss.item(), initial_value)
1257
            else:
1258
                self.assertLess(loss.item(), initial_value)
1259

1260
            # Test that irrelevant parameters were not updated since lr was almost 0
1261
            self.assertEqual(irrelevant, irrelevant_clone)
1262

1263
    @optims(optim_db, dtypes=[torch.float32])
1264
    def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
1265
        optim_cls = optim_info.optim_cls
1266
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1267
            device, dtype, optim_info
1268
        )
1269
        params = [
1270
            torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype)
1271
            for _ in range(2)
1272
        ]
1273
        old_params = [p.clone().detach() for p in params]
1274

1275
        def closure():
1276
            return torch.tensor([1], device=device, dtype=dtype)
1277

1278
        for optim_input in all_optim_inputs:
1279
            optimizer = optim_cls(params, **optim_input.kwargs)
1280
            optimizer.step(closure)
1281

1282
    @optims(optim_db, dtypes=[torch.float32])
1283
    def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info):
1284
        optim_cls = optim_info.optim_cls
1285
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1286
            device, dtype, optim_info
1287
        )
1288
        param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
1289
        old_param = param.clone().detach()
1290

1291
        def closure():
1292
            return torch.tensor([1], device=device, dtype=dtype)
1293

1294
        for optim_input in all_optim_inputs:
1295
            kwargs = optim_input.kwargs
1296

1297
            # params will decay even if grads are empty if weight_decay != 0,
1298
            # and capturable doesn't work for CPU tensors
1299
            if kwargs.get("weight_decay", 0) != 0:
1300
                continue
1301

1302
            # AdamW params will be updated regardless of grads due to lr, so make lr smaller
1303
            if optim_cls.__name__ == "AdamW":
1304
                kwargs["lr"] = (
1305
                    torch.tensor(1e-5)
1306
                    if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
1307
                    else 1e-5
1308
                )
1309

1310
            if kwargs.get("differentiable", False):
1311
                params = [param.clone()]
1312
            else:
1313
                params = [param]
1314

1315
            optimizer = optim_cls(params, **kwargs)
1316
            if optim_info.only_supports_sparse_grads:
1317
                # Intentionally construct a multidimensional empty v for the sparse grad
1318
                # Single dim v passes the test while multidim correctly repros the issue
1319
                # https://github.com/pytorch/pytorch/issues/82486
1320
                i = torch.empty((1, 0), device=device, dtype=dtype)
1321
                v = torch.empty((0, 1), device=device, dtype=dtype)
1322
                params[0].grad = torch.sparse_coo_tensor(
1323
                    i, v, (5, 1), device=device, dtype=dtype
1324
                )
1325
            else:
1326
                params[0].grad = torch.zeros_like(params[0])
1327
            optimizer.step(closure)
1328
            self.assertEqual(old_param, params[0])
1329

1330
    @optims(optim_db, dtypes=[torch.float32])
1331
    def test_optimizer_can_be_printed(self, device, dtype, optim_info):
1332
        optim_cls = optim_info.optim_cls
1333
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1334
            device, dtype, optim_info
1335
        )
1336
        params = [
1337
            Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype))
1338
            for _ in range(2)
1339
        ]
1340
        for optim_input in all_optim_inputs:
1341
            optimizer = optim_cls(params, **optim_input.kwargs)
1342
            optimizer.__repr__()
1343

1344
    @optims(optim_db, dtypes=[torch.float32])
1345
    def test_state_dict_deterministic(self, device, dtype, optim_info):
1346
        optim_cls = optim_info.optim_cls
1347

1348
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1349
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1350
            device, dtype, optim_info, skip=("differentiable",)
1351
        )
1352
        weight = Parameter(
1353
            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1354
        )
1355
        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1356
        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1357
        params = [weight, bias]
1358

1359
        def fwd_bwd(optim, w, b, i):
1360
            optim.zero_grad()
1361
            loss = (w.mv(i) + b).pow(2).sum()
1362
            loss.backward()
1363
            if optim_info.only_supports_sparse_grads:
1364
                if w.grad is not None:
1365
                    w.grad = w.grad.to_sparse()
1366
                if b.grad is not None:
1367
                    b.grad = b.grad.to_sparse()
1368
            return loss
1369

1370
        for optim_input in all_optim_inputs:
1371
            optimizer = optim_cls(params, **optim_input.kwargs)
1372
            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1373

1374
            # Prime the optimizer
1375
            for _ in range(10):
1376
                if optim_info.step_requires_closure:
1377
                    optimizer.step(closure)
1378
                else:
1379
                    closure()
1380
                    optimizer.step()
1381

1382
            # Clone the weights and construct a new optimizer for them
1383
            with torch.no_grad():
1384
                weight_c = Parameter(weight.clone())
1385
                bias_c = Parameter(bias.clone())
1386

1387
            optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1388
            closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1389

1390
            # Load the state dict from the original optimizer into the new one
1391
            optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1392

1393
            # Run both optimizers in parallel
1394
            for _ in range(10):
1395
                if optim_info.step_requires_closure:
1396
                    optimizer.step(closure)
1397
                    optimizer_c.step(closure_c)
1398
                else:
1399
                    closure()
1400
                    closure_c()
1401
                    optimizer.step()
1402
                    optimizer_c.step()
1403

1404
                self.assertEqual(weight, weight_c)
1405
                self.assertEqual(bias, bias_c)
1406

1407
            # Make sure state dict is deterministic with equal (not identical) parameters
1408
            self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1409

1410
            # Make sure repeated parameters have identical representation (see #36831)
1411
            optimizer_c.param_groups.extend(optimizer_c.param_groups)
1412
            self.assertEqual(
1413
                optimizer.state_dict()["param_groups"][-1],
1414
                optimizer_c.state_dict()["param_groups"][-1],
1415
            )
1416

1417
    @optims(optim_db, dtypes=[torch.float32])
1418
    def test_can_load_older_state_dict(self, device, dtype, optim_info):
1419
        optim_cls = optim_info.optim_cls
1420

1421
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1422
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1423
            device, dtype, optim_info, skip=("differentiable",)
1424
        )
1425
        for optim_input in all_optim_inputs:
1426
            torch.manual_seed(1)
1427
            model = torch.nn.Sequential(
1428
                torch.nn.Conv2d(4, 2, 1, stride=2),
1429
                torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
1430
            )
1431
            model.to(dtype=dtype, device=device)
1432
            input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
1433
            optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
1434

1435
            def fwd_bwd(optim, mod, i):
1436
                optim.zero_grad()
1437
                loss = mod(i).sum()
1438
                loss.backward()
1439
                return loss
1440

1441
            for _ in range(3):
1442
                if optim_info.step_requires_closure:
1443
                    optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1444
                else:
1445
                    fwd_bwd(optimizer, model, input)
1446
                    optimizer.step()
1447

1448
            # old_state_dict has all new flags del'd
1449
            old_state_dict = deepcopy(optimizer.state_dict())
1450
            old_state_dict_pg = old_state_dict["param_groups"]
1451
            for group in old_state_dict_pg:
1452
                for flag in optim_info.not_og_supported_flags:
1453
                    if flag in group:
1454
                        del group[flag]
1455

1456
            optimizer.load_state_dict(old_state_dict)
1457

1458
            # Make sure we can still step
1459
            if optim_info.step_requires_closure:
1460
                optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1461
            else:
1462
                fwd_bwd(optimizer, model, input)
1463
                optimizer.step()
1464

1465
    @optims(optim_db, dtypes=[torch.float32])
1466
    def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
1467
        optim_cls = optim_info.optim_cls
1468

1469
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1470
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1471
            device, dtype, optim_info, skip=("differentiable",)
1472
        )
1473
        weight = Parameter(
1474
            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1475
        )
1476
        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1477
        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1478
        params = [weight, bias]
1479

1480
        def fwd_bwd(optim, w, b, i):
1481
            optim.zero_grad()
1482
            loss = (w.mv(i) + b).pow(2).sum()
1483
            loss.backward()
1484
            if optim_info.only_supports_sparse_grads:
1485
                weight.grad = weight.grad.to_sparse()
1486
                bias.grad = bias.grad.to_sparse()
1487
            return loss
1488

1489
        for optim_input in all_optim_inputs:
1490
            optimizer = optim_cls(params, **optim_input.kwargs)
1491
            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1492

1493
            # Prime the optimizer
1494
            for _ in range(3):
1495
                optimizer.step(closure)
1496

1497
            sd = optimizer.state_dict()
1498

1499
            # === Check saved/loaded state_dict are the same (including weights_only load). ===
1500
            with tempfile.TemporaryFile() as f:
1501
                torch.save(sd, f)
1502
                f.seek(0)
1503
                sd_copy = torch.load(f)
1504
                self.assertEqual(sd_copy, sd)
1505
                del sd_copy
1506
                f.seek(0)
1507
                sd_copy_wo = torch.load(f, weights_only=True)
1508
                self.assertEqual(sd_copy_wo, sd)
1509

1510
    @optims(optim_db, dtypes=[torch.float32])
1511
    def test_load_nontensor_step(self, device, dtype, optim_info):
1512
        optim_cls = optim_info.optim_cls
1513

1514
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1515
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1516
            device, dtype, optim_info, skip=("differentiable",)
1517
        )
1518
        params = [
1519
            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1520
        ]
1521
        for p in params:
1522
            p.grad = torch.rand_like(p)
1523
            if optim_info.only_supports_sparse_grads:
1524
                # For this test, we naively convert the Tensor layout, which we know does
1525
                # NOT represent the expected use case for optims like SparseAdam!
1526
                p.grad = p.grad.to_sparse()
1527

1528
        # Needed for second order optims like LBFGS
1529
        closure_loss = torch.rand(1, device=device, dtype=dtype)
1530

1531
        def closure():
1532
            return closure_loss if optim_info.step_requires_closure else None
1533

1534
        for optim_input in all_optim_inputs:
1535
            kwargs = optim_input.kwargs
1536
            optimizer = optim_cls(params, **optim_input.kwargs)
1537
            for _ in range(3):
1538
                optimizer.step(closure)
1539
            state_dict = deepcopy(optimizer.state_dict())
1540
            for p_state in state_dict["state"].values():
1541
                if "step" in p_state and torch.is_tensor(p_state["step"]):
1542
                    p_state["step"] = p_state["step"].item()
1543
            optimizer.load_state_dict(state_dict)
1544
            optimizer.step(closure)
1545

1546
    @onlyCUDA
1547
    @optims(optim_db, dtypes=[torch.float32])
1548
    def test_state_dict_with_cuda_params(self, device, dtype, optim_info):
1549
        optim_cls = optim_info.optim_cls
1550

1551
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1552
        # We limit our configs to CPU only, because we will be moving them to CUDA later
1553
        cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1554
            "cpu", dtype, optim_info, skip=("differentiable",)
1555
        )
1556

1557
        # Needed for second order optims like LBFGS
1558
        closure_loss = torch.rand(1, device=device, dtype=dtype)
1559

1560
        def closure():
1561
            return closure_loss if optim_info.step_requires_closure else None
1562

1563
        for optim_input in cpu_optim_inputs:
1564
            if (
1565
                "fused" in optim_input.kwargs
1566
                and "cuda" not in optim_info.supports_fused_on
1567
            ):
1568
                self.skipTest(
1569
                    f"cuda is not supported for fused on {optim_cls.__name__}"
1570
                )
1571
            params = [
1572
                Parameter(torch.randn(2, 3, device="cpu", dtype=dtype))
1573
                for _ in range(2)
1574
            ]
1575
            for p in params:
1576
                p.grad = torch.randn_like(p)
1577
                if optim_info.only_supports_sparse_grads:
1578
                    # For this test, we naively convert the Tensor layout, which we know does
1579
                    # NOT represent the expected use case for optims like SparseAdam!
1580
                    p.grad = p.grad.to_sparse()
1581

1582
            optimizer = optim_cls(params, **optim_input.kwargs)
1583

1584
            for _ in range(3):
1585
                optimizer.step(closure)
1586

1587
            with torch.no_grad():
1588
                params_cuda = [p.to(device="cuda") for p in params]
1589
                for i, p in enumerate(params_cuda):
1590
                    p.grad = params[i].grad.to(device="cuda")
1591
            optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs)
1592

1593
            state_dict_cpu = deepcopy(optimizer.state_dict())
1594
            state_dict_cuda = deepcopy(optimizer.state_dict())
1595
            optimizer_cuda.load_state_dict(state_dict_cuda)
1596

1597
            # Make sure state_dict_cuda isn't modified by merely calling load_state_dict
1598
            self.assertEqual(state_dict_cpu, state_dict_cuda)
1599

1600
            # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
1601
            capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
1602
            fused = state_dict_cpu["param_groups"][0].get("fused", False)
1603
            new_state_dict = optimizer_cuda.state_dict()
1604
            for state_cpu, state_cuda in zip(
1605
                state_dict_cpu["state"].values(), new_state_dict["state"].values()
1606
            ):
1607
                if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
1608
                    self.assertEqual(
1609
                        state_cuda["step"].device.type,
1610
                        "cuda" if capturable or fused else "cpu",
1611
                    )
1612

1613
            for _ in range(5):
1614
                optimizer.step(closure)
1615
                optimizer_cuda.step(closure)
1616
                self.assertEqual(params, params_cuda)
1617
                self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
1618

1619
    @staticmethod
1620
    def _state_dict_pre_hook(optimizer: Optimizer) -> None:
1621
        optimizer.state["test"] = 1
1622

1623
    @staticmethod
1624
    def _state_dict_post_hook(
1625
        optimizer: Optimizer, state_dict: Dict[str, Any]
1626
    ) -> Dict[str, Any]:
1627
        if "test" in state_dict["state"]:
1628
            state_dict["state"].pop("test")
1629
            state_dict["ran_state_dict_pre_hook"] = True
1630
        else:
1631
            state_dict["ran_state_dict_pre_hook"] = False
1632
        return state_dict
1633

1634
    @optims(optim_db, dtypes=[torch.float32])
1635
    def test_state_dict_pre_hook(self, device, dtype, optim_info):
1636
        optim_cls = optim_info.optim_cls
1637
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1638
            device, dtype, optim_info
1639
        )
1640
        for optim_input in all_optim_inputs:
1641
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1642
            optim = optim_cls([param], **optim_input.kwargs)
1643
            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1644
            state_dict = optim.state_dict()
1645
            self.assertEqual(state_dict["state"]["test"], 1)
1646

1647
    @optims(optim_db, dtypes=[torch.float32])
1648
    def test_state_dict_post_hook(self, device, dtype, optim_info):
1649
        optim_cls = optim_info.optim_cls
1650
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1651
            device, dtype, optim_info
1652
        )
1653
        for optim_input in all_optim_inputs:
1654
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1655
            optim = optim_cls([param], **optim_input.kwargs)
1656
            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1657
            state_dict = optim.state_dict()
1658
            self.assertFalse(state_dict["ran_state_dict_pre_hook"])
1659

1660
    @optims(optim_db, dtypes=[torch.float32])
1661
    def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
1662
        optim_cls = optim_info.optim_cls
1663
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1664
            device, dtype, optim_info
1665
        )
1666
        for optim_input in all_optim_inputs:
1667
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1668
            optim = optim_cls([param], **optim_input.kwargs)
1669
            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1670
            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1671
            state_dict = optim.state_dict()
1672
            self.assertFalse("test" in state_dict["state"])
1673
            self.assertTrue(state_dict["ran_state_dict_pre_hook"])
1674

1675
    @staticmethod
1676
    def _load_state_dict_pre_hook1(
1677
        optimizer: Optimizer, state_dict: Dict[str, Any]
1678
    ) -> None:
1679
        state_dict["param_groups"][0]["lr"] = 0.002
1680

1681
    @staticmethod
1682
    def _load_state_dict_pre_hook2(
1683
        optimizer: Optimizer, state_dict: Dict[str, Any]
1684
    ) -> Dict[str, Any]:
1685
        # The typical use case for returning a state dict is to drastically modify the state dict.
1686
        # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
1687
        my_state_dict = deepcopy(state_dict)
1688
        my_state_dict["param_groups"][0]["lr"] = 0.003
1689
        return my_state_dict
1690

1691
    @staticmethod
1692
    def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
1693
        optimizer.state["ran_load_state_dict_pre_hook2"] = (
1694
            optimizer.param_groups[0]["lr"] == 0.003
1695
        )
1696
        optimizer.state["ran_load_state_dict_post_hook"] = True
1697

1698
    @optims(optim_db, dtypes=[torch.float32])
1699
    def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
1700
        optim_cls = optim_info.optim_cls
1701
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1702
            device, dtype, optim_info
1703
        )
1704
        for optim_input in all_optim_inputs:
1705
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1706
            optim = optim_cls([param], **optim_input.kwargs)
1707
            state_dict = optim.state_dict()
1708

1709
            # usually one would have a new optim instance here, but it's all the same here
1710
            optim.register_load_state_dict_pre_hook(
1711
                self.__class__._load_state_dict_pre_hook1
1712
            )
1713
            optim.load_state_dict(state_dict)
1714
            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1715

1716
            optim.register_load_state_dict_pre_hook(
1717
                self.__class__._load_state_dict_pre_hook2, prepend=True
1718
            )
1719
            optim.load_state_dict(state_dict)
1720
            # If prepend were False would be 0.003 but since prepend is True, the other hook overrides
1721
            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1722

1723
    @optims(optim_db, dtypes=[torch.float32])
1724
    def test_load_state_dict_post_hook(self, device, dtype, optim_info):
1725
        optim_cls = optim_info.optim_cls
1726
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1727
            device, dtype, optim_info
1728
        )
1729
        for optim_input in all_optim_inputs:
1730
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1731
            optim = optim_cls([param], **optim_input.kwargs)
1732

1733
            optim.register_load_state_dict_post_hook(
1734
                self.__class__._load_state_dict_post_hook
1735
            )
1736
            optim.load_state_dict(optim.state_dict())
1737
            self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
1738
            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1739

1740
    @optims(optim_db, dtypes=[torch.float32])
1741
    def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
1742
        optim_cls = optim_info.optim_cls
1743
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1744
            device, dtype, optim_info
1745
        )
1746
        for optim_input in all_optim_inputs:
1747
            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1748
            optim = optim_cls([param], **optim_input.kwargs)
1749

1750
            optim.register_load_state_dict_pre_hook(
1751
                self.__class__._load_state_dict_pre_hook2
1752
            )
1753
            optim.register_load_state_dict_post_hook(
1754
                self.__class__._load_state_dict_post_hook
1755
            )
1756
            optim.load_state_dict(optim.state_dict())
1757
            self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
1758
            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1759

1760
    @optims(optim_db, dtypes=[torch.float32])
1761
    def test_step_post_hook(self, device, dtype, optim_info):
1762
        def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1763
            nonlocal data
1764
            data += 2
1765

1766
        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1767

1768
        def dummy_closure():
1769
            return 1
1770

1771
        closure = dummy_closure if optim_info.step_requires_closure else None
1772

1773
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1774
            device, dtype, optim_info
1775
        )
1776
        for optim_input in all_optim_inputs:
1777
            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1778
            data = 2
1779
            hook_handle = optim.register_step_post_hook(post_hook)
1780

1781
            optim.step(closure)
1782
            optim.step(closure)
1783
            # check if post hooks were registered
1784
            self.assertEqual(data, 6)
1785

1786
            # remove handles, take step and verify that hook is no longer registered
1787
            hook_handle.remove()
1788

1789
            optim.step(closure)
1790
            self.assertEqual(data, 6)
1791

1792
    @optims(optim_db, dtypes=[torch.float32])
1793
    def test_step_pre_hook(self, device, dtype, optim_info):
1794
        def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1795
            nonlocal data
1796
            data += 2
1797

1798
        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1799

1800
        def dummy_closure():
1801
            return 1
1802

1803
        closure = dummy_closure if optim_info.step_requires_closure else None
1804

1805
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1806
            device, dtype, optim_info
1807
        )
1808
        for optim_input in all_optim_inputs:
1809
            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1810
            data = 5
1811
            hook_handle = optim.register_step_pre_hook(pre_hook)
1812

1813
            optim.step(closure)
1814
            optim.step(closure)
1815
            # check if pre hooks were registered
1816
            self.assertEqual(data, 9)
1817

1818
            # remove handles, take step and verify that hook is no longer registered
1819
            hook_handle.remove()
1820

1821
            optim.step(closure)
1822
            self.assertEqual(data, 9)
1823

1824
    @optims(optim_db, dtypes=[torch.float32])
1825
    def test_step_all_hooks(self, device, dtype, optim_info):
1826
        def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1827
            nonlocal data
1828
            data.append(0)
1829

1830
        def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1831
            nonlocal data
1832
            data.append(5)
1833

1834
        def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1835
            nonlocal data
1836
            data.append(1)
1837

1838
        def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1839
            nonlocal data
1840
            data.append(2)
1841

1842
        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1843

1844
        def dummy_closure():
1845
            return 1
1846

1847
        closure = dummy_closure if optim_info.step_requires_closure else None
1848

1849
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1850
            device, dtype, optim_info
1851
        )
1852
        for optim_input in all_optim_inputs:
1853
            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1854
            optim2 = SGD(params)
1855
            data = []
1856

1857
            # register global hooks to both optimizers
1858
            global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
1859
            global_post_handle = register_optimizer_step_post_hook(global_post_hook)
1860

1861
            # register local hooks
1862
            first_pre_handle = optim.register_step_pre_hook(local_pre_hook)
1863
            first_post_handle = optim.register_step_post_hook(local_post_hook)
1864
            second_pre_handle = optim2.register_step_pre_hook(local_pre_hook)
1865
            second_post_handle = optim2.register_step_post_hook(local_post_hook)
1866

1867
            optim.step(closure)
1868
            self.assertListEqual(data, [0, 1, 2, 5])
1869
            optim2.step(closure)
1870
            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
1871
            optim.step(closure)
1872
            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1873

1874
            # remove all hooks
1875
            global_pre_handle.remove()
1876
            global_post_handle.remove()
1877
            first_pre_handle.remove()
1878
            first_post_handle.remove()
1879
            second_pre_handle.remove()
1880
            second_post_handle.remove()
1881

1882
            optim.step(closure)
1883
            optim2.step(closure)
1884
            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1885

1886
    @optims(optim_db, dtypes=[torch.float32])
1887
    def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
1888
        optim_cls = optim_info.optim_cls
1889

1890
        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1891
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1892
            device, dtype, optim_info, skip=("differentiable",)
1893
        )
1894

1895
        params = [
1896
            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1897
        ]
1898
        for p in params:
1899
            p.grad = torch.rand_like(p)
1900
            if optim_info.only_supports_sparse_grads:
1901
                # For this test, we naively convert the Tensor layout, which we know does
1902
                # NOT represent the expected use case for optims like SparseAdam!
1903
                p.grad = p.grad.to_sparse()
1904

1905
        # Needed for second order optims like LBFGS
1906
        def closure():
1907
            return 1 if optim_info.step_requires_closure else None
1908

1909
        def getPublicAttrs(obj):
1910
            return {k for k in obj.__dict__ if not k.startswith("_")}
1911

1912
        for optim_input in all_optim_inputs:
1913
            optimizer = optim_cls(params, **optim_input.kwargs)
1914

1915
            # Make some state
1916
            for _ in range(3):
1917
                if optim_info.step_requires_closure:
1918
                    optimizer.step(closure)
1919
                else:
1920
                    closure()
1921
                    optimizer.step()
1922

1923
            self.assertEqual(
1924
                getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))
1925
            )
1926

1927
    @optims(
1928
        [optim for optim in optim_db if optim.step_requires_closure],
1929
        dtypes=[torch.float32],
1930
    )
1931
    def test_second_order_optims_return_consistent_types(
1932
        self, device, dtype, optim_info
1933
    ):
1934
        # Motivated by #7586
1935
        optim_cls = optim_info.optim_cls
1936
        params = [
1937
            torch.randn(10, 5, device=device, dtype=dtype),
1938
            torch.randn(10, device=device, dtype=dtype),
1939
        ]
1940

1941
        def closure():
1942
            return torch.tensor([10], device=device, dtype=dtype)
1943

1944
        for optim_input in optim_info.optim_inputs_func(device=device):
1945
            # Currently, the only second order optim is LBFGS, so we just go ahead and modify
1946
            # "tolerance_grad", but this may not scale if we add second order optims in the future
1947
            kwargs = optim_input.kwargs
1948
            kwargs["tolerance_grad"] = math.inf
1949
            optim_inf = optim_cls(params, **kwargs)
1950
            kwargs["tolerance_grad"] = -math.inf
1951
            optim_neg_inf = optim_cls(params, **kwargs)
1952

1953
            res1 = optim_inf.step(closure)
1954
            res2 = optim_neg_inf.step(closure)
1955
            self.assertEqual(type(res1), type(res2))
1956

1957
    @onlyCUDA
1958
    @optims(
1959
        [
1960
            optim
1961
            for optim in optim_db
1962
            if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on
1963
        ],
1964
        dtypes=floating_types_and(
1965
            torch.bfloat16,
1966
            torch.float16,
1967
        ),
1968
    )
1969
    def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
1970
        optim_cls = optim_info.optim_cls
1971
        optim_inputs = optim_info.optim_inputs_func(device="cpu")
1972
        for optim_input in optim_inputs:
1973
            inpts, models, optimizers = [], [], []
1974
            for dev in ("cpu", "cuda"):
1975
                kwargs = optim_input.kwargs
1976
                kwargs["fused"] = True
1977
                inpt = torch.tensor(
1978
                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
1979
                ).reshape(3, 2)
1980

1981
                torch.manual_seed(1)
1982
                model = torch.nn.Sequential(
1983
                    torch.nn.Linear(2, 3),
1984
                    torch.nn.Sigmoid(),
1985
                    torch.nn.Linear(3, 1),
1986
                    torch.nn.Sigmoid(),
1987
                )
1988
                model.to(dtype=dtype, device=dev)
1989

1990
                # foreach/fused optimizers should be tested with a
1991
                # zero_size tensor as its last param.
1992
                # ref: https://github.com/pytorch/pytorch/issues/100701
1993
                empty_param = torch.empty(
1994
                    (), device=dev, dtype=dtype, requires_grad=True
1995
                )
1996
                empty_param.grad = torch.rand_like(empty_param)
1997
                params = list(model.parameters()) + [empty_param]
1998

1999
                optimizer = optim_cls(params, **kwargs)
2000
                inpts.append(inpt)
2001
                models.append(model)
2002
                optimizers.append(optimizer)
2003
        self._compare_between(inpts, models, optimizers)
2004

2005
    @onlyCUDA
2006
    @optims(
2007
        [
2008
            o
2009
            for o in optim_db
2010
            if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor")
2011
        ],
2012
        dtypes=[torch.float32],
2013
    )
2014
    def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
2015
        # Test that the default implementations for optimizers are changed to foreach
2016
        # except Adafactor, which defaults to the single tensor impl for memory efficiency.
2017
        optim_cls = optim_info.optim_cls
2018
        model = torch.nn.Linear(5, 5)
2019
        model.to(dtype=dtype, device=device)
2020
        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2021

2022
        import inspect
2023

2024
        module = inspect.getmodule(optim_cls)
2025

2026
        for optim_input in optim_info.optim_inputs_func(device=device):
2027
            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2028
            optim.zero_grad()
2029
            output = model(inpt)
2030
            loss = output.sum()
2031
            loss.backward()
2032
            with patch.object(
2033
                module, f"_multi_tensor_{optim_cls.__name__.lower()}"
2034
            ) as mocked_foreach_impl:
2035
                optim.step()
2036
                self.assertTrue(mocked_foreach_impl.called)
2037

2038
    @optims(optim_db, dtypes=[torch.float32])
2039
    def test_non_empty_state(self, device, dtype, optim_info):
2040
        # There are internal tests that check that the state is not empty
2041
        optim_cls = optim_info.optim_cls
2042
        model = torch.nn.Linear(5, 5)
2043
        model.to(dtype=dtype, device=device)
2044
        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2045

2046
        for optim_input in optim_info.optim_inputs_func(device=device):
2047
            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2048
            optim.zero_grad()
2049
            output = model(inpt)
2050
            loss = output.sum()
2051
            loss.backward()
2052

2053
            if optim_info.only_supports_sparse_grads:
2054
                for param in model.parameters():
2055
                    if param.grad is not None:
2056
                        param.grad = param.grad.to_sparse()
2057

2058
            if optim_info.step_requires_closure:
2059
                optim.step(lambda: 1.0)
2060
            else:
2061
                optim.step()
2062

2063
            for state in optim.state.values():
2064
                self.assertGreater(len(state), 0)
2065

2066

2067
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
2068

2069

2070
if __name__ == "__main__":
2071
    run_tests()
2072

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.