pytorch

Форк
0
/
serialize.cpp 
1094 строки · 37.1 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/flat_hash_map.h>
4
#include <c10/util/irange.h>
5
#include <c10/util/tempfile.h>
6

7
#include <torch/torch.h>
8

9
#include <test/cpp/api/support.h>
10

11
#include <cstdio>
12
#include <memory>
13
#include <sstream>
14
#include <string>
15
#include <vector>
16

17
using namespace torch::test;
18
using namespace torch::nn;
19
using namespace torch::optim;
20

21
namespace {
22
Sequential xor_model() {
23
  return Sequential(
24
      Linear(2, 8),
25
      Functional(at::sigmoid),
26
      Linear(8, 1),
27
      Functional(at::sigmoid));
28
}
29

30
torch::Tensor save_and_load(torch::Tensor input) {
31
  std::stringstream stream;
32
  torch::save(input, stream);
33
  torch::Tensor tensor;
34
  torch::load(tensor, stream);
35
  return tensor;
36
}
37
} // namespace
38

39
template <typename DerivedOptions>
40
void is_optimizer_param_group_equal(
41
    const OptimizerParamGroup& lhs,
42
    const OptimizerParamGroup& rhs) {
43
  const auto& lhs_params = lhs.params();
44
  const auto& rhs_params = rhs.params();
45

46
  ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47
  for (const auto j : c10::irange(lhs_params.size())) {
48
    ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
49
  }
50
  ASSERT_TRUE(
51
      static_cast<const DerivedOptions&>(lhs.options()) ==
52
      static_cast<const DerivedOptions&>(rhs.options()));
53
}
54

55
template <typename DerivedOptimizerParamState>
56
void is_optimizer_state_equal(
57
    const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
58
        lhs_state,
59
    const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
60
        rhs_state) {
61
  ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62
  for (const auto& value : lhs_state) {
63
    auto found = rhs_state.find(value.first);
64
    ASSERT_TRUE(found != rhs_state.end());
65
    const DerivedOptimizerParamState& lhs_curr_state =
66
        static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67
    const DerivedOptimizerParamState& rhs_curr_state =
68
        static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69
    ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
70
  }
71
}
72

73
template <
74
    typename OptimizerClass,
75
    typename DerivedOptimizerOptions,
76
    typename DerivedOptimizerParamState>
