6
from copy import deepcopy
7
from typing import Any, Dict, Tuple
8
from unittest.mock import patch
10
from optim.test_lrscheduler import TestLRScheduler
11
from optim.test_optim import TestDifferentiableOptimizer
12
from optim.test_swa_utils import TestSWAUtils
15
from torch.nn import Parameter
16
from torch.optim import Optimizer, SGD
17
from torch.optim.lr_scheduler import ReduceLROnPlateau
18
from torch.optim.optimizer import (
19
register_optimizer_step_post_hook,
20
register_optimizer_step_pre_hook,
22
from torch.testing._internal.common_cuda import TEST_MULTIGPU
23
from torch.testing._internal.common_device_type import (
24
instantiate_device_type_tests,
28
onlyNativeDeviceTypes,
32
from torch.testing._internal.common_dtype import floating_types_and
33
from torch.testing._internal.common_optimizers import (
35
_get_optim_inputs_including_global_cliquey_kwargs,
41
from torch.testing._internal.common_utils import (
45
TEST_WITH_TORCHDYNAMO,
50
FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
53
def rosenbrock(tensor):
54
assert tensor.size() == torch.Size(
56
), f"Requires tensor with 2 scalars but got {tensor.size()}"
58
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
61
def drosenbrock(tensor):
62
assert tensor.size() == torch.Size(
64
), f"Requires tensor with 2 scalars but got {tensor.size()}"
66
return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
70
class TestOptimRenewed(TestCase):
72
This test class validates the core optimizers and is structured as the correctness of:
73
- The update algorithms (forloop implementation)
74
* Every optimizer's algorithm is most readably implemented through a big for-loop
75
over all the parameters, which is what we refer to as the forloop or single tensor
76
implementation. These algorithms are manually validated by comparing to the paper
77
and systematically validated by assuring that the loss goes the right direction
78
when the optimizer has been applied.
79
* This implementation should compose with optimizer hyperparameters well, such as
80
supporting Tensor LRs, the capturable API, and sparse and complex parameters.
81
- Each varying implementation
82
* We then have implementations that improve upon the performance of the forloop
83
implementation by leveraging fusion, namely our foreach (mult_tensor) and fused
85
* These variations are validated numerically by comparing with the forloop version
86
of the optimizer. In fact, we test most variations this way--we see the forloop
87
implementation as the ground truth and expect that improvements to it in any way
88
should be just as correct.
89
* Both params and optimizer states should be validated numerically.
91
* The optimizer instance should be serializable
92
* Calling save and load should be deterministic
93
* Moving between devices should be seamless
94
* BC - load_state_dict should be able to handle older optimizer states
95
- Hook APIs (everything should fire in the right order)
96
- LR Scheduler integration (composing should not error + should go the right direction)
97
- Parameter groups (should be equivalent to having multiple optimizers)
98
- Erroring (what should error should error)
100
We also cover different ways of generating parameters and grads:
101
- With parameters, we either generate them randomly given specific shapes or we take
102
them from a sample NN module.
103
* Variety is important here because NN modules have type Parameter and randomly
104
generated tensors have type Tensor.
105
* Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo)
106
* Complex parameters should be handled using view_as_real
107
* Parameters can be spread across different devices and different dtypes for any
109
* Parameters can be contiguous and noncontiguous
110
- With grads, we follow suit from the parameters.
111
* Grads can also be None, empty, or zero-valued, and this should not disrupt training.
116
def test_optim_infos_do_not_specify_global_cliquey_kwargs(
117
self, device, dtype, optim_info
119
global_cliquey_flags = ["foreach", "fused", "differentiable"]
120
for optim_input in optim_info.optim_inputs_func(device=device):
122
any(f for f in global_cliquey_flags if f in optim_input.kwargs)
125
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
126
def test_errors(self, device, dtype, optim_info):
127
optim_cls = optim_info.optim_cls
128
error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
130
for error_input in error_inputs:
131
optim_input = error_input.optimizer_error_input
132
params, kwargs = optim_input.params, optim_input.kwargs
133
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
134
if issubclass(error_input.error_type, Warning):
135
with self.assertWarnsRegex(
136
error_input.error_type, error_input.error_regex
138
optim_cls(params, **kwargs)
140
with self.assertRaisesRegex(
141
error_input.error_type, error_input.error_regex
143
optim_cls(params, **kwargs)
144
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
145
optim = optim_cls(params, **kwargs)
146
if issubclass(error_input.error_type, Warning):
147
with self.assertWarnsRegex(
148
error_input.error_type, error_input.error_regex
152
with self.assertRaisesRegex(
153
error_input.error_type, error_input.error_regex
157
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
159
@parametrize("contiguous", [True, False])
160
@parametrize("with_lrsched", [True, False])
161
@optims(optim_db, dtypes=[torch.float32])
162
def test_forloop_goes_right_direction(
163
self, device, dtype, optim_info, contiguous, with_lrsched
165
optim_cls = optim_info.optim_cls
166
schedulers_constructors = (
167
optim_info.scheduler_inputs if with_lrsched else [None]
170
for schedulers_constructor in schedulers_constructors:
173
optim_inputs = optim_info.optim_inputs_func(device=device)
174
for optim_input in optim_inputs:
175
if "foreach" in optim_info.supported_impls:
176
optim_input.kwargs["foreach"] = False
178
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
179
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
182
torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]
185
torch.randn((10, 2), device=device, dtype=dtype)[..., 0]
187
input = torch.randn(5, device=device, dtype=dtype)
189
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
192
for s in (schedulers_constructor if schedulers_constructor else [])
196
optimizer.zero_grad()
197
loss = (weight.mv(input) + bias).pow(2).sum()
199
if optim_info.only_supports_sparse_grads:
202
weight.grad = weight.grad.to_sparse()
203
bias.grad = bias.grad.to_sparse()
206
initial_value = closure().item()
208
if optim_info.step_requires_closure:
209
loss = optimizer.step(closure)
214
for scheduler in schedulers:
215
if isinstance(scheduler, ReduceLROnPlateau):
220
if optim_input.kwargs.get("maximize", False):
221
self.assertGreater(closure().item(), initial_value)
223
self.assertLess(closure().item(), initial_value)
226
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
227
@parametrize("with_lrsched", [True, False])
228
@optims(optim_db, dtypes=[torch.float32])
229
def test_forloop_goes_right_direction_multigpu(
230
self, device, dtype, optim_info, with_lrsched
232
optim_cls = optim_info.optim_cls
233
schedulers_constructors = (
234
optim_info.scheduler_inputs if with_lrsched else [None]
236
for schedulers_constructor in schedulers_constructors:
239
optim_inputs = optim_info.optim_inputs_func(device=device)
240
for optim_input in optim_inputs:
241
if "foreach" in optim_info.supported_impls:
242
optim_input.kwargs["foreach"] = False
244
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
245
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
246
inpt = torch.randn(5, device="cuda:0", dtype=dtype)
248
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
251
for s in (schedulers_constructor if schedulers_constructor else [])
255
optimizer.zero_grad()
256
loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
258
if optim_info.only_supports_sparse_grads:
261
weight.grad = weight.grad.to_sparse()
262
bias.grad = bias.grad.to_sparse()
265
initial_value = closure().item()
267
loss = optimizer.step(closure)
268
for scheduler in schedulers:
269
if isinstance(scheduler, ReduceLROnPlateau):
274
if optim_input.kwargs.get("maximize", False):
275
self.assertGreater(closure().item(), initial_value)
277
self.assertLess(closure().item(), initial_value)
279
@optims(optim_db, dtypes=[torch.float32])
280
def test_param_group_with_lrscheduler_goes_right_direction(
281
self, device, dtype, optim_info
283
optim_cls = optim_info.optim_cls
285
for schedulers_c in optim_info.scheduler_inputs:
286
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
287
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
288
inpt = torch.randn(5, device=device, dtype=dtype)
291
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
292
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
293
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
296
optimizer.zero_grad()
297
loss = (weight.mv(inpt) + bias).pow(2).sum()
299
if optim_info.only_supports_sparse_grads:
302
weight.grad = weight.grad.to_sparse()
303
bias.grad = bias.grad.to_sparse()
306
initial_value = closure().item()
308
loss = optimizer.step(closure)
309
for scheduler in schedulers:
310
if isinstance(scheduler, ReduceLROnPlateau):
315
self.assertLess(closure().item(), initial_value)
317
@optims(optim_db, dtypes=[torch.float32])
318
def test_tensor_lr(self, device, dtype, optim_info):
319
optim_cls = optim_info.optim_cls
322
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
323
device, dtype, optim_info, skip=("differentiable",)
325
for optim_input in all_optim_inputs:
326
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
327
weight_c = weight.clone().detach().requires_grad_(True)
328
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
329
bias_c = bias.clone().detach().requires_grad_(True)
330
inpt = torch.randn(5, device=device, dtype=dtype)
332
kwargs = optim_input.kwargs
336
kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
337
optimizer_r = optim_cls([weight, bias], **kwargs)
340
kwargs["lr"] = torch.tensor(kwargs["lr"])
341
optimizer = optim_cls([weight_c, bias_c], **kwargs)
342
except ValueError as e:
343
self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
346
def closure(optim, w, b, i):
348
loss = (w.mv(i) + b).pow(2).sum()
350
if optim_info.only_supports_sparse_grads:
353
w.grad = w.grad.to_sparse()
354
b.grad = b.grad.to_sparse()
358
if optim_info.step_requires_closure:
360
functools.partial(closure, optimizer_r, weight, bias, inpt)
363
functools.partial(closure, optimizer, weight_c, bias_c, inpt)
366
closure(optimizer_r, weight, bias, inpt)
367
closure(optimizer, weight_c, bias_c, inpt)
369
self.assertEqual(weight, weight_c)
370
self.assertEqual(bias, bias_c)
372
@parametrize("with_lrsched", [True, False])
374
[o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads],
375
dtypes=[torch.float64],
377
def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
378
optim_cls = optim_info.optim_cls
382
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
383
device, dtype, optim_info, skip=("differentiable", "fused")
385
kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
387
if with_lrsched and len(schedulers_constructors) == 0:
390
supported_inputs = []
391
if len(kwarg_updates) != 0:
393
for i in all_optim_inputs:
394
for k in kwarg_updates:
397
hashable_kwargs = tuple(sorted(i.kwargs.items()))
398
if len(i.kwargs) > 0 and hashable_kwargs not in seen:
399
supported_inputs.append(i)
400
seen.add(hashable_kwargs)
401
if "lr" in kwarg_updates:
402
i.kwargs["lr"] = kwarg_updates["lr"]
404
supported_inputs = all_optim_inputs
406
for optim_input in supported_inputs:
407
kwargs = optim_input.kwargs
408
multi_tensor = kwargs.get("foreach", False)
413
torch.tensor([1.5, 1.5]),
414
torch.tensor([1.5, 1.5], dtype=dtype),
417
params_t = [torch.tensor([1.5, 1.5])]
419
params = [Parameter(param_t) for param_t in params_t]
420
optimizer = optim_cls(params, **kwargs)
422
s(optimizer) for s in (schedulers_constructors if with_lrsched else [])
425
if not optim_info.only_supports_sparse_grads:
426
params_c = [Parameter(param_t.clone()) for param_t in params_t]
427
optimizer_c = optim_cls(params_c, **kwargs)
430
for s in (schedulers_constructors if with_lrsched else [])
433
solution = torch.tensor([1, 1])
434
with torch.no_grad():
435
initial_dist = sum(param.dist(solution) for param in params)
437
def get_grad(param, sparse_grad, w):
438
grad = drosenbrock(param)
445
i = torch.tensor([[0, 0]], dtype=torch.int64)
447
v = torch.tensor([x / 4.0, x - x / 4.0])
449
i = torch.tensor([[1, 1]], dtype=torch.int64)
451
v = torch.tensor([y - y / 4.0, y / 4.0])
452
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
455
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
457
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
460
def eval(params, sparse_grad, w):
461
optimizer.zero_grad()
463
loss = sum(rosenbrock(param) for param in params)
465
loss = rosenbrock(params[0])
468
grads_out = [get_grad(param, sparse_grad, w) for param in params]
469
with torch.no_grad():
470
params[0].grad = grads_out[0]
472
params[1].grad = grads_out[1].to(dtype=dtype)
475
for i in range(1800):
478
optimizer.step(functools.partial(eval, params, True, w))
479
for scheduler in schedulers:
480
if isinstance(scheduler, ReduceLROnPlateau):
481
scheduler.step(rosenbrock(params[0]))
484
if not optim_info.only_supports_sparse_grads:
485
optimizer_c.step(functools.partial(eval, params_c, False, w))
486
for scheduler in schedulers_c:
487
if isinstance(scheduler, ReduceLROnPlateau):
488
scheduler.step(rosenbrock(params_c[0]))
493
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
495
if not kwargs.get("maximize", False):
496
self.assertLessEqual(
497
sum(param.dist(solution) for param in params), initial_dist
500
self.assertGreaterEqual(
501
sum(rosenbrock(param) for param in params),
502
sum(rosenbrock(param_t) for param_t in params_t),
506
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
507
def test_complex(self, device, dtype, optim_info):
508
optim_cls = optim_info.optim_cls
511
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
512
device, dtype, optim_info, skip=("differentiable", "fused")
514
for optim_input in all_optim_inputs:
517
torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
518
torch.randn(10, device=device, dtype=dtype, requires_grad=True),
520
10, 5, device=device, dtype=torch.float32, requires_grad=True
525
torch.view_as_real(param).detach().clone().requires_grad_()
526
if param.is_complex()
527
else param.detach().clone().requires_grad_()
529
for param in complex_params
532
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
533
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
539
for param in real_params:
540
grad = torch.randn_like(param)
542
real_steps.append(param.detach().clone())
543
grads_losses.append(grad.clone())
544
loss = torch.randn(1)
545
grads_losses.append(loss.clone())
548
def complex_closure():
549
for param in complex_params:
550
if torch.is_complex(param):
551
grad = torch.view_as_complex(grads_losses.pop(0))
552
complex_steps.append(torch.view_as_real_copy(param.detach()))
554
grad = grads_losses.pop(0)
555
complex_steps.append(param.detach().clone())
557
return grads_losses.pop(0)
560
if optim_info.step_requires_closure:
562
real_optimizer.step(real_closure)
563
complex_optimizer.step(complex_closure)
568
real_optimizer.step()
569
complex_optimizer.step()
572
complex_params_asreal = [
573
torch.view_as_real(param) if param.is_complex() else param
574
for param in complex_params
576
self.assertEqual(real_params, complex_params_asreal)
580
self.assertEqual(complex_steps, real_steps)
583
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
584
def test_complex_2d(self, device, dtype, optim_info):
585
optim_cls = optim_info.optim_cls
588
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
589
device, dtype, optim_info, skip=("differentiable", "fused")
591
for optim_input in all_optim_inputs:
592
if optim_info.step_requires_closure:
603
torch.manual_seed(2024)
605
a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True)
606
a1_real = a1.real.clone().detach()
607
a1_imag = a1.imag.clone().detach()
608
a1_real.requires_grad_()
609
a1_imag.requires_grad_()
610
optim1 = optim_cls([a1], **optim_input.kwargs)
611
optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs)
613
a1_reals = TensorTracker()
614
a1_imags = TensorTracker()
615
a1_grad_reals = TensorTracker()
616
a1_grad_imags = TensorTracker()
617
losses = TensorTracker()
621
loss = rosenbrock(a1).abs()
625
a1_reals.add(a1.real)
626
a1_imags.add(a1.imag)
627
a1_grad_reals.add(a1.grad.real)
628
a1_grad_imags.add(a1.grad.imag)
636
a1_reals.pop_check_set(a1_real, self)
637
a1_imags.pop_check_set(a1_imag, self)
638
a2 = torch.complex(a1_real, a1_imag)
639
loss = rosenbrock(a2).abs()
640
losses.pop_check_set(loss, self)
642
a1_grad_reals.pop_check_set(a1_real.grad, self)
643
a1_grad_imags.pop_check_set(a1_imag.grad, self)
647
if optim_info.step_requires_closure:
649
optim1.step(closure1)
650
optim2.step(closure2)
657
self.assertEqual(a1.real, a1_real)
658
self.assertEqual(a1.imag, a1_imag)
660
self.assertTrue(a1_reals.all_popped())
661
self.assertTrue(a1_imags.all_popped())
662
self.assertTrue(a1_grad_reals.all_popped())
663
self.assertTrue(a1_grad_imags.all_popped())
664
self.assertTrue(losses.all_popped())
666
def _compare_between(
667
self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
672
if assert_eq_kwargs is None:
673
assert_eq_kwargs = {}
675
tracker = TensorTracker(assert_eq_kwargs)
676
for i in range(kIterations):
677
state, updated_params = [], []
678
if not isinstance(inputs, list):
679
inputs = [inputs, inputs]
680
for input, model, optimizer in zip(inputs, models, optimizers):
681
optimizer.zero_grad()
686
model[2].requires_grad_(False)
689
model[2].requires_grad_(True)
693
output = model(input)
698
state.append(optimizer.state)
699
updated_params.append(model.parameters())
701
og_state, new_state = state
702
for og_p, new_p in zip(updated_params[0], updated_params[1]):
704
tracker.pop_check_set(new_p, self)
707
og_p_state = og_state[og_p]
708
new_p_state = new_state[new_p]
709
if assert_step_dtype is not None:
710
if torch.is_tensor(og_p_state.get("step", None)):
711
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
712
if torch.is_tensor(new_p_state.get("step", None)):
713
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
715
tracker.add(og_p_state[k])
716
tracker.pop_check_set(new_p_state[k], self)
718
self.assertTrue(tracker.all_popped())
720
def _test_derived_optimizers(
726
reduced_precision=False,
727
assert_step_dtype=None,
730
Given a flag 'fused' or 'foreach', test for parity of optimizer state
731
and updated parameters between when the flag is set to True and False
732
for provided optimizer configurations.
734
assert flag in ("foreach", "fused")
735
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
737
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
738
optim_cls = optim_info.optim_cls
739
for optim_input in optim_inputs:
740
models, optimizers = [], []
741
kwargs = deepcopy(optim_input.kwargs)
742
if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
745
for flag_value in (False, True):
746
kwargs[flag] = flag_value
747
input = torch.tensor(
748
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
752
model = torch.nn.Sequential(
753
torch.nn.Linear(2, 3),
755
torch.nn.Linear(3, 1),
758
model.to(dtype=dtype, device=device)
763
empty_param = torch.empty(
764
(), device=device, dtype=dtype, requires_grad=True
766
empty_param.grad = torch.rand_like(empty_param)
767
params = list(model.parameters()) + [empty_param]
769
optimizer = optim_cls(params, **kwargs)
771
optimizers.append(optimizer)
773
self._compare_between(
774
input, models, optimizers, assert_eq_kwargs, assert_step_dtype
779
[optim for optim in optim_db if "foreach" in optim.supported_impls],
780
dtypes=[torch.float64],
782
def test_foreach_matches_forloop(self, device, dtype, optim_info):
783
self._test_derived_optimizers(device, dtype, optim_info, "foreach")
786
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
787
@parametrize("impl", ["foreach", "fused"])
791
for optim in optim_db
792
if "foreach" in optim.supported_impls or "fused" in optim.supported_impls
795
def test_mixed_device_dtype(self, device, dtype, optim_info, impl):
797
Similar in essence to _test_derived_optimizers above. The main difference is that
798
_test_derived_optimizers uses model parameters whereas we randomly pass in
799
parameters of different dtypes and devices here. We need multiple GPUs (vs just a
800
CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests
801
that call into this helper when TEST_MULTIGPU.)
803
assert impl in ("foreach", "fused")
804
if impl == "foreach" and "foreach" not in optim_info.supported_impls:
805
return unittest.skip(
806
f"foreach not supported for {optim_info.optim_cls.__name__}"
808
elif impl == "fused" and "cuda" not in optim_info.supports_fused_on:
809
return unittest.skip(
810
f"fused not supported for {optim_info.optim_cls.__name__} on cuda"
814
torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True),
815
torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True),
816
torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True),
817
torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True),
818
torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True),
819
torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True),
820
torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True),
821
torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True),
823
1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False
829
p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
831
kIterations = 7 if impl == "foreach" else 1
832
optim_inputs = optim_info.optim_inputs_func(device=device)
833
optim_cls = optim_info.optim_cls
834
for optim_input in optim_inputs:
835
updated_params, state = [], []
836
kwargs = deepcopy(optim_input.kwargs)
837
if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
840
for use_impl in (False, True):
841
kwargs[impl] = use_impl
844
p_clone = p.clone().detach()
846
p_clone.requires_grad = True
847
p_clone.grad = p.grad.clone().detach()
848
params_clone.append(p_clone)
850
optimizer = optim_cls(params_clone, **kwargs)
851
for _ in range(kIterations):
854
state.append(optimizer.state)
855
updated_params.append(params_clone)
857
og_state, new_state = state
858
for og_p, new_p in zip(updated_params[0], updated_params[1]):
861
single_rtol, single_atol = torch.testing._comparison.get_tolerances(
862
new_p.dtype, rtol=None, atol=None
864
rtol = 5 * single_rtol
865
atol = 5 * single_atol
867
self.assertEqual(og_p, new_p, rtol=rtol, atol=atol)
870
og_p_state = og_state[og_p]
871
new_p_state = new_state[new_p]
874
actual = new_p_state[k]
875
self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol)
879
[optim for optim in optim_db if "foreach" in optim.supported_impls],
880
dtypes=[torch.float64],
882
def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info):
886
old_default_dtype = torch.get_default_dtype()
887
for default_dtype in [torch.float64, torch.float16]:
889
torch.set_default_dtype(default_dtype)
890
self._test_derived_optimizers(
895
reduced_precision=default_dtype == torch.float16,
898
if default_dtype == torch.float64
903
torch.set_default_dtype(old_default_dtype)
906
@largeTensorTest("72GB", "cuda")
908
[optim for optim in optim_db if "foreach" in optim.supported_impls],
909
dtypes=[torch.float16],
911
def test_foreach_large_tensor(self, device, dtype, optim_info):
912
optim_cls = optim_info.optim_cls
913
optim_inputs = optim_info.optim_inputs_func(device=device)
914
for optim_input in optim_inputs:
915
params = [torch.ones(2**32, device=device, dtype=dtype)]
916
params[0].grad = torch.zeros_like(params[0])
917
optimizer = optim_cls(params, foreach=True, **optim_input.kwargs)
922
[optim for optim in optim_db if "foreach" in optim.supported_impls],
923
dtypes=[torch.float32],
925
def test_peak_memory_foreach(self, device, dtype, optim_info):
927
optim_inputs = optim_info.optim_inputs_func(device=device)
928
optim_cls = optim_info.optim_cls
929
for optim_input in optim_inputs:
930
kwargs = deepcopy(optim_input.kwargs)
932
for flag_value in (False, True):
933
kwargs["foreach"] = flag_value
938
param = torch.rand(16, 8, device=device, dtype=dtype)
939
params = [torch.rand_like(param) for _ in range(nparams)]
941
optimizer = optim_cls(params, **kwargs)
944
p.grad = torch.rand_like(p)
950
torch.cuda.reset_peak_memory_stats()
953
max_mems.append(torch.cuda.max_memory_allocated())
955
st_max_mem, mt_max_mem = max_mems
956
intermediate_size = nparams * param.nelement() * param.element_size()
960
if optimizer.param_groups[0].get(
962
) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
967
if optim_cls.__name__ == "NAdam":
970
if TEST_WITH_TORCHDYNAMO:
977
if optim_cls.__name__ == "RAdam":
980
if TEST_WITH_TORCHDYNAMO:
987
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]:
994
if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False):
1000
if TEST_WITH_TORCHDYNAMO:
1003
expected_max_mem = st_max_mem + intermediate_size * nintermediates
1009
expected_max_mem *= 1.02
1011
self.assertLessEqual(mt_max_mem, expected_max_mem)
1014
[optim for optim in optim_db if "fused" in optim.supported_impls],
1015
dtypes=floating_types_and(
1020
def test_fused_matches_forloop(self, device, dtype, optim_info):
1021
if _get_device_type(device) not in optim_info.supports_fused_on:
1023
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1025
if _get_device_type(device) == "mps" and dtype not in (
1029
self.skipTest("MPS supports only torch.float16 and torch.float32")
1030
self._test_derived_optimizers(device, dtype, optim_info, "fused")
1033
[optim for optim in optim_db if "fused" in optim.supported_impls],
1034
dtypes=(torch.float32,),
1036
def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
1037
if _get_device_type(device) not in optim_info.supports_fused_on:
1039
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1042
with torch.device("meta"):
1043
model = torch.nn.Sequential(
1044
torch.nn.Linear(2, 3),
1046
torch.nn.Linear(3, 1),
1050
optimizer = optim_info.optim_cls(model.parameters(), fused=True)
1051
with torch.device("meta"):
1052
for p in model.parameters():
1053
p.grad = torch.rand_like(p)
1055
with self.assertRaisesRegex(
1057
"`fused=True` requires all the params to be floating point Tensors",
1061
optimizer.zero_grad(set_to_none=True)
1062
model.to_empty(device=device)
1063
for p in model.parameters():
1064
p.grad = torch.rand_like(p)
1067
@onlyNativeDeviceTypes
1068
@largeTensorTest("64GB")
1070
[optim for optim in optim_db if "fused" in optim.supported_impls],
1071
dtypes=[torch.float16],
1073
def test_fused_large_tensor(self, device, dtype, optim_info):
1074
if device not in optim_info.supports_fused_on:
1076
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1078
optim_cls = optim_info.optim_cls
1079
optim_inputs = optim_info.optim_inputs_func(device=device)
1080
for optim_input in optim_inputs:
1081
params = [torch.ones(2**32, device=device, dtype=dtype)]
1082
params[0].grad = torch.zeros_like(params[0])
1083
optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1088
[optim for optim in optim_db if "fused" in optim.supported_impls],
1089
dtypes=[torch.float32],
1091
def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
1092
if device not in optim_info.supports_fused_on:
1094
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1096
optim_cls = optim_info.optim_cls
1097
optim_inputs = optim_info.optim_inputs_func(device=device)
1099
for optim_input in optim_inputs:
1100
for no_grad_scale in (False, True):
1102
torch.ones((1,), device=device, dtype=dtype)
1103
for _ in range(num_params)
1105
params_c = [param.clone().detach() for param in params]
1107
p.grad = torch.ones_like(p)
1108
optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1109
optimizer.grad_scale = (
1112
else torch.ones((1,), dtype=dtype, device=device)
1114
optimizer.found_inf = torch.ones((), dtype=dtype, device=device)
1117
if "step" in optimizer.state[p]:
1119
torch.zeros((), dtype=dtype, device=device),
1120
optimizer.state[p]["step"],
1122
self.assertEqual(params, params_c)
1124
@parametrize("impl", ["fused", "capturable"])
1126
[optim for optim in optim_db if "fused" in optim.supported_impls],
1127
dtypes=[torch.float32],
1129
def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
1135
optim_cls = optim_info.optim_cls
1136
opt_name = optim_cls.__name__
1137
if opt_name in ("SGD", "Adagrad") and impl == "capturable":
1139
self.skipTest("SGD does not currently support capturable")
1140
if _get_device_type(device) == "cpu":
1141
self.skipTest("Test is only for non-cpu devices")
1144
and _get_device_type(device) not in optim_info.supports_fused_on
1146
self.skipTest(f"{device} is not supported for fused on {opt_name}")
1147
elif impl == "capturable" and _get_device_type(device) == "mps":
1148
self.skipTest("MPS does not support capturable")
1150
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
1151
for optim_input in cpu_optim_inputs:
1152
param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
1153
optimizer = optim_cls([param], **optim_input.kwargs)
1154
param.grad = torch.rand_like(param)
1156
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
1157
optim_state_dict_cpu["param_groups"][0][impl] = True
1160
optim_input.kwargs[impl] = True
1161
param_device = param.clone().detach().to(device=device)
1162
optimizer_device = optim_cls([param_device], **optim_input.kwargs)
1163
optimizer_device.load_state_dict(optim_state_dict_cpu)
1164
optimizer_device.zero_grad()
1165
param_device.grad = torch.rand_like(param_device)
1166
optimizer_device.step()
1168
@optims(optim_db, dtypes=[torch.float32])
1169
def test_param_groups_weight_decay(self, device, dtype, optim_info):
1170
optim_cls = optim_info.optim_cls
1172
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1173
device, dtype, optim_info, skip=("differentiable",)
1175
for optim_input in all_optim_inputs:
1176
weight_kwargs = optim_input.kwargs
1177
bias_kwargs = deepcopy(optim_input.kwargs)
1178
bias_kwargs["weight_decay"] = 0.0
1180
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1181
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1182
input = torch.randn(5, device=device, dtype=dtype)
1184
optimizer = optim_cls(
1186
dict(params=[weight], **weight_kwargs),
1187
dict(params=[bias], **bias_kwargs),
1191
loss = (weight.mv(input) + bias).pow(2).sum()
1192
initial_value = loss.item()
1194
optimizer.zero_grad()
1195
loss = (weight.mv(input) + bias).pow(2).sum()
1197
if optim_info.only_supports_sparse_grads:
1200
weight.grad = weight.grad.to_sparse()
1201
bias.grad = bias.grad.to_sparse()
1205
if optim_input.kwargs.get("maximize", False):
1206
self.assertGreater(loss.item(), initial_value)
1208
self.assertLess(loss.item(), initial_value)
1210
@optims(optim_db, dtypes=[torch.float32])
1211
def test_param_groups_lr(self, device, dtype, optim_info):
1212
optim_cls = optim_info.optim_cls
1214
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1215
device, dtype, optim_info, skip=("differentiable",)
1217
for optim_input in all_optim_inputs:
1219
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
1220
optim_input.kwargs["lr"] = 1e-3
1221
outer_kwargs = {"lr": 1e-28}
1222
if optim_cls.__name__ == "Rprop":
1224
outer_kwargs["step_sizes"] = (0, 50)
1226
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1227
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1228
irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
1229
irrelevant_clone = irrelevant.clone()
1230
input = torch.randn(5, device=device, dtype=dtype)
1231
optimizer = optim_cls(
1233
dict(params=[weight, bias], **optim_input.kwargs),
1234
dict(params=[irrelevant]),
1239
loss = (weight.mv(input) + bias).pow(2).sum()
1240
initial_value = loss.item()
1242
optimizer.zero_grad()
1243
loss = (weight.mv(input) + bias).pow(2).sum()
1245
irrelevant.grad = torch.rand_like(irrelevant)
1246
if optim_info.only_supports_sparse_grads:
1249
weight.grad = weight.grad.to_sparse()
1250
bias.grad = bias.grad.to_sparse()
1251
irrelevant.grad = irrelevant.grad.to_sparse()
1255
if optim_input.kwargs.get("maximize", False):
1256
self.assertGreater(loss.item(), initial_value)
1258
self.assertLess(loss.item(), initial_value)
1261
self.assertEqual(irrelevant, irrelevant_clone)
1263
@optims(optim_db, dtypes=[torch.float32])
1264
def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
1265
optim_cls = optim_info.optim_cls
1266
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1267
device, dtype, optim_info
1270
torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype)
1273
old_params = [p.clone().detach() for p in params]
1276
return torch.tensor([1], device=device, dtype=dtype)
1278
for optim_input in all_optim_inputs:
1279
optimizer = optim_cls(params, **optim_input.kwargs)
1280
optimizer.step(closure)
1282
@optims(optim_db, dtypes=[torch.float32])
1283
def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info):
1284
optim_cls = optim_info.optim_cls
1285
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1286
device, dtype, optim_info
1288
param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
1289
old_param = param.clone().detach()
1292
return torch.tensor([1], device=device, dtype=dtype)
1294
for optim_input in all_optim_inputs:
1295
kwargs = optim_input.kwargs
1299
if kwargs.get("weight_decay", 0) != 0:
1303
if optim_cls.__name__ == "AdamW":
1306
if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
1310
if kwargs.get("differentiable", False):
1311
params = [param.clone()]
1315
optimizer = optim_cls(params, **kwargs)
1316
if optim_info.only_supports_sparse_grads:
1320
i = torch.empty((1, 0), device=device, dtype=dtype)
1321
v = torch.empty((0, 1), device=device, dtype=dtype)
1322
params[0].grad = torch.sparse_coo_tensor(
1323
i, v, (5, 1), device=device, dtype=dtype
1326
params[0].grad = torch.zeros_like(params[0])
1327
optimizer.step(closure)
1328
self.assertEqual(old_param, params[0])
1330
@optims(optim_db, dtypes=[torch.float32])
1331
def test_optimizer_can_be_printed(self, device, dtype, optim_info):
1332
optim_cls = optim_info.optim_cls
1333
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1334
device, dtype, optim_info
1337
Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype))
1340
for optim_input in all_optim_inputs:
1341
optimizer = optim_cls(params, **optim_input.kwargs)
1342
optimizer.__repr__()
1344
@optims(optim_db, dtypes=[torch.float32])
1345
def test_state_dict_deterministic(self, device, dtype, optim_info):
1346
optim_cls = optim_info.optim_cls
1349
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1350
device, dtype, optim_info, skip=("differentiable",)
1353
torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1355
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1356
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1357
params = [weight, bias]
1359
def fwd_bwd(optim, w, b, i):
1361
loss = (w.mv(i) + b).pow(2).sum()
1363
if optim_info.only_supports_sparse_grads:
1364
if w.grad is not None:
1365
w.grad = w.grad.to_sparse()
1366
if b.grad is not None:
1367
b.grad = b.grad.to_sparse()
1370
for optim_input in all_optim_inputs:
1371
optimizer = optim_cls(params, **optim_input.kwargs)
1372
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1376
if optim_info.step_requires_closure:
1377
optimizer.step(closure)
1383
with torch.no_grad():
1384
weight_c = Parameter(weight.clone())
1385
bias_c = Parameter(bias.clone())
1387
optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1388
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1391
optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1395
if optim_info.step_requires_closure:
1396
optimizer.step(closure)
1397
optimizer_c.step(closure_c)
1404
self.assertEqual(weight, weight_c)
1405
self.assertEqual(bias, bias_c)
1408
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1411
optimizer_c.param_groups.extend(optimizer_c.param_groups)
1413
optimizer.state_dict()["param_groups"][-1],
1414
optimizer_c.state_dict()["param_groups"][-1],
1417
@optims(optim_db, dtypes=[torch.float32])
1418
def test_can_load_older_state_dict(self, device, dtype, optim_info):
1419
optim_cls = optim_info.optim_cls
1422
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1423
device, dtype, optim_info, skip=("differentiable",)
1425
for optim_input in all_optim_inputs:
1426
torch.manual_seed(1)
1427
model = torch.nn.Sequential(
1428
torch.nn.Conv2d(4, 2, 1, stride=2),
1429
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
1431
model.to(dtype=dtype, device=device)
1432
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
1433
optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
1435
def fwd_bwd(optim, mod, i):
1442
if optim_info.step_requires_closure:
1443
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1445
fwd_bwd(optimizer, model, input)
1449
old_state_dict = deepcopy(optimizer.state_dict())
1450
old_state_dict_pg = old_state_dict["param_groups"]
1451
for group in old_state_dict_pg:
1452
for flag in optim_info.not_og_supported_flags:
1456
optimizer.load_state_dict(old_state_dict)
1459
if optim_info.step_requires_closure:
1460
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1462
fwd_bwd(optimizer, model, input)
1465
@optims(optim_db, dtypes=[torch.float32])
1466
def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
1467
optim_cls = optim_info.optim_cls
1470
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1471
device, dtype, optim_info, skip=("differentiable",)
1474
torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1476
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1477
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1478
params = [weight, bias]
1480
def fwd_bwd(optim, w, b, i):
1482
loss = (w.mv(i) + b).pow(2).sum()
1484
if optim_info.only_supports_sparse_grads:
1485
weight.grad = weight.grad.to_sparse()
1486
bias.grad = bias.grad.to_sparse()
1489
for optim_input in all_optim_inputs:
1490
optimizer = optim_cls(params, **optim_input.kwargs)
1491
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1495
optimizer.step(closure)
1497
sd = optimizer.state_dict()
1500
with tempfile.TemporaryFile() as f:
1503
sd_copy = torch.load(f)
1504
self.assertEqual(sd_copy, sd)
1507
sd_copy_wo = torch.load(f, weights_only=True)
1508
self.assertEqual(sd_copy_wo, sd)
1510
@optims(optim_db, dtypes=[torch.float32])
1511
def test_load_nontensor_step(self, device, dtype, optim_info):
1512
optim_cls = optim_info.optim_cls
1515
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1516
device, dtype, optim_info, skip=("differentiable",)
1519
Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1522
p.grad = torch.rand_like(p)
1523
if optim_info.only_supports_sparse_grads:
1526
p.grad = p.grad.to_sparse()
1529
closure_loss = torch.rand(1, device=device, dtype=dtype)
1532
return closure_loss if optim_info.step_requires_closure else None
1534
for optim_input in all_optim_inputs:
1535
kwargs = optim_input.kwargs
1536
optimizer = optim_cls(params, **optim_input.kwargs)
1538
optimizer.step(closure)
1539
state_dict = deepcopy(optimizer.state_dict())
1540
for p_state in state_dict["state"].values():
1541
if "step" in p_state and torch.is_tensor(p_state["step"]):
1542
p_state["step"] = p_state["step"].item()
1543
optimizer.load_state_dict(state_dict)
1544
optimizer.step(closure)
1547
@optims(optim_db, dtypes=[torch.float32])
1548
def test_state_dict_with_cuda_params(self, device, dtype, optim_info):
1549
optim_cls = optim_info.optim_cls
1553
cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1554
"cpu", dtype, optim_info, skip=("differentiable",)
1558
closure_loss = torch.rand(1, device=device, dtype=dtype)
1561
return closure_loss if optim_info.step_requires_closure else None
1563
for optim_input in cpu_optim_inputs:
1565
"fused" in optim_input.kwargs
1566
and "cuda" not in optim_info.supports_fused_on
1569
f"cuda is not supported for fused on {optim_cls.__name__}"
1572
Parameter(torch.randn(2, 3, device="cpu", dtype=dtype))
1576
p.grad = torch.randn_like(p)
1577
if optim_info.only_supports_sparse_grads:
1580
p.grad = p.grad.to_sparse()
1582
optimizer = optim_cls(params, **optim_input.kwargs)
1585
optimizer.step(closure)
1587
with torch.no_grad():
1588
params_cuda = [p.to(device="cuda") for p in params]
1589
for i, p in enumerate(params_cuda):
1590
p.grad = params[i].grad.to(device="cuda")
1591
optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs)
1593
state_dict_cpu = deepcopy(optimizer.state_dict())
1594
state_dict_cuda = deepcopy(optimizer.state_dict())
1595
optimizer_cuda.load_state_dict(state_dict_cuda)
1598
self.assertEqual(state_dict_cpu, state_dict_cuda)
1601
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
1602
fused = state_dict_cpu["param_groups"][0].get("fused", False)
1603
new_state_dict = optimizer_cuda.state_dict()
1604
for state_cpu, state_cuda in zip(
1605
state_dict_cpu["state"].values(), new_state_dict["state"].values()
1607
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
1609
state_cuda["step"].device.type,
1610
"cuda" if capturable or fused else "cpu",
1614
optimizer.step(closure)
1615
optimizer_cuda.step(closure)
1616
self.assertEqual(params, params_cuda)
1617
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
1620
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
1621
optimizer.state["test"] = 1
1624
def _state_dict_post_hook(
1625
optimizer: Optimizer, state_dict: Dict[str, Any]
1626
) -> Dict[str, Any]:
1627
if "test" in state_dict["state"]:
1628
state_dict["state"].pop("test")
1629
state_dict["ran_state_dict_pre_hook"] = True
1631
state_dict["ran_state_dict_pre_hook"] = False
1634
@optims(optim_db, dtypes=[torch.float32])
1635
def test_state_dict_pre_hook(self, device, dtype, optim_info):
1636
optim_cls = optim_info.optim_cls
1637
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1638
device, dtype, optim_info
1640
for optim_input in all_optim_inputs:
1641
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1642
optim = optim_cls([param], **optim_input.kwargs)
1643
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1644
state_dict = optim.state_dict()
1645
self.assertEqual(state_dict["state"]["test"], 1)
1647
@optims(optim_db, dtypes=[torch.float32])
1648
def test_state_dict_post_hook(self, device, dtype, optim_info):
1649
optim_cls = optim_info.optim_cls
1650
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1651
device, dtype, optim_info
1653
for optim_input in all_optim_inputs:
1654
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1655
optim = optim_cls([param], **optim_input.kwargs)
1656
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1657
state_dict = optim.state_dict()
1658
self.assertFalse(state_dict["ran_state_dict_pre_hook"])
1660
@optims(optim_db, dtypes=[torch.float32])
1661
def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
1662
optim_cls = optim_info.optim_cls
1663
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1664
device, dtype, optim_info
1666
for optim_input in all_optim_inputs:
1667
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1668
optim = optim_cls([param], **optim_input.kwargs)
1669
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1670
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1671
state_dict = optim.state_dict()
1672
self.assertFalse("test" in state_dict["state"])
1673
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
1676
def _load_state_dict_pre_hook1(
1677
optimizer: Optimizer, state_dict: Dict[str, Any]
1679
state_dict["param_groups"][0]["lr"] = 0.002
1682
def _load_state_dict_pre_hook2(
1683
optimizer: Optimizer, state_dict: Dict[str, Any]
1684
) -> Dict[str, Any]:
1687
my_state_dict = deepcopy(state_dict)
1688
my_state_dict["param_groups"][0]["lr"] = 0.003
1689
return my_state_dict
1692
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
1693
optimizer.state["ran_load_state_dict_pre_hook2"] = (
1694
optimizer.param_groups[0]["lr"] == 0.003
1696
optimizer.state["ran_load_state_dict_post_hook"] = True
1698
@optims(optim_db, dtypes=[torch.float32])
1699
def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
1700
optim_cls = optim_info.optim_cls
1701
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1702
device, dtype, optim_info
1704
for optim_input in all_optim_inputs:
1705
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1706
optim = optim_cls([param], **optim_input.kwargs)
1707
state_dict = optim.state_dict()
1710
optim.register_load_state_dict_pre_hook(
1711
self.__class__._load_state_dict_pre_hook1
1713
optim.load_state_dict(state_dict)
1714
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1716
optim.register_load_state_dict_pre_hook(
1717
self.__class__._load_state_dict_pre_hook2, prepend=True
1719
optim.load_state_dict(state_dict)
1721
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1723
@optims(optim_db, dtypes=[torch.float32])
1724
def test_load_state_dict_post_hook(self, device, dtype, optim_info):
1725
optim_cls = optim_info.optim_cls
1726
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1727
device, dtype, optim_info
1729
for optim_input in all_optim_inputs:
1730
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1731
optim = optim_cls([param], **optim_input.kwargs)
1733
optim.register_load_state_dict_post_hook(
1734
self.__class__._load_state_dict_post_hook
1736
optim.load_state_dict(optim.state_dict())
1737
self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
1738
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1740
@optims(optim_db, dtypes=[torch.float32])
1741
def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
1742
optim_cls = optim_info.optim_cls
1743
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1744
device, dtype, optim_info
1746
for optim_input in all_optim_inputs:
1747
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1748
optim = optim_cls([param], **optim_input.kwargs)
1750
optim.register_load_state_dict_pre_hook(
1751
self.__class__._load_state_dict_pre_hook2
1753
optim.register_load_state_dict_post_hook(
1754
self.__class__._load_state_dict_post_hook
1756
optim.load_state_dict(optim.state_dict())
1757
self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
1758
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1760
@optims(optim_db, dtypes=[torch.float32])
1761
def test_step_post_hook(self, device, dtype, optim_info):
1762
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1766
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1768
def dummy_closure():
1771
closure = dummy_closure if optim_info.step_requires_closure else None
1773
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1774
device, dtype, optim_info
1776
for optim_input in all_optim_inputs:
1777
optim = optim_info.optim_cls(params, **optim_input.kwargs)
1779
hook_handle = optim.register_step_post_hook(post_hook)
1784
self.assertEqual(data, 6)
1787
hook_handle.remove()
1790
self.assertEqual(data, 6)
1792
@optims(optim_db, dtypes=[torch.float32])
1793
def test_step_pre_hook(self, device, dtype, optim_info):
1794
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1798
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1800
def dummy_closure():
1803
closure = dummy_closure if optim_info.step_requires_closure else None
1805
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1806
device, dtype, optim_info
1808
for optim_input in all_optim_inputs:
1809
optim = optim_info.optim_cls(params, **optim_input.kwargs)
1811
hook_handle = optim.register_step_pre_hook(pre_hook)
1816
self.assertEqual(data, 9)
1819
hook_handle.remove()
1822
self.assertEqual(data, 9)
1824
@optims(optim_db, dtypes=[torch.float32])
1825
def test_step_all_hooks(self, device, dtype, optim_info):
1826
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1830
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1834
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1838
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1842
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1844
def dummy_closure():
1847
closure = dummy_closure if optim_info.step_requires_closure else None
1849
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1850
device, dtype, optim_info
1852
for optim_input in all_optim_inputs:
1853
optim = optim_info.optim_cls(params, **optim_input.kwargs)
1854
optim2 = SGD(params)
1858
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
1859
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
1862
first_pre_handle = optim.register_step_pre_hook(local_pre_hook)
1863
first_post_handle = optim.register_step_post_hook(local_post_hook)
1864
second_pre_handle = optim2.register_step_pre_hook(local_pre_hook)
1865
second_post_handle = optim2.register_step_post_hook(local_post_hook)
1868
self.assertListEqual(data, [0, 1, 2, 5])
1869
optim2.step(closure)
1870
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
1872
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1875
global_pre_handle.remove()
1876
global_post_handle.remove()
1877
first_pre_handle.remove()
1878
first_post_handle.remove()
1879
second_pre_handle.remove()
1880
second_post_handle.remove()
1883
optim2.step(closure)
1884
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1886
@optims(optim_db, dtypes=[torch.float32])
1887
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
1888
optim_cls = optim_info.optim_cls
1891
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1892
device, dtype, optim_info, skip=("differentiable",)
1896
Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1899
p.grad = torch.rand_like(p)
1900
if optim_info.only_supports_sparse_grads:
1903
p.grad = p.grad.to_sparse()
1907
return 1 if optim_info.step_requires_closure else None
1909
def getPublicAttrs(obj):
1910
return {k for k in obj.__dict__ if not k.startswith("_")}
1912
for optim_input in all_optim_inputs:
1913
optimizer = optim_cls(params, **optim_input.kwargs)
1917
if optim_info.step_requires_closure:
1918
optimizer.step(closure)
1924
getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))
1928
[optim for optim in optim_db if optim.step_requires_closure],
1929
dtypes=[torch.float32],
1931
def test_second_order_optims_return_consistent_types(
1932
self, device, dtype, optim_info
1935
optim_cls = optim_info.optim_cls
1937
torch.randn(10, 5, device=device, dtype=dtype),
1938
torch.randn(10, device=device, dtype=dtype),
1942
return torch.tensor([10], device=device, dtype=dtype)
1944
for optim_input in optim_info.optim_inputs_func(device=device):
1947
kwargs = optim_input.kwargs
1948
kwargs["tolerance_grad"] = math.inf
1949
optim_inf = optim_cls(params, **kwargs)
1950
kwargs["tolerance_grad"] = -math.inf
1951
optim_neg_inf = optim_cls(params, **kwargs)
1953
res1 = optim_inf.step(closure)
1954
res2 = optim_neg_inf.step(closure)
1955
self.assertEqual(type(res1), type(res2))
1961
for optim in optim_db
1962
if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on
1964
dtypes=floating_types_and(
1969
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
1970
optim_cls = optim_info.optim_cls
1971
optim_inputs = optim_info.optim_inputs_func(device="cpu")
1972
for optim_input in optim_inputs:
1973
inpts, models, optimizers = [], [], []
1974
for dev in ("cpu", "cuda"):
1975
kwargs = optim_input.kwargs
1976
kwargs["fused"] = True
1977
inpt = torch.tensor(
1978
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
1981
torch.manual_seed(1)
1982
model = torch.nn.Sequential(
1983
torch.nn.Linear(2, 3),
1985
torch.nn.Linear(3, 1),
1988
model.to(dtype=dtype, device=dev)
1993
empty_param = torch.empty(
1994
(), device=dev, dtype=dtype, requires_grad=True
1996
empty_param.grad = torch.rand_like(empty_param)
1997
params = list(model.parameters()) + [empty_param]
1999
optimizer = optim_cls(params, **kwargs)
2001
models.append(model)
2002
optimizers.append(optimizer)
2003
self._compare_between(inpts, models, optimizers)
2010
if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor")
2012
dtypes=[torch.float32],
2014
def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
2017
optim_cls = optim_info.optim_cls
2018
model = torch.nn.Linear(5, 5)
2019
model.to(dtype=dtype, device=device)
2020
inpt = torch.rand(2, 5, dtype=dtype, device=device)
2024
module = inspect.getmodule(optim_cls)
2026
for optim_input in optim_info.optim_inputs_func(device=device):
2027
optim = optim_cls(model.parameters(), **optim_input.kwargs)
2029
output = model(inpt)
2033
module, f"_multi_tensor_{optim_cls.__name__.lower()}"
2034
) as mocked_foreach_impl:
2036
self.assertTrue(mocked_foreach_impl.called)
2038
@optims(optim_db, dtypes=[torch.float32])
2039
def test_non_empty_state(self, device, dtype, optim_info):
2041
optim_cls = optim_info.optim_cls
2042
model = torch.nn.Linear(5, 5)
2043
model.to(dtype=dtype, device=device)
2044
inpt = torch.rand(2, 5, dtype=dtype, device=device)
2046
for optim_input in optim_info.optim_inputs_func(device=device):
2047
optim = optim_cls(model.parameters(), **optim_input.kwargs)
2049
output = model(inpt)
2053
if optim_info.only_supports_sparse_grads:
2054
for param in model.parameters():
2055
if param.grad is not None:
2056
param.grad = param.grad.to_sparse()
2058
if optim_info.step_requires_closure:
2059
optim.step(lambda: 1.0)
2063
for state in optim.state.values():
2064
self.assertGreater(len(state), 0)
2067
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
2070
if __name__ == "__main__":