4
from torch.optim import (
17
from torch.testing._internal.common_utils import (
27
load_tests = load_tests
30
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
37
opt_differentiable_state = {
38
k: v.clone() if isinstance(v, torch.Tensor) else v
39
for k, v in opt_differentiable_state.items()
41
opt = opt_class([p], **kwargs)
42
opt.state[p].update(opt_differentiable_state)
46
for v in opt.state[p].values()
47
if isinstance(v, torch.Tensor) and v.requires_grad
51
@skipIfTorchDynamo("Differentiable optimizers not supported")
52
class TestDifferentiableOptimizer(TestCase):
54
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
55
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
56
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
57
state = {"momentum_buffer": mbuff}
65
{"lr": 0.9, "differentiable": True},
72
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
73
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
76
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
77
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
78
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
79
state["max_exp_avg_sq"] = torch.rand(
80
10, requires_grad=True, dtype=torch.float64
90
{"lr": 0.9, "differentiable": True, "amsgrad": True},
95
def test_rmsprop(self):
97
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
98
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
99
state["step"] = torch.zeros((), dtype=torch.float64)
100
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
101
state["momentum_buffer"] = torch.rand(
102
10, requires_grad=True, dtype=torch.float64
105
state["grad_avg"] = 1e-2 * torch.rand(
106
10, requires_grad=True, dtype=torch.float64
119
"differentiable": True,
127
def test_adadelta(self):
129
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
130
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
133
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
134
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
135
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
143
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
148
def test_adagrad(self):
150
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
151
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
154
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
155
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
163
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
168
def test_adamax(self):
170
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
171
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
174
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
175
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
176
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
184
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
190
"The inplace mu update fails with dynamo, "
191
"since this is only happening when differentiable is enabled, skipping for now"
195
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
196
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
199
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
200
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
201
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
202
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
211
{"lr": 0.9, "differentiable": True},
216
def test_rprop(self):
218
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
219
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
222
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
223
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
224
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
233
{"lr": 0.9, "differentiable": True},
238
def test_adamw(self):
240
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
241
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
244
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
245
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
246
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
247
state["max_exp_avg_sq"] = torch.rand(
248
10, requires_grad=True, dtype=torch.float64
258
{"lr": 0.9, "differentiable": True, "amsgrad": True},
263
def test_nadam(self):
265
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
266
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
269
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
270
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
271
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
272
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
281
{"lr": 0.9, "differentiable": True},
293
{"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
298
def test_radam(self):
300
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
301
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
304
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
305
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
306
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
315
{"lr": 0.9, "differentiable": True},
329
"decoupled_weight_decay": True,
330
"differentiable": True,
337
if __name__ == "__main__":
338
print("These tests should be run through test/test_optim.py instead")