77
void test_serialize_optimizer(
78
    DerivedOptimizerOptions options,
79
    bool only_has_global_state = false) {
80
  torch::manual_seed(0);
81
  auto model1 = Linear(5, 2);
82
  auto model2 = Linear(5, 2);
83
  auto model3 = Linear(5, 2);
84

85
  // Models 1, 2, 3 will have the same parameters.
86
  auto model_tempfile = c10::make_tempfile();
87
  torch::save(model1, model_tempfile.name);
88
  torch::load(model2, model_tempfile.name);
89
  torch::load(model3, model_tempfile.name);
90

91
  auto param1 = model1->named_parameters();
92
  auto param2 = model2->named_parameters();
93
  auto param3 = model3->named_parameters();
94
  for (const auto& p : param1) {
95
    ASSERT_TRUE(p->allclose(param2[p.key()]));
96
    ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
97
  }
98
  // Make some optimizers
99
  auto optim1 = OptimizerClass(
100
      {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101
  auto optim2 = OptimizerClass(model2->parameters(), options);
102
  auto optim2_2 = OptimizerClass(model2->parameters(), options);
103
  auto optim3 = OptimizerClass(model3->parameters(), options);
104
  auto optim3_2 = OptimizerClass(model3->parameters(), options);
105
  for (auto& param_group : optim3_2.param_groups()) {
106
    const double lr = param_group.options().get_lr();
107
    // change the learning rate, which will be overwritten by the loading
108
    // otherwise, test cannot check if options are saved and loaded correctly
109
    param_group.options().set_lr(lr + 0.01);
110
  }
111

112
  auto x = torch::ones({10, 5});
113

114
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
115
    optimizer.zero_grad();
116
    auto y = model->forward(x).sum();
117
    y.backward();
118
    auto closure = []() { return torch::tensor({10}); };
119
    optimizer.step(closure);
120
  };
121

122
  // Do 2 steps of model1
123
  step(optim1, model1);
124
  step(optim1, model1);
125

126
  // Do 2 steps of model 2 without saving the optimizer
127
  step(optim2, model2);
128
  step(optim2_2, model2);
129

130
  // Do 1 step of model 3
131
  step(optim3, model3);
132

133
  // save the optimizer
134
  auto optim_tempfile = c10::make_tempfile();
135
  torch::save(optim3, optim_tempfile.name);
136
  torch::load(optim3_2, optim_tempfile.name);
137

138
  auto& optim3_2_param_groups = optim3_2.param_groups();
139
  auto& optim3_param_groups = optim3.param_groups();
140
  auto& optim3_2_state = optim3_2.state();
141
  auto& optim3_state = optim3.state();
142

143
  // optim3_2 and optim1 should have param_groups and state of size 1 and
144
  // state_size respectively
145
  ASSERT_TRUE(optim3_2_param_groups.size() == 1);
146
  // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
147
  // global state
148
  unsigned state_size = only_has_global_state ? 1 : 2;
149
  ASSERT_TRUE(optim3_2_state.size() == state_size);
150

151
  // optim3_2 and optim1 should have param_groups and state of same size
152
  ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
153
  ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
154

155
  // checking correctness of serialization logic for optimizer.param_groups_ and
156
  // optimizer.state_
157
  for (const auto i : c10::irange(optim3_2_param_groups.size())) {
158
    is_optimizer_param_group_equal<DerivedOptimizerOptions>(
159
        optim3_2_param_groups[i], optim3_param_groups[i]);
160
    is_optimizer_state_equal<DerivedOptimizerParamState>(
161
        optim3_2_state, optim3_state);
162
  }
163

164
  // Do step2 for model 3
165
  step(optim3_2, model3);
166

167
  param1 = model1->named_parameters();
168
  param2 = model2->named_parameters();
169
  param3 = model3->named_parameters();
170
  for (const auto& p : param1) {
171
    const auto& name = p.key();
172
    // Model 1 and 3 should be the same
173
    ASSERT_TRUE(
174
        param1[name].norm().item<float>() == param3[name].norm().item<float>());
175
    ASSERT_TRUE(
176
        param1[name].norm().item<float>() != param2[name].norm().item<float>());
177
  }
178
}
179

180
/// Utility function to save a value of `int64_t` type.
181
void write_int_value(
182
    torch::serialize::OutputArchive& archive,
183
    const std::string& key,
184
    const int64_t& value) {
185
  archive.write(key, c10::IValue(value));
186
}
187
// Utility function to save a vector of buffers.
188
template <typename BufferContainer>
189
void write_tensors_to_archive(
190
    torch::serialize::OutputArchive& archive,
191
    const std::string& key,
192
    const BufferContainer& buffers) {
193
  archive.write(
194
      key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
195
  for (const auto index : c10::irange(buffers.size())) {
196
    archive.write(
197
        key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
198
  }
199
}
200

201
// Utility function to save a vector of step buffers.
202
void write_step_buffers(
203
    torch::serialize::OutputArchive& archive,
204
    const std::string& key,
205
    const std::vector<int64_t>& steps) {
206
  std::vector<torch::Tensor> tensors;
207
  tensors.reserve(steps.size());
208
  for (const auto& step : steps) {
209
    tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
210
  }
211
  write_tensors_to_archive(archive, key, tensors);
212
}
213

214
#define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
215
  {                                                                          \
216
    WarningCapture warnings;                                                 \
217
    funcname(optimizer, filename);                                           \
218
    ASSERT_EQ(                                                               \
219
        count_substr_occurrences(warnings.str(), "old serialization"), 1);   \
220
  }
221

222
TEST(SerializeTest, KeysFunc) {
223
  auto tempfile = c10::make_tempfile();
224
  torch::serialize::OutputArchive output_archive;
225
  for (const auto i : c10::irange(3)) {
226
    output_archive.write(
227
        "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
228
  }
229
  output_archive.save_to(tempfile.name);
230
  torch::serialize::InputArchive input_archive;
231
  input_archive.load_from(tempfile.name);
232
  std::vector<std::string> keys = input_archive.keys();
233
  ASSERT_EQ(keys.size(), 3);
234
  for (const auto i : c10::irange(keys.size())) {
235
    ASSERT_EQ(keys[i], "element/" + std::to_string(i));
236
  }
237
}
238

239
TEST(SerializeTest, TryReadFunc) {
240
  auto tempfile = c10::make_tempfile();
241
  torch::serialize::OutputArchive output_archive;
242
  for (const auto i : c10::irange(3)) {
243
    output_archive.write(
244
        "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
245
  }
246
  output_archive.save_to(tempfile.name);
247
  torch::serialize::InputArchive input_archive;
248
  input_archive.load_from(tempfile.name);
249
  c10::IValue ivalue;
250
  ASSERT_FALSE(input_archive.try_read("1", ivalue));
251
  ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
252
  ASSERT_EQ(ivalue.toInt(), 1);
253
}
254

255
TEST(SerializeTest, Basic) {
256
  torch::manual_seed(0);
257

258
  auto x = torch::randn({5, 5});
259
  auto y = save_and_load(x);
260

261
  ASSERT_TRUE(y.defined());
262
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
263
  ASSERT_TRUE(x.allclose(y));
264
}
265

266
TEST(SerializeTest, MathBits) {
267
  torch::manual_seed(0);
268

269
  auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
270
  auto x = torch::randn({5, 5}, options);
271
  {
272
    auto expected = torch::conj(x);
273
    auto actual = save_and_load(expected);
274

275
    ASSERT_TRUE(actual.defined());
276
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
277
    ASSERT_TRUE(actual.allclose(expected));
278
  }
279

280
  {
281
    auto expected = torch::_neg_view(x);
282
    auto actual = save_and_load(expected);
283

284
    ASSERT_TRUE(actual.defined());
285
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
286
    ASSERT_TRUE(actual.allclose(expected));
287
  }
288

289
  {
290
    auto expected = torch::conj(torch::_neg_view(x));
291
    auto actual = save_and_load(expected);
292

293
    ASSERT_TRUE(actual.defined());
294
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
295
    ASSERT_TRUE(actual.allclose(expected));
296
  }
297

298
  {
299
    // We don't support serializing `ZeroTensor` as it is not public facing yet.
300
    // If in future, `ZeroTensor` serialization is supported, this test should
301
    // start failing!
302
    auto t = torch::_efficientzerotensor({5, 5});
303
    ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
304
  }
305
}
306

307
TEST(SerializeTest, BasicToFile) {
308
  torch::manual_seed(0);
309

310
  auto x = torch::randn({5, 5});
311

312
  auto tempfile = c10::make_tempfile();
313
  torch::save(x, tempfile.name);
314

315
  torch::Tensor y;
316
  torch::load(y, tempfile.name);
317

318
  ASSERT_TRUE(y.defined());
319
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
320
  ASSERT_TRUE(x.allclose(y));
321
}
322

323
TEST(SerializeTest, BasicViaFunc) {
324
  torch::manual_seed(0);
325

326
  auto x = torch::randn({5, 5});
327

328
  std::string serialized;
329
  torch::save(x, [&](const void* buf, size_t n) {
330
    serialized.append(reinterpret_cast<const char*>(buf), n);
331
    return n;
332
  });
333
  torch::Tensor y;
334
  torch::load(y, serialized.data(), serialized.size());
335

336
  ASSERT_TRUE(y.defined());
337
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
338
  ASSERT_TRUE(x.allclose(y));
339

340
  torch::Tensor z;
341
  torch::load(
342
      z,
343
      [&](uint64_t pos, void* buf, size_t n) -> size_t {
344
        if (pos >= serialized.size())
345
          return 0;
346
        size_t nbytes =
347
            std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
348
        memcpy(buf, serialized.data() + pos, nbytes);
349
        return nbytes;
350
      },
351
      [&]() -> size_t { return serialized.size(); });
352
  ASSERT_TRUE(z.defined());
353
  ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
354
  ASSERT_TRUE(x.allclose(z));
355
}
356

357
TEST(SerializeTest, Resized) {
358
  torch::manual_seed(0);
359

360
  auto x = torch::randn({11, 5});
361
  x.resize_({5, 5});
362
  auto y = save_and_load(x);
363

364
  ASSERT_TRUE(y.defined());
365
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
366
  ASSERT_TRUE(x.allclose(y));
367
}
368

369
TEST(SerializeTest, Sliced) {
370
  torch::manual_seed(0);
371

372
  auto x = torch::randn({11, 5});
373
  x = x.slice(0, 1, 5);
374
  auto y = save_and_load(x);
375

376
  ASSERT_TRUE(y.defined());
377
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
378
  ASSERT_TRUE(x.allclose(y));
379
}
380

381
TEST(SerializeTest, NonContiguous) {
382
  torch::manual_seed(0);
383

384
  auto x = torch::randn({11, 5});
385
  x = x.slice(1, 1, 4);
386
  auto y = save_and_load(x);
387

388
  ASSERT_TRUE(y.defined());
389
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
390
  ASSERT_TRUE(x.allclose(y));
391
}
392

393
TEST(SerializeTest, ErrorOnMissingKey) {
394
  struct B : torch::nn::Module {
395
    B(const std::string& name_c) {
396
      register_buffer(name_c, torch::ones(5, torch::kFloat));
397
    }
398
  };
399
  struct A : torch::nn::Module {
400
    A(const std::string& name_b, const std::string& name_c) {
401
      register_module(name_b, std::make_shared<B>(name_c));
402
    }
403
  };
404
  struct M : torch::nn::Module {
405
    M(const std::string& name_a,
406
      const std::string& name_b,
407
      const std::string& name_c) {
408
      register_module(name_a, std::make_shared<A>(name_b, name_c));
409
    }
410
  };
411

412
  // create a hierarchy of models with names differing below the top level
413
  auto model1 = std::make_shared<M>("a", "b", "c");
414
  auto model2 = std::make_shared<M>("a", "b", "x");
415
  auto model3 = std::make_shared<M>("a", "x", "c");
416

417
  std::stringstream stream;
418
  torch::save(model1, stream);
419
  // We want the errors to contain hierarchy information, too.
420
  ASSERT_THROWS_WITH(
421
      torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
422
  stream.seekg(0, stream.beg);
423
  ASSERT_THROWS_WITH(
424
      torch::load(model3, stream), "No such serialized submodule: 'a.x'");
425
}
426

427
TEST(SerializeTest, XOR) {
428
  // We better be able to save and load an XOR model!
429
  auto getLoss = [](Sequential model, uint32_t batch_size) {
430
    auto inputs = torch::empty({batch_size, 2});
431
    auto labels = torch::empty({batch_size});
432
    for (const auto i : c10::irange(batch_size)) {
433
      inputs[i] = torch::randint(2, {2}, torch::kInt64);
434
      labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
435
    }
436
    auto x = model->forward<torch::Tensor>(inputs);
437
    return torch::binary_cross_entropy(x, labels);
438
  };
439

440
  auto model = xor_model();
441
  auto model2 = xor_model();
442
  auto model3 = xor_model();
443
  auto optimizer = torch::optim::SGD(
444
      model->parameters(),
445
      torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
446
          1e-6));
447

448
  float running_loss = 1;
449
  int epoch = 0;
450
  while (running_loss > 0.1) {
451
    torch::Tensor loss = getLoss(model, 4);
452
    optimizer.zero_grad();
453
    loss.backward();
454
    optimizer.step();
455

456
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
457
    running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
458
    ASSERT_LT(epoch, 3000);
459
    epoch++;
460
  }
461

462
  auto tempfile = c10::make_tempfile();
463
  torch::save(model, tempfile.name);
464
  torch::load(model2, tempfile.name);
465

466
  auto loss = getLoss(model2, 100);
467
  ASSERT_LT(loss.item<float>(), 0.1);
468
}
469

470
TEST(SerializeTest, Optim) {
471
  auto model1 = Linear(5, 2);
472
  auto model2 = Linear(5, 2);
473
  auto model3 = Linear(5, 2);
474

475
  // Models 1, 2, 3 will have the same parameters.
476
  auto model_tempfile = c10::make_tempfile();
477
  torch::save(model1, model_tempfile.name);
478
  torch::load(model2, model_tempfile.name);
479
  torch::load(model3, model_tempfile.name);
480

481
  auto param1 = model1->named_parameters();
482
  auto param2 = model2->named_parameters();
483
  auto param3 = model3->named_parameters();
484
  for (const auto& p : param1) {
485
    ASSERT_TRUE(p->allclose(param2[p.key()]));
486
    ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
487
  }
488

489
  // Make some optimizers with momentum (and thus state)
490
  auto optim1 = torch::optim::SGD(
491
      model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492
  auto optim2 = torch::optim::SGD(
493
      model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494
  auto optim2_2 = torch::optim::SGD(
495
      model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
496
  auto optim3 = torch::optim::SGD(
497
      model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
498
  auto optim3_2 = torch::optim::SGD(
499
      model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
500

501
  auto x = torch::ones({10, 5});
502

503
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
504
    optimizer.zero_grad();
505
    auto y = model->forward(x).sum();
506
    y.backward();
507
    optimizer.step();
508
  };
509

510
  // Do 2 steps of model1
511
  step(optim1, model1);
512
  step(optim1, model1);
513

514
  // Do 2 steps of model 2 without saving the optimizer
515
  step(optim2, model2);
516
  step(optim2_2, model2);
517

518
  // Do 2 steps of model 3 while saving the optimizer
519
  step(optim3, model3);
520

521
  auto optim_tempfile = c10::make_tempfile();
522
  torch::save(optim3, optim_tempfile.name);
523
  torch::load(optim3_2, optim_tempfile.name);
524
  step(optim3_2, model3);
525

526
  param1 = model1->named_parameters();
527
  param2 = model2->named_parameters();
528
  param3 = model3->named_parameters();
529
  for (const auto& p : param1) {
530
    const auto& name = p.key();
531
    // Model 1 and 3 should be the same
532
    ASSERT_TRUE(
533
        param1[name].norm().item<float>() == param3[name].norm().item<float>());
534
    ASSERT_TRUE(
535
        param1[name].norm().item<float>() != param2[name].norm().item<float>());
536
  }
537
}
538

539
TEST(SerializeTest, Optim_Adagrad) {
540
  test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
541
      AdagradOptions(1e-1));
542

543
  // bc compatibility check
544
  auto model1 = Linear(5, 2);
545
  auto optim1 = torch::optim::Adagrad(
546
      model1->parameters(), torch::optim::AdagradOptions(1e-1));
547

548
  auto x = torch::ones({10, 5});
549
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
550
    optimizer.zero_grad();
551
    auto y = model->forward(x).sum();
552
    y.backward();
553
    optimizer.step();
554
  };
555
  step(optim1, model1);
556
  auto optim1_2 =
557
      Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
558

559
  // fill up with optim1 sum_buffers
560
  std::vector<torch::Tensor> sum_buffers;
561
  // fill up with optim1 state_buffers
562
  std::vector<int64_t> step_buffers;
563
  const auto& params_ = optim1.param_groups()[0].params();
564
  const auto& optim1_state = optim1.state();
565
  for (const auto& param : params_) {
566
    auto key_ = param.unsafeGetTensorImpl();
567
    const AdagradParamState& curr_state_ =
568
        static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
569
    sum_buffers.emplace_back(curr_state_.sum());
570
    step_buffers.emplace_back(curr_state_.step());
571
  }
572
  // write sum_buffers and step_buffers to the file
573
  auto optim_tempfile_old_format = c10::make_tempfile();
574
  torch::serialize::OutputArchive output_archive;
575
  write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
576
  write_step_buffers(output_archive, "step_buffers", step_buffers);
577
  output_archive.save_to(optim_tempfile_old_format.name);
578
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
579
      torch::load, optim1_2, optim_tempfile_old_format.name);
580
  is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
581
}
582

583
TEST(SerializeTest, Optim_SGD) {
584
  test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
585
      SGDOptions(1e-1).momentum(0.9));
586

587
  // bc compatibility check
588
  auto model1 = Linear(5, 2);
589
  auto model1_params = model1->parameters();
590
  // added a tensor for lazy init check - when all params do not have a momentum
591
  // buffer entry
592
  model1_params.emplace_back(torch::randn({2, 3}));
593
  auto optim1 = torch::optim::SGD(
594
      model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
595

596
  auto x = torch::ones({10, 5});
597
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
598
    optimizer.zero_grad();
599
    auto y = model->forward(x).sum();
600
    y.backward();
601
    optimizer.step();
602
  };
603
  step(optim1, model1);
604

605
  std::vector<at::Tensor> momentum_buffers;
606
  int64_t iteration_{0};
607
  const auto& params_ = optim1.param_groups()[0].params();
608
  const auto& optim1_state = optim1.state();
609
  for (const auto i : c10::irange(params_.size())) {
610
    if (i != (params_.size() - 1)) {
611
      auto key_ = params_[i].unsafeGetTensorImpl();
612
      const SGDParamState& curr_state_ =
613
          static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
614
      momentum_buffers.emplace_back(curr_state_.momentum_buffer());
615
    }
616
  }
617
  ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
618
  // write momentum_buffers to the file
619
  auto optim_tempfile_old_format = c10::make_tempfile();
620
  torch::serialize::OutputArchive output_archive;
621
  write_tensors_to_archive(
622
      output_archive, "momentum_buffers", momentum_buffers);
623
  write_int_value(output_archive, "iteration_", iteration_);
624
  output_archive.save_to(optim_tempfile_old_format.name);
625
  auto optim1_2 =
626
      SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
627
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
628
      torch::load, optim1_2, optim_tempfile_old_format.name);
629
  is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
630
}
631

632
TEST(SerializeTest, Optim_Adam) {
633
  test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
634
      AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
635

636
  // bc compatibility check
637
  auto model1 = Linear(5, 2);
638
  auto model1_params = model1->parameters();
639
  // added a tensor for lazy init check - when all params do not have entry in
640
  // buffers
641
  model1_params.emplace_back(torch::randn({2, 3}));
642
  auto optim1 = torch::optim::Adam(
643
      model1_params, torch::optim::AdamOptions().weight_decay(0.5));
644

645
  auto x = torch::ones({10, 5});
646
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
647
    optimizer.zero_grad();
648
    auto y = model->forward(x).sum();
649
    y.backward();
650
    optimizer.step();
651
  };
652
  step(optim1, model1);
653

654
  std::vector<int64_t> step_buffers;
655
  std::vector<at::Tensor> exp_average_buffers;
656
  std::vector<at::Tensor> exp_average_sq_buffers;
657
  std::vector<at::Tensor> max_exp_average_sq_buffers;
658
  const auto& params_ = optim1.param_groups()[0].params();
659
  const auto& optim1_state = optim1.state();
660
  for (const auto i : c10::irange(params_.size())) {
661
    if (i != (params_.size() - 1)) {
662
      auto key_ = params_[i].unsafeGetTensorImpl();
663
      const AdamParamState& curr_state_ =
664
          static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
665
      step_buffers.emplace_back(curr_state_.step());
666
      exp_average_buffers.emplace_back(curr_state_.exp_avg());
667
      exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
668
      if (curr_state_.max_exp_avg_sq().defined()) {
669
        max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
670
      }
671
    }
672
  }
673
  // write buffers to the file
674
  auto optim_tempfile_old_format = c10::make_tempfile();
675
  torch::serialize::OutputArchive output_archive;
676
  write_step_buffers(output_archive, "step_buffers", step_buffers);
677
  write_tensors_to_archive(
678
      output_archive, "exp_average_buffers", exp_average_buffers);
679
  write_tensors_to_archive(
680
      output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
681
  write_tensors_to_archive(
682
      output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
683
  output_archive.save_to(optim_tempfile_old_format.name);
684
  auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
685
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
686
      torch::load, optim1_2, optim_tempfile_old_format.name);
687
  is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
688
}
689

690
TEST(SerializeTest, Optim_AdamW) {
691
  test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
692
      AdamWOptions().lr(0.99999).amsgrad(true).betas(
693
          std::make_tuple(0.999, 0.1)));
694

695
  // bc compatibility check
696
  auto model1 = Linear(5, 2);
697
  auto model1_params = model1->parameters();
698
  // added a tensor for lazy init check - when all params do not have entry in
699
  // buffers
700
  model1_params.emplace_back(torch::randn({2, 3}));
701
  auto optim1 = torch::optim::AdamW(
702
      model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
703

704
  auto x = torch::ones({10, 5});
705
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
706
    optimizer.zero_grad();
707
    auto y = model->forward(x).sum();
708
    y.backward();
709
    optimizer.step();
710
  };
711
  step(optim1, model1);
712

713
  std::vector<int64_t> step_buffers;
714
  std::vector<at::Tensor> exp_average_buffers;
715
  std::vector<at::Tensor> exp_average_sq_buffers;
716
  std::vector<at::Tensor> max_exp_average_sq_buffers;
717
  const auto& params_ = optim1.param_groups()[0].params();
718
  const auto& optim1_state = optim1.state();
719
  for (const auto i : c10::irange(params_.size())) {
720
    if (i != (params_.size() - 1)) {
721
      auto key_ = params_[i].unsafeGetTensorImpl();
722
      const AdamWParamState& curr_state_ =
723
          static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
724
      step_buffers.emplace_back(curr_state_.step());
725
      exp_average_buffers.emplace_back(curr_state_.exp_avg());
726
      exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
727
      if (curr_state_.max_exp_avg_sq().defined()) {
728
        max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
729
      }
730
    }
731
  }
732
  // write buffers to the file
733
  auto optim_tempfile_old_format = c10::make_tempfile();
734
  torch::serialize::OutputArchive output_archive;
735
  write_step_buffers(output_archive, "step_buffers", step_buffers);
736
  write_tensors_to_archive(
737
      output_archive, "exp_average_buffers", exp_average_buffers);
738
  write_tensors_to_archive(
739
      output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
740
  write_tensors_to_archive(
741
      output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
742
  output_archive.save_to(optim_tempfile_old_format.name);
743
  auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
744
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
745
      torch::load, optim1_2, optim_tempfile_old_format.name);
746
  is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
747
}
748

749
TEST(SerializeTest, Optim_RMSprop) {
750
  auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
751
  test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
752

753
  // bc compatibility check
754
  auto model1 = Linear(5, 2);
755
  auto model1_params = model1->parameters();
756

757
  // added a tensor for lazy init check - when all params do not have a momentum
758
  // buffer entry
759
  model1_params.emplace_back(torch::randn({2, 3}));
760
  auto optim1 = torch::optim::RMSprop(model1_params, options);
761

762
  auto x = torch::ones({10, 5});
763
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
764
    optimizer.zero_grad();
765
    auto y = model->forward(x).sum();
766
    y.backward();
767
    optimizer.step();
768
  };
769
  step(optim1, model1);
770

771
  std::vector<at::Tensor> square_average_buffers;
772
  std::vector<at::Tensor> momentum_buffers;
773
  std::vector<at::Tensor> grad_average_buffers;
774
  const auto& params_ = optim1.param_groups()[0].params();
775
  const auto& optim1_state = optim1.state();
776
  for (const auto i : c10::irange(params_.size())) {
777
    if (i != (params_.size() - 1)) {
778
      auto key_ = params_[i].unsafeGetTensorImpl();
779
      const RMSpropParamState& curr_state_ =
780
          static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
781
      square_average_buffers.emplace_back(curr_state_.square_avg());
782
      if (curr_state_.momentum_buffer().defined()) {
783
        momentum_buffers.emplace_back(curr_state_.momentum_buffer());
784
      }
785
      if (curr_state_.grad_avg().defined()) {
786
        grad_average_buffers.emplace_back(curr_state_.grad_avg());
787
      }
788
    }
789
  }
790
  // write buffers to the file
791
  auto optim_tempfile_old_format = c10::make_tempfile();
792
  torch::serialize::OutputArchive output_archive;
793
  write_tensors_to_archive(
794
      output_archive, "square_average_buffers", square_average_buffers);
795
  write_tensors_to_archive(
796
      output_archive, "momentum_buffers", momentum_buffers);
797
  write_tensors_to_archive(
798
      output_archive, "grad_average_buffers", grad_average_buffers);
799
  output_archive.save_to(optim_tempfile_old_format.name);
800
  auto optim1_2 = RMSprop(model1_params, options);
801
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
802
      torch::load, optim1_2, optim_tempfile_old_format.name);
803
  const auto& params1_2_ = optim1_2.param_groups()[0].params();
804
  auto& optim1_2_state = optim1_2.state();
805
  // old RMSprop didn't track step value
806
  for (const auto i : c10::irange(params1_2_.size())) {
807
    if (i != (params1_2_.size() - 1)) {
808
      auto key_ = params_[i].unsafeGetTensorImpl();
809
      const RMSpropParamState& curr_state_ =
810
          static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
811
      RMSpropParamState& curr_state1_2_ =
812
          static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
813
      curr_state1_2_.step(curr_state_.step());
814
    }
815
  }
816
  is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
817
}
818

819
TEST(SerializeTest, Optim_LBFGS) {
820
  test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
821
      LBFGSOptions(), true);
822
  // bc compatibility check
823
  auto model1 = Linear(5, 2);
824
  auto model1_params = model1->parameters();
825
  // added a tensor for lazy init check - when all params do not have entry in
826
  // buffers
827
  model1_params.emplace_back(torch::randn({2, 3}));
828
  auto optim1 =
829
      torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
830

831
  auto x = torch::ones({10, 5});
832
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
833
    optimizer.zero_grad();
834
    auto y = model->forward(x).sum();
835
    y.backward();
836
    auto closure = []() { return torch::tensor({10}); };
837
    optimizer.step(closure);
838
  };
839

840
  step(optim1, model1);
841

842
  at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
843
  std::deque<at::Tensor> old_dirs, old_stps;
844

845
  const auto& params_ = optim1.param_groups()[0].params();
846
  auto key_ = params_[0].unsafeGetTensorImpl();
847
  const auto& optim1_state =
848
      static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
849
  d = optim1_state.d();
850
  t = at::tensor(optim1_state.t());
851
  H_diag = optim1_state.H_diag();
852
  prev_flat_grad = optim1_state.prev_flat_grad();
853
  prev_loss = at::tensor(optim1_state.prev_loss());
854
  old_dirs = optim1_state.old_dirs();
855

856
  // write buffers to the file
857
  auto optim_tempfile_old_format = c10::make_tempfile();
858
  torch::serialize::OutputArchive output_archive;
859
  output_archive.write("d", d, /*is_buffer=*/true);
860
  output_archive.write("t", t, /*is_buffer=*/true);
861
  output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
862
  output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
863
  output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
864
  write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
865
  write_tensors_to_archive(output_archive, "old_stps", old_stps);
866
  output_archive.save_to(optim_tempfile_old_format.name);
867

868
  auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
869
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
870
      torch::load, optim1_2, optim_tempfile_old_format.name);
