pytorch

Форк
0
/
optim.cpp 
575 строк · 18.5 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
5

6
#include <test/cpp/api/optim_baseline.h>
7
#include <test/cpp/api/support.h>
8

9
#include <cmath>
10
#include <cstdlib>
11
#include <functional>
12
#include <iostream>
13
#include <memory>
14
#include <random>
15
#include <vector>
16

17
using namespace torch::nn;
18
using namespace torch::optim;
19

20
template <typename OptimizerClass, typename Options>
21
bool test_optimizer_xor(Options options) {
22
  torch::manual_seed(0);
23

24
  Sequential model(
25
      Linear(2, 8),
26
      Functional(torch::sigmoid),
27
      Linear(8, 1),
28
      Functional(torch::sigmoid));
29

30
  const int64_t kBatchSize = 200;
31
  const int64_t kMaximumNumberOfEpochs = 3000;
32

33
  OptimizerClass optimizer(model->parameters(), options);
34

35
  float running_loss = 1;
36
  int epoch = 0;
37
  while (running_loss > 0.1) {
38
    auto inputs = torch::empty({kBatchSize, 2});
39
    auto labels = torch::empty({kBatchSize});
40
    for (const auto i : c10::irange(kBatchSize)) {
41
      inputs[i] = torch::randint(2, {2}, torch::kInt64);
42
      labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
43
    }
44

45
    inputs.set_requires_grad(true);
46

47
    auto step = [&](OptimizerClass& optimizer,
48
                    Sequential model,
49
                    torch::Tensor inputs,
50
                    torch::Tensor labels) {
51
      auto closure = [&]() {
52
        optimizer.zero_grad();
53
        auto x = model->forward(inputs);
54
        auto loss = torch::binary_cross_entropy(x, labels);
55
        loss.backward();
56
        return loss;
57
      };
58
      return optimizer.step(closure);
59
    };
60

61
    torch::Tensor loss = step(optimizer, model, inputs, labels);
62

63
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64
    running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65
    if (epoch > kMaximumNumberOfEpochs) {
66
      std::cout << "Loss is too high after epoch " << epoch << ": "
67
                << running_loss << std::endl;
68
      return false;
69
    }
70
    epoch++;
71
  }
72
  return true;
73
}
74

75
template <typename Parameters>
76
void assign_parameter(
77
    const Parameters& parameters,
78
    const char* name,
79
    torch::Tensor new_tensor) {
80
  auto parameter = parameters[name];
81
  parameter.set_requires_grad(false);
82
  parameter.flatten().copy_(new_tensor);
83
  parameter.set_requires_grad(true);
84
}
85

