1
#include <gtest/gtest.h>
3
#include <c10/util/flat_hash_map.h>
4
#include <c10/util/irange.h>
5
#include <c10/util/tempfile.h>
7
#include <torch/torch.h>
9
#include <test/cpp/api/support.h>
17
using namespace torch::test;
18
using namespace torch::nn;
19
using namespace torch::optim;
22
Sequential xor_model() {
25
Functional(at::sigmoid),
27
Functional(at::sigmoid));
30
torch::Tensor save_and_load(torch::Tensor input) {
31
std::stringstream stream;
32
torch::save(input, stream);
34
torch::load(tensor, stream);
39
template <typename DerivedOptions>
40
void is_optimizer_param_group_equal(
41
const OptimizerParamGroup& lhs,
42
const OptimizerParamGroup& rhs) {
43
const auto& lhs_params = lhs.params();
44
const auto& rhs_params = rhs.params();
46
ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47
for (const auto j : c10::irange(lhs_params.size())) {
48
ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
51
static_cast<const DerivedOptions&>(lhs.options()) ==
52
static_cast<const DerivedOptions&>(rhs.options()));
55
template <typename DerivedOptimizerParamState>
56
void is_optimizer_state_equal(
57
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
59
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
61
ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62
for (const auto& value : lhs_state) {
63
auto found = rhs_state.find(value.first);
64
ASSERT_TRUE(found != rhs_state.end());
65
const DerivedOptimizerParamState& lhs_curr_state =
66
static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67
const DerivedOptimizerParamState& rhs_curr_state =
68
static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69
ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
74
typename OptimizerClass,
75
typename DerivedOptimizerOptions,
76
typename DerivedOptimizerParamState>
77
void test_serialize_optimizer(
78
DerivedOptimizerOptions options,
79
bool only_has_global_state = false) {
80
torch::manual_seed(0);
81
auto model1 = Linear(5, 2);
82
auto model2 = Linear(5, 2);
83
auto model3 = Linear(5, 2);
86
auto model_tempfile = c10::make_tempfile();
87
torch::save(model1, model_tempfile.name);
88
torch::load(model2, model_tempfile.name);
89
torch::load(model3, model_tempfile.name);
91
auto param1 = model1->named_parameters();
92
auto param2 = model2->named_parameters();
93
auto param3 = model3->named_parameters();
94
for (const auto& p : param1) {
95
ASSERT_TRUE(p->allclose(param2[p.key()]));
96
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
99
auto optim1 = OptimizerClass(
100
{torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101
auto optim2 = OptimizerClass(model2->parameters(), options);
102
auto optim2_2 = OptimizerClass(model2->parameters(), options);
103
auto optim3 = OptimizerClass(model3->parameters(), options);
104
auto optim3_2 = OptimizerClass(model3->parameters(), options);
105
for (auto& param_group : optim3_2.param_groups()) {
106
const double lr = param_group.options().get_lr();
109
param_group.options().set_lr(lr + 0.01);
112
auto x = torch::ones({10, 5});
114
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
115
optimizer.zero_grad();
116
auto y = model->forward(x).sum();
118
auto closure = []() { return torch::tensor({10}); };
119
optimizer.step(closure);
123
step(optim1, model1);
124
step(optim1, model1);
127
step(optim2, model2);
128
step(optim2_2, model2);
131
step(optim3, model3);
134
auto optim_tempfile = c10::make_tempfile();
135
torch::save(optim3, optim_tempfile.name);
136
torch::load(optim3_2, optim_tempfile.name);
138
auto& optim3_2_param_groups = optim3_2.param_groups();
139
auto& optim3_param_groups = optim3.param_groups();
140
auto& optim3_2_state = optim3_2.state();
141
auto& optim3_state = optim3.state();
145
ASSERT_TRUE(optim3_2_param_groups.size() == 1);
148
unsigned state_size = only_has_global_state ? 1 : 2;
149
ASSERT_TRUE(optim3_2_state.size() == state_size);
152
ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
153
ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
157
for (const auto i : c10::irange(optim3_2_param_groups.size())) {
158
is_optimizer_param_group_equal<DerivedOptimizerOptions>(
159
optim3_2_param_groups[i], optim3_param_groups[i]);
160
is_optimizer_state_equal<DerivedOptimizerParamState>(
161
optim3_2_state, optim3_state);
165
step(optim3_2, model3);
167
param1 = model1->named_parameters();
168
param2 = model2->named_parameters();
169
param3 = model3->named_parameters();
170
for (const auto& p : param1) {
171
const auto& name = p.key();
174
param1[name].norm().item<float>() == param3[name].norm().item<float>());
176
param1[name].norm().item<float>() != param2[name].norm().item<float>());
182
torch::serialize::OutputArchive& archive,
183
const std::string& key,
184
const int64_t& value) {
185
archive.write(key, c10::IValue(value));
188
template <typename BufferContainer>
189
void write_tensors_to_archive(
190
torch::serialize::OutputArchive& archive,
191
const std::string& key,
192
const BufferContainer& buffers) {
194
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
195
for (const auto index : c10::irange(buffers.size())) {
197
key + "/" + std::to_string(index), buffers[index], true);
202
void write_step_buffers(
203
torch::serialize::OutputArchive& archive,
204
const std::string& key,
205
const std::vector<int64_t>& steps) {
206
std::vector<torch::Tensor> tensors;
207
tensors.reserve(steps.size());
208
for (const auto& step : steps) {
209
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
211
write_tensors_to_archive(archive, key, tensors);
214
#define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
216
WarningCapture warnings; \
217
funcname(optimizer, filename); \
219
count_substr_occurrences(warnings.str(), "old serialization"), 1); \
222
TEST(SerializeTest, KeysFunc) {
223
auto tempfile = c10::make_tempfile();
224
torch::serialize::OutputArchive output_archive;
225
for (const auto i : c10::irange(3)) {
226
output_archive.write(
227
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
229
output_archive.save_to(tempfile.name);
230
torch::serialize::InputArchive input_archive;
231
input_archive.load_from(tempfile.name);
232
std::vector<std::string> keys = input_archive.keys();
233
ASSERT_EQ(keys.size(), 3);
234
for (const auto i : c10::irange(keys.size())) {
235
ASSERT_EQ(keys[i], "element/" + std::to_string(i));
239
TEST(SerializeTest, TryReadFunc) {
240
auto tempfile = c10::make_tempfile();
241
torch::serialize::OutputArchive output_archive;
242
for (const auto i : c10::irange(3)) {
243
output_archive.write(
244
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
246
output_archive.save_to(tempfile.name);
247
torch::serialize::InputArchive input_archive;
248
input_archive.load_from(tempfile.name);
250
ASSERT_FALSE(input_archive.try_read("1", ivalue));
251
ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
252
ASSERT_EQ(ivalue.toInt(), 1);
255
TEST(SerializeTest, Basic) {
256
torch::manual_seed(0);
258
auto x = torch::randn({5, 5});
259
auto y = save_and_load(x);
261
ASSERT_TRUE(y.defined());
262
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
263
ASSERT_TRUE(x.allclose(y));
266
TEST(SerializeTest, MathBits) {
267
torch::manual_seed(0);
269
auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
270
auto x = torch::randn({5, 5}, options);
272
auto expected = torch::conj(x);
273
auto actual = save_and_load(expected);
275
ASSERT_TRUE(actual.defined());
276
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
277
ASSERT_TRUE(actual.allclose(expected));
281
auto expected = torch::_neg_view(x);
282
auto actual = save_and_load(expected);
284
ASSERT_TRUE(actual.defined());
285
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
286
ASSERT_TRUE(actual.allclose(expected));
290
auto expected = torch::conj(torch::_neg_view(x));
291
auto actual = save_and_load(expected);
293
ASSERT_TRUE(actual.defined());
294
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
295
ASSERT_TRUE(actual.allclose(expected));
302
auto t = torch::_efficientzerotensor({5, 5});
303
ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
307
TEST(SerializeTest, BasicToFile) {
308
torch::manual_seed(0);
310
auto x = torch::randn({5, 5});
312
auto tempfile = c10::make_tempfile();
313
torch::save(x, tempfile.name);
316
torch::load(y, tempfile.name);
318
ASSERT_TRUE(y.defined());
319
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
320
ASSERT_TRUE(x.allclose(y));
323
TEST(SerializeTest, BasicViaFunc) {
324
torch::manual_seed(0);
326
auto x = torch::randn({5, 5});
328
std::string serialized;
329
torch::save(x, [&](const void* buf, size_t n) {
330
serialized.append(reinterpret_cast<const char*>(buf), n);
334
torch::load(y, serialized.data(), serialized.size());
336
ASSERT_TRUE(y.defined());
337
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
338
ASSERT_TRUE(x.allclose(y));
343
[&](uint64_t pos, void* buf, size_t n) -> size_t {
344
if (pos >= serialized.size())
347
std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
348
memcpy(buf, serialized.data() + pos, nbytes);
351
[&]() -> size_t { return serialized.size(); });
352
ASSERT_TRUE(z.defined());
353
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
354
ASSERT_TRUE(x.allclose(z));
357
TEST(SerializeTest, Resized) {
358
torch::manual_seed(0);
360
auto x = torch::randn({11, 5});
362
auto y = save_and_load(x);
364
ASSERT_TRUE(y.defined());
365
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
366
ASSERT_TRUE(x.allclose(y));
369
TEST(SerializeTest, Sliced) {
370
torch::manual_seed(0);
372
auto x = torch::randn({11, 5});
373
x = x.slice(0, 1, 5);
374
auto y = save_and_load(x);
376
ASSERT_TRUE(y.defined());
377
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
378
ASSERT_TRUE(x.allclose(y));
381
TEST(SerializeTest, NonContiguous) {
382
torch::manual_seed(0);
384
auto x = torch::randn({11, 5});
385
x = x.slice(1, 1, 4);
386
auto y = save_and_load(x);
388
ASSERT_TRUE(y.defined());
389
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
390
ASSERT_TRUE(x.allclose(y));
393
TEST(SerializeTest, ErrorOnMissingKey) {
394
struct B : torch::nn::Module {
395
B(const std::string& name_c) {
396
register_buffer(name_c, torch::ones(5, torch::kFloat));
399
struct A : torch::nn::Module {
400
A(const std::string& name_b, const std::string& name_c) {
401
register_module(name_b, std::make_shared<B>(name_c));
404
struct M : torch::nn::Module {
405
M(const std::string& name_a,
406
const std::string& name_b,
407
const std::string& name_c) {
408
register_module(name_a, std::make_shared<A>(name_b, name_c));
413
auto model1 = std::make_shared<M>("a", "b", "c");
414
auto model2 = std::make_shared<M>("a", "b", "x");
415
auto model3 = std::make_shared<M>("a", "x", "c");
417
std::stringstream stream;
418
torch::save(model1, stream);
421
torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
422
stream.seekg(0, stream.beg);
424
torch::load(model3, stream), "No such serialized submodule: 'a.x'");
427
TEST(SerializeTest, XOR) {
429
auto getLoss = [](Sequential model, uint32_t batch_size) {
430
auto inputs = torch::empty({batch_size, 2});
431
auto labels = torch::empty({batch_size});
432
for (const auto i : c10::irange(batch_size)) {
433
inputs[i] = torch::randint(2, {2}, torch::kInt64);
434
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
436
auto x = model->forward<torch::Tensor>(inputs);
437
return torch::binary_cross_entropy(x, labels);
440
auto model = xor_model();
441
auto model2 = xor_model();
442
auto model3 = xor_model();
443
auto optimizer = torch::optim::SGD(
445
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
448
float running_loss = 1;
450
while (running_loss > 0.1) {
451
torch::Tensor loss = getLoss(model, 4);
452
optimizer.zero_grad();
457
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
458
ASSERT_LT(epoch, 3000);
462
auto tempfile = c10::make_tempfile();
463
torch::save(model, tempfile.name);
464
torch::load(model2, tempfile.name);
466
auto loss = getLoss(model2, 100);
467
ASSERT_LT(loss.item<float>(), 0.1);
470
TEST(SerializeTest, Optim) {
471
auto model1 = Linear(5, 2);
472
auto model2 = Linear(5, 2);
473
auto model3 = Linear(5, 2);
476
auto model_tempfile = c10::make_tempfile();
477
torch::save(model1, model_tempfile.name);
478
torch::load(model2, model_tempfile.name);
479
torch::load(model3, model_tempfile.name);
481
auto param1 = model1->named_parameters();
482
auto param2 = model2->named_parameters();
483
auto param3 = model3->named_parameters();
484
for (const auto& p : param1) {
485
ASSERT_TRUE(p->allclose(param2[p.key()]));
486
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
490
auto optim1 = torch::optim::SGD(
491
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492
auto optim2 = torch::optim::SGD(
493
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494
auto optim2_2 = torch::optim::SGD(
495
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
496
auto optim3 = torch::optim::SGD(
497
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
498
auto optim3_2 = torch::optim::SGD(
499
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
501
auto x = torch::ones({10, 5});
503
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
504
optimizer.zero_grad();
505
auto y = model->forward(x).sum();
511
step(optim1, model1);
512
step(optim1, model1);
515
step(optim2, model2);
516
step(optim2_2, model2);
519
step(optim3, model3);
521
auto optim_tempfile = c10::make_tempfile();
522
torch::save(optim3, optim_tempfile.name);
523
torch::load(optim3_2, optim_tempfile.name);
524
step(optim3_2, model3);
526
param1 = model1->named_parameters();
527
param2 = model2->named_parameters();
528
param3 = model3->named_parameters();
529
for (const auto& p : param1) {
530
const auto& name = p.key();
533
param1[name].norm().item<float>() == param3[name].norm().item<float>());
535
param1[name].norm().item<float>() != param2[name].norm().item<float>());
539
TEST(SerializeTest, Optim_Adagrad) {
540
test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
541
AdagradOptions(1e-1));
544
auto model1 = Linear(5, 2);
545
auto optim1 = torch::optim::Adagrad(
546
model1->parameters(), torch::optim::AdagradOptions(1e-1));
548
auto x = torch::ones({10, 5});
549
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
550
optimizer.zero_grad();
551
auto y = model->forward(x).sum();
555
step(optim1, model1);
557
Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
560
std::vector<torch::Tensor> sum_buffers;
562
std::vector<int64_t> step_buffers;
563
const auto& params_ = optim1.param_groups()[0].params();
564
const auto& optim1_state = optim1.state();
565
for (const auto& param : params_) {
566
auto key_ = param.unsafeGetTensorImpl();
567
const AdagradParamState& curr_state_ =
568
static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
569
sum_buffers.emplace_back(curr_state_.sum());
570
step_buffers.emplace_back(curr_state_.step());
573
auto optim_tempfile_old_format = c10::make_tempfile();
574
torch::serialize::OutputArchive output_archive;
575
write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
576
write_step_buffers(output_archive, "step_buffers", step_buffers);
577
output_archive.save_to(optim_tempfile_old_format.name);
578
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
579
torch::load, optim1_2, optim_tempfile_old_format.name);
580
is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
583
TEST(SerializeTest, Optim_SGD) {
584
test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
585
SGDOptions(1e-1).momentum(0.9));
588
auto model1 = Linear(5, 2);
589
auto model1_params = model1->parameters();
592
model1_params.emplace_back(torch::randn({2, 3}));
593
auto optim1 = torch::optim::SGD(
594
model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
596
auto x = torch::ones({10, 5});
597
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
598
optimizer.zero_grad();
599
auto y = model->forward(x).sum();
603
step(optim1, model1);
605
std::vector<at::Tensor> momentum_buffers;
606
int64_t iteration_{0};
607
const auto& params_ = optim1.param_groups()[0].params();
608
const auto& optim1_state = optim1.state();
609
for (const auto i : c10::irange(params_.size())) {
610
if (i != (params_.size() - 1)) {
611
auto key_ = params_[i].unsafeGetTensorImpl();
612
const SGDParamState& curr_state_ =
613
static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
614
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
617
ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
619
auto optim_tempfile_old_format = c10::make_tempfile();
620
torch::serialize::OutputArchive output_archive;
621
write_tensors_to_archive(
622
output_archive, "momentum_buffers", momentum_buffers);
623
write_int_value(output_archive, "iteration_", iteration_);
624
output_archive.save_to(optim_tempfile_old_format.name);
626
SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
627
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
628
torch::load, optim1_2, optim_tempfile_old_format.name);
629
is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
632
TEST(SerializeTest, Optim_Adam) {
633
test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
634
AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
637
auto model1 = Linear(5, 2);
638
auto model1_params = model1->parameters();
641
model1_params.emplace_back(torch::randn({2, 3}));
642
auto optim1 = torch::optim::Adam(
643
model1_params, torch::optim::AdamOptions().weight_decay(0.5));
645
auto x = torch::ones({10, 5});
646
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
647
optimizer.zero_grad();
648
auto y = model->forward(x).sum();
652
step(optim1, model1);
654
std::vector<int64_t> step_buffers;
655
std::vector<at::Tensor> exp_average_buffers;
656
std::vector<at::Tensor> exp_average_sq_buffers;
657
std::vector<at::Tensor> max_exp_average_sq_buffers;
658
const auto& params_ = optim1.param_groups()[0].params();
659
const auto& optim1_state = optim1.state();
660
for (const auto i : c10::irange(params_.size())) {
661
if (i != (params_.size() - 1)) {
662
auto key_ = params_[i].unsafeGetTensorImpl();
663
const AdamParamState& curr_state_ =
664
static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
665
step_buffers.emplace_back(curr_state_.step());
666
exp_average_buffers.emplace_back(curr_state_.exp_avg());
667
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
668
if (curr_state_.max_exp_avg_sq().defined()) {
669
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
674
auto optim_tempfile_old_format = c10::make_tempfile();
675
torch::serialize::OutputArchive output_archive;
676
write_step_buffers(output_archive, "step_buffers", step_buffers);
677
write_tensors_to_archive(
678
output_archive, "exp_average_buffers", exp_average_buffers);
679
write_tensors_to_archive(
680
output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
681
write_tensors_to_archive(
682
output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
683
output_archive.save_to(optim_tempfile_old_format.name);
684
auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
685
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
686
torch::load, optim1_2, optim_tempfile_old_format.name);
687
is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
690
TEST(SerializeTest, Optim_AdamW) {
691
test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
692
AdamWOptions().lr(0.99999).amsgrad(true).betas(
693
std::make_tuple(0.999, 0.1)));
696
auto model1 = Linear(5, 2);
697
auto model1_params = model1->parameters();
700
model1_params.emplace_back(torch::randn({2, 3}));
701
auto optim1 = torch::optim::AdamW(
702
model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
704
auto x = torch::ones({10, 5});
705
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
706
optimizer.zero_grad();
707
auto y = model->forward(x).sum();
711
step(optim1, model1);
713
std::vector<int64_t> step_buffers;
714
std::vector<at::Tensor> exp_average_buffers;
715
std::vector<at::Tensor> exp_average_sq_buffers;
716
std::vector<at::Tensor> max_exp_average_sq_buffers;
717
const auto& params_ = optim1.param_groups()[0].params();
718
const auto& optim1_state = optim1.state();
719
for (const auto i : c10::irange(params_.size())) {
720
if (i != (params_.size() - 1)) {
721
auto key_ = params_[i].unsafeGetTensorImpl();
722
const AdamWParamState& curr_state_ =
723
static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
724
step_buffers.emplace_back(curr_state_.step());
725
exp_average_buffers.emplace_back(curr_state_.exp_avg());
726
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
727
if (curr_state_.max_exp_avg_sq().defined()) {
728
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
733
auto optim_tempfile_old_format = c10::make_tempfile();
734
torch::serialize::OutputArchive output_archive;
735
write_step_buffers(output_archive, "step_buffers", step_buffers);
736
write_tensors_to_archive(
737
output_archive, "exp_average_buffers", exp_average_buffers);
738
write_tensors_to_archive(
739
output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
740
write_tensors_to_archive(
741
output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
742
output_archive.save_to(optim_tempfile_old_format.name);
743
auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
744
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
745
torch::load, optim1_2, optim_tempfile_old_format.name);
746
is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
749
TEST(SerializeTest, Optim_RMSprop) {
750
auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
751
test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
754
auto model1 = Linear(5, 2);
755
auto model1_params = model1->parameters();
759
model1_params.emplace_back(torch::randn({2, 3}));
760
auto optim1 = torch::optim::RMSprop(model1_params, options);
762
auto x = torch::ones({10, 5});
763
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
764
optimizer.zero_grad();
765
auto y = model->forward(x).sum();
769
step(optim1, model1);
771
std::vector<at::Tensor> square_average_buffers;
772
std::vector<at::Tensor> momentum_buffers;
773
std::vector<at::Tensor> grad_average_buffers;
774
const auto& params_ = optim1.param_groups()[0].params();
775
const auto& optim1_state = optim1.state();
776
for (const auto i : c10::irange(params_.size())) {
777
if (i != (params_.size() - 1)) {
778
auto key_ = params_[i].unsafeGetTensorImpl();
779
const RMSpropParamState& curr_state_ =
780
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
781
square_average_buffers.emplace_back(curr_state_.square_avg());
782
if (curr_state_.momentum_buffer().defined()) {
783
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
785
if (curr_state_.grad_avg().defined()) {
786
grad_average_buffers.emplace_back(curr_state_.grad_avg());
791
auto optim_tempfile_old_format = c10::make_tempfile();
792
torch::serialize::OutputArchive output_archive;
793
write_tensors_to_archive(
794
output_archive, "square_average_buffers", square_average_buffers);
795
write_tensors_to_archive(
796
output_archive, "momentum_buffers", momentum_buffers);
797
write_tensors_to_archive(
798
output_archive, "grad_average_buffers", grad_average_buffers);
799
output_archive.save_to(optim_tempfile_old_format.name);
800
auto optim1_2 = RMSprop(model1_params, options);
801
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
802
torch::load, optim1_2, optim_tempfile_old_format.name);
803
const auto& params1_2_ = optim1_2.param_groups()[0].params();
804
auto& optim1_2_state = optim1_2.state();
806
for (const auto i : c10::irange(params1_2_.size())) {
807
if (i != (params1_2_.size() - 1)) {
808
auto key_ = params_[i].unsafeGetTensorImpl();
809
const RMSpropParamState& curr_state_ =
810
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
811
RMSpropParamState& curr_state1_2_ =
812
static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
813
curr_state1_2_.step(curr_state_.step());
816
is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
819
TEST(SerializeTest, Optim_LBFGS) {
820
test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
821
LBFGSOptions(), true);
823
auto model1 = Linear(5, 2);
824
auto model1_params = model1->parameters();
827
model1_params.emplace_back(torch::randn({2, 3}));
829
torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
831
auto x = torch::ones({10, 5});
832
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
833
optimizer.zero_grad();
834
auto y = model->forward(x).sum();
836
auto closure = []() { return torch::tensor({10}); };
837
optimizer.step(closure);
840
step(optim1, model1);
842
at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
843
std::deque<at::Tensor> old_dirs, old_stps;
845
const auto& params_ = optim1.param_groups()[0].params();
846
auto key_ = params_[0].unsafeGetTensorImpl();
847
const auto& optim1_state =
848
static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
849
d = optim1_state.d();
850
t = at::tensor(optim1_state.t());
851
H_diag = optim1_state.H_diag();
852
prev_flat_grad = optim1_state.prev_flat_grad();
853
prev_loss = at::tensor(optim1_state.prev_loss());
854
old_dirs = optim1_state.old_dirs();
857
auto optim_tempfile_old_format = c10::make_tempfile();
858
torch::serialize::OutputArchive output_archive;
859
output_archive.write("d", d, true);
860
output_archive.write("t", t, true);
861
output_archive.write("H_diag", H_diag, true);
862
output_archive.write("prev_flat_grad", prev_flat_grad, true);
863
output_archive.write("prev_loss", prev_loss, true);
864
write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
865
write_tensors_to_archive(output_archive, "old_stps", old_stps);
866
output_archive.save_to(optim_tempfile_old_format.name);
868
auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
869
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
870
torch::load, optim1_2, optim_tempfile_old_format.name);
872
const auto& params1_2_ = optim1_2.param_groups()[0].params();
873
auto param_key = params1_2_[0].unsafeGetTensorImpl();
874
auto& optim1_2_state =
875
static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
878
optim1_2_state.func_evals(optim1_state.func_evals());
879
optim1_2_state.n_iter(optim1_state.n_iter());
880
optim1_2_state.ro(optim1_state.ro());
881
optim1_2_state.al(optim1_state.al());
883
is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
886
TEST(SerializeTest, XOR_CUDA) {
887
torch::manual_seed(0);
889
auto getLoss = [](Sequential model,
891
bool is_cuda = false) {
892
auto inputs = torch::empty({batch_size, 2});
893
auto labels = torch::empty({batch_size});
895
inputs = inputs.cuda();
896
labels = labels.cuda();
898
for (const auto i : c10::irange(batch_size)) {
899
inputs[i] = torch::randint(2, {2}, torch::kInt64);
900
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
902
auto x = model->forward<torch::Tensor>(inputs);
903
return torch::binary_cross_entropy(x, labels);
906
auto model = xor_model();
907
auto model2 = xor_model();
908
auto model3 = xor_model();
909
auto optimizer = torch::optim::SGD(
911
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
914
float running_loss = 1;
916
while (running_loss > 0.1) {
917
torch::Tensor loss = getLoss(model, 4);
918
optimizer.zero_grad();
923
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
924
ASSERT_LT(epoch, 3000);
928
auto tempfile = c10::make_tempfile();
929
torch::save(model, tempfile.name);
930
torch::load(model2, tempfile.name);
932
auto loss = getLoss(model2, 100);
933
ASSERT_LT(loss.item<float>(), 0.1);
935
model2->to(torch::kCUDA);
936
loss = getLoss(model2, 100, true);
937
ASSERT_LT(loss.item<float>(), 0.1);
939
auto tempfile2 = c10::make_tempfile();
940
torch::save(model2, tempfile2.name);
941
torch::load(model3, tempfile2.name);
943
loss = getLoss(model3, 100, true);
944
ASSERT_LT(loss.item<float>(), 0.1);
949
CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
950
struct C : torch::nn::Module {
952
register_buffer("foo", torch::ones(5, torch::kInt32));
955
struct B : torch::nn::Module {};
956
struct A : torch::nn::Module {
958
register_module("b", std::make_shared<B>());
959
register_module("c", std::make_shared<C>());
962
struct M : torch::nn::Module {
964
register_module("a", std::make_shared<A>());
968
auto out = std::make_shared<M>();
969
std::stringstream ss;
970
torch::save(out, ss);
971
auto in = std::make_shared<M>();
974
const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
975
ASSERT_EQ(output, 5);
978
TEST(SerializeTest, VectorOfTensors) {
979
torch::manual_seed(0);
981
std::vector<torch::Tensor> x_vec = {
982
torch::randn({1, 2}), torch::randn({3, 4})};
984
std::stringstream stream;
985
torch::save(x_vec, stream);
987
std::vector<torch::Tensor> y_vec;
988
torch::load(y_vec, stream);
990
for (const auto i : c10::irange(x_vec.size())) {
993
ASSERT_TRUE(y.defined());
994
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
995
ASSERT_TRUE(x.allclose(y));
999
TEST(SerializeTest, IValue) {
1000
c10::IValue ivalue(1);
1001
auto tempfile = c10::make_tempfile();
1002
torch::serialize::OutputArchive output_archive;
1003
output_archive.write("value", ivalue);
1004
output_archive.save_to(tempfile.name);
1006
torch::serialize::InputArchive input_archive;
1007
input_archive.load_from(tempfile.name);
1008
c10::IValue ivalue_out;
1009
input_archive.read("value", ivalue_out);
1010
ASSERT_EQ(ivalue_out.toInt(), 1);
1013
input_archive.read("bad_key", ivalue_out),
1014
"does not have a field with name");
1020
TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1021
struct A : torch::nn::Module {
1023
register_module("relu", torch::nn::Functional(torch::relu));
1027
auto out = std::make_shared<A>();
1028
std::stringstream ss;
1029
torch::save(out, ss);
1031
torch::serialize::InputArchive archive;
1032
archive.load_from(ss);
1033
torch::serialize::InputArchive relu_archive;
1038
ASSERT_FALSE(archive.try_read("relu", relu_archive));
1044
TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1045
struct B : torch::nn::Module {
1047
register_module("relu1", torch::nn::Functional(torch::relu));
1048
register_buffer("foo", torch::zeros(5, torch::kInt32));
1051
struct A : torch::nn::Module {
1053
register_module("b", std::make_shared<B>());
1054
register_module("relu2", torch::nn::Functional(torch::relu));
1058
auto out = std::make_shared<A>();
1061
out->named_buffers()["b.foo"].fill_(1);
1062
auto tempfile = c10::make_tempfile();
1063
torch::save(out, tempfile.name);
1065
torch::serialize::InputArchive archive;
1066
archive.load_from(tempfile.name);
1067
torch::serialize::InputArchive archive_b;
1068
torch::serialize::InputArchive archive_relu;
1069
torch::Tensor tensor_foo;
1071
ASSERT_TRUE(archive.try_read("b", archive_b));
1072
ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, true));
1076
ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1080
ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1082
auto in = std::make_shared<A>();
1087
torch::load(in, tempfile.name);
1090
const int output = in->named_buffers()["b.foo"].sum().item<int>();
1093
ASSERT_EQ(output, 5);