871

872
  const auto& params1_2_ = optim1_2.param_groups()[0].params();
873
  auto param_key = params1_2_[0].unsafeGetTensorImpl();
874
  auto& optim1_2_state =
875
      static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
876

877
  // old LBFGS didn't track func_evals, n_iter, ro, al values
878
  optim1_2_state.func_evals(optim1_state.func_evals());
879
  optim1_2_state.n_iter(optim1_state.n_iter());
880
  optim1_2_state.ro(optim1_state.ro());
881
  optim1_2_state.al(optim1_state.al());
882

883
  is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
884
}
885

886
TEST(SerializeTest, XOR_CUDA) {
887
  torch::manual_seed(0);
888
  // We better be able to save and load a XOR model!
889
  auto getLoss = [](Sequential model,
890
                    uint32_t batch_size,
891
                    bool is_cuda = false) {
892
    auto inputs = torch::empty({batch_size, 2});
893
    auto labels = torch::empty({batch_size});
894
    if (is_cuda) {
895
      inputs = inputs.cuda();
896
      labels = labels.cuda();
897
    }
898
    for (const auto i : c10::irange(batch_size)) {
899
      inputs[i] = torch::randint(2, {2}, torch::kInt64);
900
      labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
901
    }
902
    auto x = model->forward<torch::Tensor>(inputs);
903
    return torch::binary_cross_entropy(x, labels);
904
  };
905

906
  auto model = xor_model();
907
  auto model2 = xor_model();
908
  auto model3 = xor_model();
909
  auto optimizer = torch::optim::SGD(
910
      model->parameters(),
911
      torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
912
          1e-6));
