1
#include <gtest/gtest.h>
3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
6
#include <test/cpp/api/optim_baseline.h>
7
#include <test/cpp/api/support.h>
17
using namespace torch::nn;
18
using namespace torch::optim;
20
template <typename OptimizerClass, typename Options>
21
bool test_optimizer_xor(Options options) {
22
torch::manual_seed(0);
26
Functional(torch::sigmoid),
28
Functional(torch::sigmoid));
30
const int64_t kBatchSize = 200;
31
const int64_t kMaximumNumberOfEpochs = 3000;
33
OptimizerClass optimizer(model->parameters(), options);
35
float running_loss = 1;
37
while (running_loss > 0.1) {
38
auto inputs = torch::empty({kBatchSize, 2});
39
auto labels = torch::empty({kBatchSize});
40
for (const auto i : c10::irange(kBatchSize)) {
41
inputs[i] = torch::randint(2, {2}, torch::kInt64);
42
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
45
inputs.set_requires_grad(true);
47
auto step = [&](OptimizerClass& optimizer,
50
torch::Tensor labels) {
51
auto closure = [&]() {
52
optimizer.zero_grad();
53
auto x = model->forward(inputs);
54
auto loss = torch::binary_cross_entropy(x, labels);
58
return optimizer.step(closure);
61
torch::Tensor loss = step(optimizer, model, inputs, labels);
63
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65
if (epoch > kMaximumNumberOfEpochs) {
66
std::cout << "Loss is too high after epoch " << epoch << ": "
67
<< running_loss << std::endl;
75
template <typename Parameters>
77
const Parameters& parameters,
79
torch::Tensor new_tensor) {
80
auto parameter = parameters[name];
81
parameter.set_requires_grad(false);
82
parameter.flatten().copy_(new_tensor);
83
parameter.set_requires_grad(true);
86
template <typename OptimizerClass, typename Options>
87
void check_exact_values(
89
std::vector<std::vector<torch::Tensor>> expected_parameters) {
90
const size_t kIterations = 1001;
91
const size_t kSampleEvery = 100;
93
torch::manual_seed(0);
97
Functional(torch::sigmoid),
99
Functional(torch::sigmoid));
101
model->to(torch::kFloat64);
103
// Use exact input values because matching random values is hard.
104
auto parameters = model->named_parameters();
109
{-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
114
torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
118
torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
120
parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
122
auto optimizer = OptimizerClass(parameters.values(), options);
123
torch::Tensor input =
124
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
127
for (const auto i : c10::irange(kIterations)) {
128
optimizer.zero_grad();
129
auto output = model->forward(input);
130
auto loss = output.sum();
133
auto closure = []() { return torch::tensor({10}); };
134
optimizer.step(closure);
136
if (i % kSampleEvery == 0) {
138
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139
for (const auto p : c10::irange(parameters.size())) {
140
ASSERT_TRUE(parameters[p]->defined());
141
// Always compare using double dtype, regardless of the original dtype
143
auto computed = parameters[p]->flatten().to(torch::kFloat64);
145
expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147
std::cout << "Iteration " << i << ": " << computed
148
<< " != " << expected << " (parameter " << p << ")"
157
TEST(OptimTest, OptimizerAccessors) {
158
auto options = AdagradOptions(1.0);
159
std::vector<torch::Tensor> params;
160
for (const auto i : c10::irange(3)) {
161
(void)i; // Suppress unused variable warning
162
params.push_back(torch::randn(10));
164
auto optimizer = Adagrad(params, options);
165
// test for defaults() method with non-const reference
166
auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167
ASSERT_TRUE(options == options_);
168
// test for param_groups() with non-const reference return
169
auto& params_groups = optimizer.param_groups();
170
// NOLINTNEXTLINE(modernize-use-emplace)
171
params_groups.push_back(OptimizerParamGroup(params));
172
auto& params_1 = params_groups[1].params();
173
for (const auto i : c10::irange(params_1.size())) {
174
torch::equal(params[i], params_1[i]);
177
// test for add_param_group() when one or more params existing in another
178
// param_group are passed in the new param group to be added
180
optimizer.add_param_group(OptimizerParamGroup(params)),
181
"some parameters appear in more than one parameter group");
183
// test for state() with non-const reference return
184
auto& state_ = static_cast<AdagradParamState&>(
185
*(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
186
state_.step(state_.step() + 1);
188
const auto& optimizer_ = Adagrad(params, options);
189
optimizer_.defaults();
190
// test for param_groups() with const reference return
191
(void)optimizer_.param_groups();
192
// test for state() with const reference return
196
#define OLD_INTERFACE_WARNING_CHECK(func) \
198
torch::test::WarningCapture warnings; \
201
torch::test::count_substr_occurrences( \
202
warnings.str(), "will be removed"), \
206
struct MyOptimizerOptions
207
: public OptimizerCloneableOptions<MyOptimizerOptions> {
208
MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
209
TORCH_ARG(double, lr) = 1.0;
212
TEST(OptimTest, OldInterface) {
213
struct MyOptimizer : Optimizer {
214
using Optimizer::Optimizer;
215
torch::Tensor step(LossClosure closure = nullptr) override {
218
explicit MyOptimizer(
219
std::vector<at::Tensor> params,
220
MyOptimizerOptions defaults = {})
221
: // NOLINTNEXTLINE(performance-move-const-arg)
223
{std::move(OptimizerParamGroup(params))},
224
std::make_unique<MyOptimizerOptions>(defaults)) {}
226
std::vector<torch::Tensor> parameters = {
227
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
229
MyOptimizer optimizer(parameters);
230
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
232
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
233
ASSERT_EQ(size, parameters.size());
236
std::vector<at::Tensor> params;
237
MyOptimizer optimizer(params);
239
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
241
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
244
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
246
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
247
ASSERT_EQ(size, parameters.size());
249
std::vector<torch::Tensor> params_;
250
OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
251
for (const auto p : c10::irange(size)) {
252
ASSERT_TRUE(params_[p].allclose(parameters[p]));
257
MyOptimizer optimizer(linear->parameters());
259
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
261
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
262
ASSERT_EQ(size, linear->parameters().size());
266
TEST(OptimTest, XORConvergence_SGD) {
267
ASSERT_TRUE(test_optimizer_xor<SGD>(
268
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
271
TEST(OptimTest, XORConvergence_LBFGS) {
272
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
273
ASSERT_TRUE(test_optimizer_xor<LBFGS>(
274
LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
277
TEST(OptimTest, XORConvergence_Adagrad) {
278
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
279
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
282
TEST(OptimTest, XORConvergence_RMSprop) {
283
ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
286
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
287
ASSERT_TRUE(test_optimizer_xor<RMSprop>(
288
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
291
TEST(OptimTest, XORConvergence_Adam) {
292
ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
295
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
296
ASSERT_TRUE(test_optimizer_xor<Adam>(
297
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
300
TEST(OptimTest, ProducesPyTorchValues_Adam) {
301
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
304
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
305
check_exact_values<Adam>(
306
AdamOptions(1.0).weight_decay(1e-2),
307
expected_parameters::Adam_with_weight_decay());
310
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
311
check_exact_values<Adam>(
312
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
313
expected_parameters::Adam_with_weight_decay_and_amsgrad());
316
TEST(OptimTest, XORConvergence_AdamW) {
317
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
320
TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
321
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
324
TEST(OptimTest, ProducesPyTorchValues_AdamW) {
325
check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
328
TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
329
check_exact_values<AdamW>(
330
AdamWOptions(1.0).weight_decay(0),
331
expected_parameters::AdamW_without_weight_decay());
334
TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
335
check_exact_values<AdamW>(
336
AdamWOptions(1.0).amsgrad(true),
337
expected_parameters::AdamW_with_amsgrad());
340
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
341
check_exact_values<Adagrad>(
342
AdagradOptions(1.0), expected_parameters::Adagrad());
345
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
346
check_exact_values<Adagrad>(
347
AdagradOptions(1.0).weight_decay(1e-2),
348
expected_parameters::Adagrad_with_weight_decay());
351
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
352
check_exact_values<Adagrad>(
353
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
354
expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
357
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
358
check_exact_values<RMSprop>(
359
RMSpropOptions(0.1), expected_parameters::RMSprop());
362
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
363
check_exact_values<RMSprop>(
364
RMSpropOptions(0.1).weight_decay(1e-2),
365
expected_parameters::RMSprop_with_weight_decay());
368
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
369
check_exact_values<RMSprop>(
370
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
371
expected_parameters::RMSprop_with_weight_decay_and_centered());
376
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
377
check_exact_values<RMSprop>(
378
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
379
expected_parameters::
380
RMSprop_with_weight_decay_and_centered_and_momentum());
383
TEST(OptimTest, ProducesPyTorchValues_SGD) {
384
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
387
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
388
check_exact_values<SGD>(
389
SGDOptions(0.1).weight_decay(1e-2),
390
expected_parameters::SGD_with_weight_decay());
393
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
394
check_exact_values<SGD>(
395
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
396
expected_parameters::SGD_with_weight_decay_and_momentum());
399
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
400
check_exact_values<SGD>(
401
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
402
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
405
TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
406
check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
409
TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
410
check_exact_values<LBFGS>(
411
LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
412
expected_parameters::LBFGS_with_line_search());
415
TEST(OptimTest, ZeroGrad) {
416
torch::manual_seed(0);
419
SGD optimizer(model->parameters(), 0.1);
421
for (const auto& parameter : model->parameters()) {
422
ASSERT_FALSE(parameter.grad().defined());
425
auto output = model->forward(torch::ones({5, 2}));
426
auto loss = output.sum();
429
for (const auto& parameter : model->parameters()) {
430
ASSERT_TRUE(parameter.grad().defined());
431
ASSERT_GT(parameter.grad().sum().item<float>(), 0);
434
optimizer.zero_grad();
436
for (const auto& parameter : model->parameters()) {
437
ASSERT_FALSE(parameter.grad().defined());
441
TEST(OptimTest, ExternalVectorOfParameters) {
442
torch::manual_seed(0);
444
std::vector<torch::Tensor> parameters = {
445
torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
446
std::vector<torch::Tensor> original_parameters = {
447
parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
449
// Set all gradients to one
450
for (auto& parameter : parameters) {
451
parameter.mutable_grad() = torch::ones_like(parameter);
454
SGD optimizer(parameters, 1.0);
458
ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
459
ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
460
ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
463
TEST(OptimTest, AddParameter_LBFGS) {
464
torch::manual_seed(0);
466
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
467
std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
469
// Set all gradients to one
470
for (auto& parameter : parameters) {
471
parameter.mutable_grad() = torch::ones_like(parameter);
474
LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
475
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
477
optimizer.step([]() { return torch::tensor(1); });
479
// REQUIRE this doesn't throw
482
// Check whether the learning rate of the parameter groups in the optimizer are
483
// the same as the expected learning rates given in the epoch:learning rate map
485
Optimizer& optimizer,
486
LRScheduler& lr_scheduler,
487
std::map<unsigned, double> expected_epoch_lrs) {
488
// Find maximum epoch in map
489
unsigned kIterations = std::max_element(
490
expected_epoch_lrs.begin(),
491
expected_epoch_lrs.end(),
492
[](const std::pair<unsigned, double>& a,
493
const std::pair<unsigned, double>& b) -> bool {
494
return a.second > b.second;
498
for (unsigned i = 0; i <= kIterations; i++) {
499
const auto epoch_iter = expected_epoch_lrs.find(i);
500
if (epoch_iter != expected_epoch_lrs.end()) {
501
// Compare the similarity of the two floating point learning rates
505
optimizer.param_groups()[0].options().get_lr()) <
506
std::numeric_limits<double>::epsilon());
513
// Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
514
// which does not inherit from LRScheduler and requires a metrics
516
void check_lr_change_for_reduce_on_plateau(
517
Optimizer& optimizer,
518
ReduceLROnPlateauScheduler& lr_scheduler,
519
std::map<unsigned, double> expected_epoch_lrs) {
520
// Find maximum epoch in map
521
unsigned kIterations = std::max_element(
522
expected_epoch_lrs.begin(),
523
expected_epoch_lrs.end(),
524
[](const std::pair<unsigned, double>& a,
525
const std::pair<unsigned, double>& b) -> bool {
526
return a.second > b.second;
530
for (unsigned i = 0; i <= kIterations; i++) {
531
const auto epoch_iter = expected_epoch_lrs.find(i);
532
if (epoch_iter != expected_epoch_lrs.end()) {
533
// Compare the similarity of the two floating point learning rates
537
optimizer.param_groups()[0].options().get_lr()) <
538
std::numeric_limits<double>::epsilon());
541
lr_scheduler.step(5.0);
545
TEST(OptimTest, CheckLRChange_StepLR_Adam) {
546
torch::Tensor parameters = torch::zeros({1});
547
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
549
const unsigned step_size = 20;
550
const double gamma = 0.5;
551
StepLR step_lr_scheduler(optimizer, step_size, gamma);
553
// The learning rate should have halved at epoch 20
554
const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
556
check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
559
TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
560
torch::Tensor parameters = torch::zeros({1});
561
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
562
const float factor = 0.5;
563
const int patience = 20;
564
ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
566
ReduceLROnPlateauScheduler::SchedulerMode::min,
570
// The learning rate should have halved at epoch 20
571
const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
573
check_lr_change_for_reduce_on_plateau(
574
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);