86
template <typename OptimizerClass, typename Options>
87
void check_exact_values(
88
    Options options,
89
    std::vector<std::vector<torch::Tensor>> expected_parameters) {
90
  const size_t kIterations = 1001;
91
  const size_t kSampleEvery = 100;
92

93
  torch::manual_seed(0);
94

95
  Sequential model(
96
      Linear(2, 3),
97
      Functional(torch::sigmoid),
98
      Linear(3, 1),
99
      Functional(torch::sigmoid));
100

101
  model->to(torch::kFloat64);
102

103
  // Use exact input values because matching random values is hard.
104
  auto parameters = model->named_parameters();
105
  assign_parameter(
106
      parameters,
107
      "0.weight",
108
      torch::tensor(
109
          {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
110
          torch::kFloat64));
111
  assign_parameter(
112
      parameters,
113
      "0.bias",
114
      torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
115
  assign_parameter(
116
      parameters,
117
      "2.weight",
118
      torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
119
  assign_parameter(
120
      parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
121

122
  auto optimizer = OptimizerClass(parameters.values(), options);
123
  torch::Tensor input =
124
      torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
125
          .reshape({3, 2});
126

127
  for (const auto i : c10::irange(kIterations)) {
128
    optimizer.zero_grad();
129
    auto output = model->forward(input);
130
    auto loss = output.sum();
131
    loss.backward();
132

133
    auto closure = []() { return torch::tensor({10}); };
134
    optimizer.step(closure);
135

136
    if (i % kSampleEvery == 0) {
137
      ASSERT_TRUE(
138
          expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139
      for (const auto p : c10::irange(parameters.size())) {
140
        ASSERT_TRUE(parameters[p]->defined());
141
        // Always compare using double dtype, regardless of the original dtype
142
        // of the tensors
143
        auto computed = parameters[p]->flatten().to(torch::kFloat64);
144
        auto expected =
145
            expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146
        if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147
          std::cout << "Iteration " << i << ": " << computed
148
                    << " != " << expected << " (parameter " << p << ")"
149
                    << std::endl;
150
          ASSERT_TRUE(false);
151
        }
152
      }
153
    }
154
  }
155
}
156

157
TEST(OptimTest, OptimizerAccessors) {
158
  auto options = AdagradOptions(1.0);
159
  std::vector<torch::Tensor> params;
160
  for (const auto i : c10::irange(3)) {
161
    (void)i; // Suppress unused variable warning
162
    params.push_back(torch::randn(10));
163
  }
164
  auto optimizer = Adagrad(params, options);
165
  // test for defaults() method with non-const reference
166
  auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167
  ASSERT_TRUE(options == options_);
168
  // test for param_groups() with non-const reference return
169
  auto& params_groups = optimizer.param_groups();
170
  // NOLINTNEXTLINE(modernize-use-emplace)
171
  params_groups.push_back(OptimizerParamGroup(params));
172
  auto& params_1 = params_groups[1].params();
173
  for (const auto i : c10::irange(params_1.size())) {
174
    torch::equal(params[i], params_1[i]);
175
  }
176

177
  // test for add_param_group() when one or more params existing in another
178
  // param_group are passed in the new param group to be added
179
  ASSERT_THROWS_WITH(
180
      optimizer.add_param_group(OptimizerParamGroup(params)),
181
      "some parameters appear in more than one parameter group");
182

183
  // test for state() with non-const reference return
184
  auto& state_ = static_cast<AdagradParamState&>(
185
      *(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
186
  state_.step(state_.step() + 1);
187

188
  const auto& optimizer_ = Adagrad(params, options);
189
  optimizer_.defaults();
190
  // test for param_groups() with const reference return
191
  (void)optimizer_.param_groups();
192
  // test for state() with const reference return
193
  optimizer_.state();
194
}
195

196
#define OLD_INTERFACE_WARNING_CHECK(func)       \
197
  {                                             \
198
    torch::test::WarningCapture warnings;       \
199
    func;                                       \
200
    ASSERT_EQ(                                  \
201
        torch::test::count_substr_occurrences(  \
202
            warnings.str(), "will be removed"), \
203
        1);                                     \
204
  }
205

206
struct MyOptimizerOptions
207
    : public OptimizerCloneableOptions<MyOptimizerOptions> {
208
  MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
209
  TORCH_ARG(double, lr) = 1.0;
210
};
211

212
TEST(OptimTest, OldInterface) {
213
  struct MyOptimizer : Optimizer {
214
    using Optimizer::Optimizer;
215
    torch::Tensor step(LossClosure closure = nullptr) override {
216
      return {};
217
    }
218
    explicit MyOptimizer(
219
        std::vector<at::Tensor> params,
220
        MyOptimizerOptions defaults = {})
221
        : // NOLINTNEXTLINE(performance-move-const-arg)
222
          Optimizer(
223
              {std::move(OptimizerParamGroup(params))},
224
              std::make_unique<MyOptimizerOptions>(defaults)) {}
225
  };
226
  std::vector<torch::Tensor> parameters = {
227
      torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
228
  {
229
    MyOptimizer optimizer(parameters);
230
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
231
    size_t size;
232
    OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
233
    ASSERT_EQ(size, parameters.size());
234
  }
235
  {
236
    std::vector<at::Tensor> params;
237
    MyOptimizer optimizer(params);
238

239
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
240
    size_t size;
241
    OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
242
    ASSERT_EQ(size, 0);
243

244
    OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
245

246
    OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
247
    ASSERT_EQ(size, parameters.size());
248

249
    std::vector<torch::Tensor> params_;
250
    OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
251
    for (const auto p : c10::irange(size)) {
252
      ASSERT_TRUE(params_[p].allclose(parameters[p]));
253
    }
254
  }
255
  {
256
    Linear linear(3, 4);
257
    MyOptimizer optimizer(linear->parameters());
258

259
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
260
    size_t size;
261
    OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
262
    ASSERT_EQ(size, linear->parameters().size());
263
  }
264
}
265

266
TEST(OptimTest, XORConvergence_SGD) {
267
  ASSERT_TRUE(test_optimizer_xor<SGD>(
268
      SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
269
}
270

271
TEST(OptimTest, XORConvergence_LBFGS) {
272
  ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
273
  ASSERT_TRUE(test_optimizer_xor<LBFGS>(
274
      LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
275
}
276

277
TEST(OptimTest, XORConvergence_Adagrad) {
278
  ASSERT_TRUE(test_optimizer_xor<Adagrad>(
279
      AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
280
}
281

282
TEST(OptimTest, XORConvergence_RMSprop) {
283
  ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
284
}
285

286
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
287
  ASSERT_TRUE(test_optimizer_xor<RMSprop>(
288
      RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
289
}
290

291
TEST(OptimTest, XORConvergence_Adam) {
292
  ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
293
}
294

295
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
296
  ASSERT_TRUE(test_optimizer_xor<Adam>(
297
      AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
298
}
299

300
TEST(OptimTest, ProducesPyTorchValues_Adam) {
301
  check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
302
}
303

304
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
305
  check_exact_values<Adam>(
306
      AdamOptions(1.0).weight_decay(1e-2),
307
      expected_parameters::Adam_with_weight_decay());
308
}
309

310
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
311
  check_exact_values<Adam>(
312
      AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
313
      expected_parameters::Adam_with_weight_decay_and_amsgrad());
314
}
315

316
TEST(OptimTest, XORConvergence_AdamW) {
317
  ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
318
}
319

320
TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
321
  ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
322
}
323

324
TEST(OptimTest, ProducesPyTorchValues_AdamW) {
325
  check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
326
}
327

328
TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
329
  check_exact_values<AdamW>(
330
      AdamWOptions(1.0).weight_decay(0),
331
      expected_parameters::AdamW_without_weight_decay());
332
}
333

334
TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
335
  check_exact_values<AdamW>(
336
      AdamWOptions(1.0).amsgrad(true),
337
      expected_parameters::AdamW_with_amsgrad());
338
}
339

340
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
341
  check_exact_values<Adagrad>(
342
      AdagradOptions(1.0), expected_parameters::Adagrad());
343
}
344

345
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
346
  check_exact_values<Adagrad>(
347
      AdagradOptions(1.0).weight_decay(1e-2),
348
      expected_parameters::Adagrad_with_weight_decay());
349
}
350

351
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
352
  check_exact_values<Adagrad>(
353
      AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
354
      expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
355
}
356

357
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
358
  check_exact_values<RMSprop>(
359
      RMSpropOptions(0.1), expected_parameters::RMSprop());
360
}
361

362
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
363
  check_exact_values<RMSprop>(
364
      RMSpropOptions(0.1).weight_decay(1e-2),
365
      expected_parameters::RMSprop_with_weight_decay());
366
}
367

368
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
369
  check_exact_values<RMSprop>(
370
      RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
371
      expected_parameters::RMSprop_with_weight_decay_and_centered());
372
}
373

374
TEST(
375
    OptimTest,
376
    ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
377
  check_exact_values<RMSprop>(
378
      RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
379
      expected_parameters::
380
          RMSprop_with_weight_decay_and_centered_and_momentum());
381
}
382

383
TEST(OptimTest, ProducesPyTorchValues_SGD) {
384
  check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
385
}
386

387
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
388
  check_exact_values<SGD>(
389
      SGDOptions(0.1).weight_decay(1e-2),
390
      expected_parameters::SGD_with_weight_decay());
391
}
392

393
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
394
  check_exact_values<SGD>(
395
      SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
396
      expected_parameters::SGD_with_weight_decay_and_momentum());
397
}
398

399
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
400
  check_exact_values<SGD>(
401
      SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
402
      expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
403
}
404

405
TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
406
  check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
407
}
408

409
TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
410
  check_exact_values<LBFGS>(
411
      LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
412
      expected_parameters::LBFGS_with_line_search());
413
}
414

415
TEST(OptimTest, ZeroGrad) {
416
  torch::manual_seed(0);
417

418
  Linear model(2, 8);
419
  SGD optimizer(model->parameters(), 0.1);
420

421
  for (const auto& parameter : model->parameters()) {
422
    ASSERT_FALSE(parameter.grad().defined());
423
  }
424

425
  auto output = model->forward(torch::ones({5, 2}));
426
  auto loss = output.sum();
427
  loss.backward();
428

429
  for (const auto& parameter : model->parameters()) {
430
    ASSERT_TRUE(parameter.grad().defined());
431
    ASSERT_GT(parameter.grad().sum().item<float>(), 0);
432
  }
433

434
  optimizer.zero_grad();
435

436
  for (const auto& parameter : model->parameters()) {
437
    ASSERT_FALSE(parameter.grad().defined());
438
  }
439
}
440

441
TEST(OptimTest, ExternalVectorOfParameters) {
442
  torch::manual_seed(0);
443

444
  std::vector<torch::Tensor> parameters = {
445
      torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
446
  std::vector<torch::Tensor> original_parameters = {
447
      parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
448

449
  // Set all gradients to one
450
  for (auto& parameter : parameters) {
451
    parameter.mutable_grad() = torch::ones_like(parameter);
452
  }
453

454
  SGD optimizer(parameters, 1.0);
455

456
  optimizer.step();
457

458
  ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
459
  ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
460
  ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
461
}
462

463
TEST(OptimTest, AddParameter_LBFGS) {
464
  torch::manual_seed(0);
465

466
  std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
467
  std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
468

469
  // Set all gradients to one
470
  for (auto& parameter : parameters) {
471
    parameter.mutable_grad() = torch::ones_like(parameter);
472
  }
473

474
  LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
475
  OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
476

477
  optimizer.step([]() { return torch::tensor(1); });
478

479
  // REQUIRE this doesn't throw
480
}
481

482
// Check whether the learning rate of the parameter groups in the optimizer are
483
// the same as the expected learning rates given in the epoch:learning rate map
484
void check_lr_change(
485
    Optimizer& optimizer,
486
    LRScheduler& lr_scheduler,
487
    std::map<unsigned, double> expected_epoch_lrs) {
488
  // Find maximum epoch in map
489
  unsigned kIterations = std::max_element(
490
                             expected_epoch_lrs.begin(),
491
                             expected_epoch_lrs.end(),
492
                             [](const std::pair<unsigned, double>& a,
493
                                const std::pair<unsigned, double>& b) -> bool {
494
                               return a.second > b.second;
495
                             })
496
                             ->first;
497

498
  for (unsigned i = 0; i <= kIterations; i++) {
499
    const auto epoch_iter = expected_epoch_lrs.find(i);
500
    if (epoch_iter != expected_epoch_lrs.end()) {
501
      // Compare the similarity of the two floating point learning rates
502
      ASSERT_TRUE(
503
          fabs(
504
              epoch_iter->second -
505
              optimizer.param_groups()[0].options().get_lr()) <
506
          std::numeric_limits<double>::epsilon());
507
    }
508
    optimizer.step();
509
    lr_scheduler.step();
510
  }
511
}
512

513
// Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
514
// which does not inherit from LRScheduler and requires a metrics
515
// input to step().
516
void check_lr_change_for_reduce_on_plateau(
517
    Optimizer& optimizer,
518
    ReduceLROnPlateauScheduler& lr_scheduler,
519
    std::map<unsigned, double> expected_epoch_lrs) {
520
  // Find maximum epoch in map
521
  unsigned kIterations = std::max_element(
522
                             expected_epoch_lrs.begin(),
523
                             expected_epoch_lrs.end(),
524
                             [](const std::pair<unsigned, double>& a,
525
                                const std::pair<unsigned, double>& b) -> bool {
526
                               return a.second > b.second;
527
                             })
528
                             ->first;
529

530
  for (unsigned i = 0; i <= kIterations; i++) {
531
    const auto epoch_iter = expected_epoch_lrs.find(i);
532
    if (epoch_iter != expected_epoch_lrs.end()) {
533
      // Compare the similarity of the two floating point learning rates
534
      ASSERT_TRUE(
535
          fabs(
536
              epoch_iter->second -
537
              optimizer.param_groups()[0].options().get_lr()) <
538
          std::numeric_limits<double>::epsilon());
539
    }
540
    optimizer.step();
541
    lr_scheduler.step(5.0);
542
  }
543
}
544

545
TEST(OptimTest, CheckLRChange_StepLR_Adam) {
546
  torch::Tensor parameters = torch::zeros({1});
547
  auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
548

549
  const unsigned step_size = 20;
550
  const double gamma = 0.5;
551
  StepLR step_lr_scheduler(optimizer, step_size, gamma);
552

553
  // The learning rate should have halved at epoch 20
554
  const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
555

556
  check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
557
}
558

559
TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
560
  torch::Tensor parameters = torch::zeros({1});
561
  auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
562
  const float factor = 0.5;
563
  const int patience = 20;
564
  ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
565
      optimizer,
566
      ReduceLROnPlateauScheduler::SchedulerMode::min,
567
      factor,
568
      patience);
569

570
  // The learning rate should have halved at epoch 20
571
  const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
572

573
  check_lr_change_for_reduce_on_plateau(
574
      optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
575
}
576

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.