913

914
  float running_loss = 1;
915
  int epoch = 0;
916
  while (running_loss > 0.1) {
917
    torch::Tensor loss = getLoss(model, 4);
918
    optimizer.zero_grad();
919
    loss.backward();
920
    optimizer.step();
921

922
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
923
    running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
924
    ASSERT_LT(epoch, 3000);
925
    epoch++;
926
  }
927

928
  auto tempfile = c10::make_tempfile();
929
  torch::save(model, tempfile.name);
930
  torch::load(model2, tempfile.name);
931

932
  auto loss = getLoss(model2, 100);
933
  ASSERT_LT(loss.item<float>(), 0.1);
934

935
  model2->to(torch::kCUDA);
936
  loss = getLoss(model2, 100, true);
937
  ASSERT_LT(loss.item<float>(), 0.1);
938

939
  auto tempfile2 = c10::make_tempfile();
940
  torch::save(model2, tempfile2.name);
941
  torch::load(model3, tempfile2.name);
942

943
  loss = getLoss(model3, 100, true);
944
  ASSERT_LT(loss.item<float>(), 0.1);
945
}
946

947
TEST(
948
    SerializeTest,
949
    CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
950
  struct C : torch::nn::Module {
951
    C() {
952
      register_buffer("foo", torch::ones(5, torch::kInt32));
953
    }
954
  };
955
  struct B : torch::nn::Module {};
956
  struct A : torch::nn::Module {
957
    A() {
958
      register_module("b", std::make_shared<B>());
959
      register_module("c", std::make_shared<C>());
960
    }
961
  };
962
  struct M : torch::nn::Module {
963
    M() {
964
      register_module("a", std::make_shared<A>());
965
    }
966
  };
967

968
  auto out = std::make_shared<M>();
969
  std::stringstream ss;
970
  torch::save(out, ss);
971
  auto in = std::make_shared<M>();
972
  torch::load(in, ss);
973

974
  const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
975
  ASSERT_EQ(output, 5);
976
}
977

978
TEST(SerializeTest, VectorOfTensors) {
979
  torch::manual_seed(0);
980

981
  std::vector<torch::Tensor> x_vec = {
982
      torch::randn({1, 2}), torch::randn({3, 4})};
983

984
  std::stringstream stream;
985
  torch::save(x_vec, stream);
986

987
  std::vector<torch::Tensor> y_vec;
988
  torch::load(y_vec, stream);
989

990
  for (const auto i : c10::irange(x_vec.size())) {
991
    auto& x = x_vec[i];
992
    auto& y = y_vec[i];
993
    ASSERT_TRUE(y.defined());
994
    ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
995
    ASSERT_TRUE(x.allclose(y));
996
  }
997
}
998

999
TEST(SerializeTest, IValue) {
1000
  c10::IValue ivalue(1);
1001
  auto tempfile = c10::make_tempfile();
1002
  torch::serialize::OutputArchive output_archive;
1003
  output_archive.write("value", ivalue);
1004
  output_archive.save_to(tempfile.name);
1005

1006
  torch::serialize::InputArchive input_archive;
1007
  input_archive.load_from(tempfile.name);
1008
  c10::IValue ivalue_out;
1009
  input_archive.read("value", ivalue_out);
1010
  ASSERT_EQ(ivalue_out.toInt(), 1);
1011

1012
  ASSERT_THROWS_WITH(
1013
      input_archive.read("bad_key", ivalue_out),
1014
      "does not have a field with name");
1015
}
1016

1017
// NOTE: if a `Module` contains unserializable submodules (e.g.
1018
// `nn::Functional`), we expect those submodules to be skipped when the `Module`
1019
// is being serialized.
1020
TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1021
  struct A : torch::nn::Module {
1022
    A() {
1023
      register_module("relu", torch::nn::Functional(torch::relu));
1024
    }
1025
  };
1026

1027
  auto out = std::make_shared<A>();
1028
  std::stringstream ss;
1029
  torch::save(out, ss);
1030

1031
  torch::serialize::InputArchive archive;
1032
  archive.load_from(ss);
1033
  torch::serialize::InputArchive relu_archive;
1034

1035
  // Submodule with name "relu" should not exist in the `InputArchive`,
1036
  // because the "relu" submodule is an `nn::Functional` and is not
1037
  // serializable.
1038
  ASSERT_FALSE(archive.try_read("relu", relu_archive));
1039
}
1040

1041
// NOTE: If a `Module` contains unserializable submodules (e.g.
1042
// `nn::Functional`), we don't check the existence of those submodules in the
1043
// `InputArchive` when deserializing.
1044
TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1045
  struct B : torch::nn::Module {
1046
    B() {
1047
      register_module("relu1", torch::nn::Functional(torch::relu));
1048
      register_buffer("foo", torch::zeros(5, torch::kInt32));
1049
    }
1050
  };
1051
  struct A : torch::nn::Module {
1052
    A() {
1053
      register_module("b", std::make_shared<B>());
1054
      register_module("relu2", torch::nn::Functional(torch::relu));
1055
    }
1056
  };
1057

1058
  auto out = std::make_shared<A>();
1059
  // Manually change the values of "b.foo", so that we can check whether the
1060
  // buffer contains these values after deserialization.
1061
  out->named_buffers()["b.foo"].fill_(1);
1062
  auto tempfile = c10::make_tempfile();
1063
  torch::save(out, tempfile.name);
1064

1065
  torch::serialize::InputArchive archive;
1066
  archive.load_from(tempfile.name);
1067
  torch::serialize::InputArchive archive_b;
1068
  torch::serialize::InputArchive archive_relu;
1069
  torch::Tensor tensor_foo;
1070

1071
  ASSERT_TRUE(archive.try_read("b", archive_b));
1072
  ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
1073

1074
  // Submodule with name "relu1" should not exist in `archive_b`, because the
1075
  // "relu1" submodule is an `nn::Functional` and is not serializable.
1076
  ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1077

1078
  // Submodule with name "relu2" should not exist in `archive`, because the
1079
  // "relu2" submodule is an `nn::Functional` and is not serializable.
1080
  ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1081

1082
  auto in = std::make_shared<A>();
1083
  // `torch::load(...)` works without error, even though `A` contains the
1084
  // `nn::Functional` submodules while the serialized file doesn't, because the
1085
  // `nn::Functional` submodules are not serializable and thus ignored when
1086
  // deserializing.
1087
  torch::load(in, tempfile.name);
1088

1089
  // Check that the "b.foo" buffer is correctly deserialized from the file.
1090
  const int output = in->named_buffers()["b.foo"].sum().item<int>();
1091
  // `output` should equal to the sum of the values we manually assigned to
1092
  // "b.foo" before serialization.
1093
  ASSERT_EQ(output, 5);
1094
}
1095

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.