1
#include <gtest/gtest.h>
3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
6
#include <test/cpp/api/support.h>
8
using namespace torch::nn;
9
using namespace torch::test;
11
struct AGIUnit : torch::nn::Module {};
14
struct AGIUnit : torch::nn::Module {};
15
struct AGIUnit2 : torch::nn::Module {
16
AGIUnit2() : torch::nn::Module("Foo") {}
20
struct ModuleTest : torch::test::SeedingFixture {};
22
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
24
ASSERT_TRUE(module->is_training());
27
ASSERT_FALSE(module->is_training());
30
ASSERT_TRUE(module->is_training());
33
TEST_F(ModuleTest, ZeroGrad) {
35
auto weight = torch::ones({8, 3}, torch::requires_grad());
36
auto loss = module(weight).sum();
38
for (auto& parameter : module->parameters()) {
39
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
40
auto grad = parameter.grad();
41
ASSERT_TRUE(grad.defined());
42
ASSERT_NE(grad.sum().item<float>(), 0);
45
for (auto& parameter : module->parameters()) {
46
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
47
auto grad = parameter.grad();
48
ASSERT_FALSE(grad.defined());
52
TEST_F(ModuleTest, ZeroGradWithUndefined) {
53
struct TestModule : torch::nn::Module {
55
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
56
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
62
auto z = module.x * 2;
65
ASSERT_TRUE(module.x.grad().defined());
66
ASSERT_FALSE(module.y.grad().defined());
68
module.zero_grad(false); // set_to_none = false
70
ASSERT_TRUE(module.x.grad().defined());
71
ASSERT_FALSE(module.y.grad().defined());
73
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
77
ASSERT_FALSE(module.x.grad().defined());
78
ASSERT_FALSE(module.y.grad().defined());
81
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
82
struct TestModel : public torch::nn::Module {};
84
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85
"Submodule name must not contain a dot (got 'name.with.dot')");
87
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88
"Submodule name must not be empty");
91
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92
struct TestModel : public torch::nn::Module {};
94
model.register_module("linear", torch::nn::Linear(3, 4));
96
model.register_module("linear", torch::nn::Linear(3, 4)),
97
"Submodule 'linear' already defined");
100
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
101
torch::nn::Module model;
103
model.replace_module("linear", torch::nn::Linear(3, 4)),
104
"Submodule 'linear' is not defined");
107
TEST_F(ModuleTest, ReplaceModule) {
108
struct TestModel : public torch::nn::Module {
109
torch::nn::Linear l1{nullptr};
111
l1 = register_module("l1", torch::nn::Linear(3, 4));
114
auto model = std::make_shared<TestModel>();
115
model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
116
ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
117
ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
120
TEST_F(ModuleTest, UnregisterModule) {
121
struct TestModel : public torch::nn::Module {};
124
model.unregister_module("linear"),
125
"No Module with name `linear` is registered");
126
model.register_module("linear", torch::nn::Linear(3, 4));
127
model.unregister_module("linear");
128
ASSERT_TRUE(model.children().empty());
131
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
132
struct TestModel : public torch::nn::Module {};
134
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
135
"Parameter name must not contain a dot (got 'name.with.dot')");
137
TestModel{}.register_parameter("", torch::ones(5)),
138
"Parameter name must not be empty");
141
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
142
struct TestModel : public torch::nn::Module {};
144
model.register_parameter("p", torch::ones(5));
146
model.register_parameter("p", torch::ones(5)),
147
"Parameter 'p' already defined");
150
TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
151
struct TestModel : public torch::nn::Module {};
154
model.register_parameter(
155
"undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
156
ASSERT_EQ(model.parameters().size(), 0);
159
WarningCapture warnings;
162
model.register_parameter("undefined_tensor", torch::Tensor());
163
ASSERT_EQ(model.parameters().size(), 0);
166
count_substr_occurrences(
168
"Ignoring the `requires_grad=true` function parameter"),
173
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
174
struct TestModel : public torch::nn::Module {};
176
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
177
"Buffer name must not contain a dot (got 'name.with.dot')");
179
TestModel{}.register_buffer("", torch::ones(5)),
180
"Buffer name must not be empty");
183
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
184
struct TestModel : public torch::nn::Module {};
186
model.register_buffer("p", torch::ones(5));
188
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
191
TEST_F(ModuleTest, CanGetName) {
192
// CHECK instead of REQUIRE because demangling may fail.
194
// Call it twice just to make sure there are no bugs in the lazy
195
// initialization semantics.
196
EXPECT_EQ(agi.name(), "AGIUnit");
197
EXPECT_EQ(agi.name(), "AGIUnit");
198
EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
199
EXPECT_EQ(test::AGIUnit2().name(), "Foo");
202
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
204
ASSERT_EQ(module->as<Linear>(), module.get());
205
ASSERT_EQ(module->as<LinearImpl>(), module.get());
206
ASSERT_EQ(module->as<Module>(), module.get());
207
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
209
std::shared_ptr<Module> raw = module.ptr();
210
ASSERT_EQ(raw->as<Linear>(), module.get());
211
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
212
ASSERT_EQ(raw->as<Module>(), module.get());
213
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
215
Module& raw_ref = *raw.get();
216
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
217
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
218
ASSERT_EQ(raw_ref.as<Module>(), module.get());
219
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
220
if (auto* linear = raw_ref.as<Linear>()) {
221
ASSERT_EQ(linear->weight.ndimension(), 2);
225
ASSERT_EQ(unit.as<Linear>(), nullptr);
226
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
227
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
230
void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
231
torch::Device to_device,
232
torch::Dtype to_dtype) {
234
// Case 1: Undefined tensors as parameters
235
Linear module(LinearOptions(10, 20).bias(false));
236
ASSERT_TRUE(module->weight.defined());
237
ASSERT_FALSE(module->bias.defined());
239
module->to(to_device);
240
ASSERT_TRUE(module->weight.defined());
241
ASSERT_EQ(module->weight.device().type(), to_device.type());
242
ASSERT_FALSE(module->bias.defined());
244
module->to(to_dtype);
245
ASSERT_TRUE(module->weight.defined());
246
ASSERT_EQ(module->weight.dtype(), to_dtype);
247
ASSERT_FALSE(module->bias.defined());
250
// Case 2: Undefined tensors as buffers
252
BatchNorm1dOptions(5).track_running_stats(false).affine(true));
253
ASSERT_TRUE(module->weight.defined());
254
ASSERT_FALSE(module->running_mean.defined());
256
module->to(to_device);
257
ASSERT_TRUE(module->weight.defined());
258
ASSERT_EQ(module->weight.device().type(), to_device.type());
259
ASSERT_FALSE(module->running_mean.defined());
261
module->to(to_dtype);
262
ASSERT_TRUE(module->weight.defined());
263
ASSERT_EQ(module->weight.dtype(), to_dtype);
264
ASSERT_FALSE(module->running_mean.defined());
268
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
269
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
272
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
273
test_DeviceOrDtypeConversionSkipsUndefinedTensor(
274
torch::kCUDA, torch::kDouble);
277
TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
279
Linear module(LinearOptions(10, 20).bias(false));
281
auto params = module->parameters();
282
ASSERT_EQ(params.size(), 1);
283
auto named_params = module->named_parameters();
284
ASSERT_EQ(named_params.size(), 1);
286
ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
287
ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
291
BatchNorm1dOptions(5).track_running_stats(false).affine(false));
293
auto buffers = module->buffers();
294
ASSERT_EQ(buffers.size(), 0);
295
auto named_buffers = module->named_buffers();
296
ASSERT_EQ(named_buffers.size(), 0);
300
BatchNorm1dOptions(5).track_running_stats(true).affine(false));
302
auto buffers = module->buffers();
303
ASSERT_EQ(buffers.size(), 3);
304
auto named_buffers = module->named_buffers();
305
ASSERT_EQ(named_buffers.size(), 3);
307
ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
309
pointer_equal(named_buffers["running_mean"], module->running_mean));
310
ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
312
pointer_equal(named_buffers["running_var"], module->running_var));
314
pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
315
ASSERT_TRUE(pointer_equal(
316
named_buffers["num_batches_tracked"], module->num_batches_tracked));
320
TEST_F(ModuleTest, Conversion_MultiCUDA) {
321
Linear module(128, 64);
322
for (auto& parameter : module->parameters()) {
323
ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
324
ASSERT_EQ(parameter.dtype(), torch::kFloat32);
327
module->to({torch::kCUDA, 0});
328
for (auto& parameter : module->parameters()) {
329
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
330
ASSERT_EQ(parameter.device().index(), 0);
332
module->to({torch::kCUDA, 1});
333
for (auto& parameter : module->parameters()) {
334
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
335
ASSERT_EQ(parameter.device().index(), 1);
339
module->to(torch::Device(torch::kCPU));
340
for (auto& parameter : module->parameters()) {
341
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
345
module->to(torch::kFloat64);
346
for (auto& parameter : module->parameters()) {
347
ASSERT_EQ(parameter.dtype(), torch::kFloat64);
352
TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
353
Linear module(128, 64);
354
for (auto& parameter : module->parameters()) {
355
parameter.requires_grad_(false);
358
module->to(torch::kInt32);
359
for (auto& parameter : module->parameters()) {
360
ASSERT_EQ(parameter.dtype(), torch::kInt32);
364
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
365
for (auto& parameter : module->parameters()) {
366
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
367
ASSERT_EQ(parameter.device().index(), 1);
369
for (auto& parameter : module->parameters()) {
370
ASSERT_EQ(parameter.dtype(), torch::kUInt8);
375
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
376
struct UnCloneable : Module {};
378
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
381
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
382
struct Cloneable : Module {
383
std::shared_ptr<Module> clone(
384
const torch::optional<torch::Device>& device =
385
torch::nullopt) const override {
390
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
391
ASSERT_NO_THROW({ module.clone(); });
394
// NOLINTNEXTLINE(bugprone-exception-escape)
395
struct TestDistinctParametersModule
396
: public Cloneable<TestDistinctParametersModule> {
397
TestDistinctParametersModule() {
398
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
401
void reset() override {
402
l1 = register_module("l1", Linear(10, 3));
403
l2 = register_module("l2", Linear(3, 5));
404
l3 = register_module("l3", Linear(5, 100));
405
buffer = register_buffer("buf", torch::ones({2, 2}));
408
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
409
torch::Tensor buffer;
412
void testDistinctParameters(
413
std::shared_ptr<Module> m1,
414
std::shared_ptr<Module> m2) {
415
auto params1 = m1->named_parameters();
416
auto params2 = m2->named_parameters();
417
ASSERT_EQ(params1.size(), 6);
418
ASSERT_EQ(params2.size(), 6);
419
for (auto& param : params1) {
420
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
421
ASSERT_TRUE(param->allclose(params2[param.key()]));
424
for (auto& param : params1) {
425
ASSERT_FALSE(param->allclose(params2[param.key()]));
428
auto buffers1 = m1->named_buffers();
429
auto buffers2 = m2->named_buffers();
430
ASSERT_EQ(buffers1.size(), 1);
431
ASSERT_EQ(buffers2.size(), 1);
432
for (auto& buffer : buffers1) {
433
ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
434
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
437
for (auto& buffer : buffers1) {
438
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
442
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
443
auto module = std::make_shared<TestDistinctParametersModule>();
444
torch::NoGradGuard no_grad;
445
auto module2 = module->clone();
446
testDistinctParameters(module, module2);
449
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
450
auto module = std::make_shared<TestDistinctParametersModule>();
451
torch::NoGradGuard no_grad;
452
torch::Device device(torch::kCUDA, 0);
454
auto module2 = module->clone(device);
455
testDistinctParameters(module, module2);
458
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
459
auto module = std::make_shared<TestDistinctParametersModule>();
460
torch::NoGradGuard no_grad;
461
torch::Device d0(torch::kCUDA, 0);
462
torch::Device d1(torch::kCUDA, 1);
464
auto module2 = module->clone(d1);
466
for (auto& param : module->parameters()) {
467
ASSERT_EQ(param.device(), d0);
470
for (auto& param : module2->parameters()) {
471
ASSERT_EQ(param.device(), d1);
474
// need to move the module back to d0 as allclose expects two tensors on
477
testDistinctParameters(module, module2);
480
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
481
// NOLINTNEXTLINE(bugprone-exception-escape)
482
struct TestModule : public Cloneable<TestModule> {
484
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
487
void reset() override {
488
weight = register_parameter("weight", torch::ones({4, 4}));
490
torch::Tensor weight;
492
auto module = std::make_shared<TestModule>();
494
torch::NoGradGuard no_grad;
498
pointer_equal(module->weight, module->named_parameters()["weight"]));
499
ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
501
auto module2 = std::dynamic_pointer_cast<TestModule>(
502
std::shared_ptr<Module>(module->clone()));
503
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
505
pointer_equal(module2->weight, module2->named_parameters()["weight"]));
506
ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
507
ASSERT_TRUE(module2->weight.allclose(module->weight));
509
pointer_equal(module2->weight, module->named_parameters()["weight"]));
512
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
513
// NOLINTNEXTLINE(bugprone-exception-escape)
514
struct TestModule : public Cloneable<TestModule> {
516
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
519
void reset() override {
520
weight = register_parameter("weight", torch::ones({4, 4}));
523
torch::Tensor weight;
526
// NOLINTNEXTLINE(bugprone-exception-escape)
527
struct NestedModule : public Cloneable<NestedModule> {
529
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
532
void reset() override {
533
module = register_module("module", std::make_shared<TestModule>());
535
std::shared_ptr<TestModule> module;
538
auto a = std::make_shared<NestedModule>();
540
torch::NoGradGuard no_grad;
541
a->module->weight += 1;
542
a->module->value = 123;
545
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
547
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
548
ASSERT_TRUE(pointer_equal(
549
b->module->weight, b->module->named_parameters()["weight"]));
551
b->module->named_parameters()["weight"].allclose(a->module->weight));
552
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
553
ASSERT_EQ(b->module->value, a->module->value);
556
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
557
// NOLINTNEXTLINE(bugprone-exception-escape)
558
struct TestModule : public Cloneable<TestModule> {
560
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
563
void reset() override {
564
l1 = register_module("l1", Linear(10, 3));
565
l2 = register_module("l2", Linear(3, 5));
566
l3 = register_module("l3", Linear(5, 100));
567
buffer = register_buffer("buf", torch::ones({2, 2}));
570
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
571
torch::Tensor buffer;
575
torch::Device device(torch::kCUDA, 0);
579
auto clone = m.clone();
580
for (const auto& parameter : clone->parameters()) {
581
ASSERT_EQ(parameter.device().type(), device.type());
582
ASSERT_EQ(parameter.device().index(), device.index());
584
for (const auto& buffer : clone->buffers()) {
585
ASSERT_EQ(buffer.device().type(), device.type());
586
ASSERT_EQ(buffer.device().index(), device.index());
592
CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
593
// NOLINTNEXTLINE(bugprone-exception-escape)
594
struct TestModule : public Cloneable<TestModule> {
596
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
599
void reset() override {
600
l1 = register_module("l1", Linear(10, 3));
601
l2 = register_module("l2", Linear(3, 5));
602
l3 = register_module("l3", Linear(5, 100));
603
buffer = register_buffer("buf", torch::ones({2, 2}));
606
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
607
torch::Tensor buffer;
611
torch::Device device(torch::kCUDA, 1);
612
// everything is on CPU here
613
auto clone = m.clone(device);
614
for (const auto& parameter : clone->parameters()) {
615
ASSERT_EQ(parameter.device().type(), device.type());
616
ASSERT_EQ(parameter.device().index(), device.index());
618
for (const auto& buffer : clone->buffers()) {
619
ASSERT_EQ(buffer.device().type(), device.type());
620
ASSERT_EQ(buffer.device().index(), device.index());
624
struct ParameterTestModule : Module {
625
ParameterTestModule() {
626
a = register_parameter("a", torch::zeros({2, 2}));
627
b = register_parameter("b", torch::ones({2, 2}));
628
c = register_parameter("c", torch::ones({2, 2}) * 2);
631
torch::Tensor a, b, c;
634
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
635
ParameterTestModule module;
636
ASSERT_EQ(module.parameters().size(), 3);
637
ASSERT_EQ(module.named_parameters().size(), 3);
640
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
641
ParameterTestModule module;
642
auto parameters = module.named_parameters();
643
ASSERT_TRUE(parameters.contains("a"));
644
ASSERT_TRUE(parameters.contains("b"));
645
ASSERT_TRUE(parameters.contains("c"));
648
struct BufferTestModule : Module {
650
a = register_buffer("a", torch::zeros({2, 2}));
651
b = register_buffer("b", torch::ones({2, 2}));
652
c = register_buffer("c", torch::ones({2, 2}) * 2);
655
torch::Tensor a, b, c;
658
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
659
BufferTestModule module;
660
ASSERT_EQ(module.buffers().size(), 3);
661
ASSERT_EQ(module.named_buffers().size(), 3);
664
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
665
BufferTestModule module;
666
auto buffers = module.named_buffers();
667
ASSERT_TRUE(buffers.contains("a"));
668
ASSERT_TRUE(buffers.contains("b"));
669
ASSERT_TRUE(buffers.contains("c"));
672
struct AImpl : torch::nn::Module {
674
AImpl(int x) : x_(x) {}
681
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
684
ASSERT_FALSE(a.is_empty());
685
ASSERT_EQ(a->x_, 123);
690
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
693
ASSERT_FALSE(a.is_empty());
697
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
700
ASSERT_TRUE(a.is_empty());
701
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
704
struct TestModule : public torch::nn::Module {
705
TestModule(int64_t size) {
706
p1 = register_parameter("p1", torch::randn({size}));
707
p2 = register_parameter("p2", torch::randn({size}));
708
b1 = register_buffer("b1", torch::randn({size}));
709
b2 = register_buffer("b2", torch::randn({size}));
712
torch::Tensor forward(torch::Tensor input) {
716
torch::Tensor p1, p2, b1, b2;
719
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
720
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
721
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
722
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
723
model.ptr(), model[0], model[1], model[2]};
724
ASSERT_EQ(modules.size(), expected.size());
725
for (const auto i : c10::irange(expected.size())) {
726
// Assert pointer equality.
727
ASSERT_EQ(modules[i].get(), expected[i].get());
731
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
732
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
733
std::vector<std::shared_ptr<torch::nn::Module>> modules =
734
model->modules(/*include_self=*/false);
735
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
736
model[0], model[1], model[2]};
737
ASSERT_EQ(modules.size(), expected.size());
738
for (const auto i : c10::irange(expected.size())) {
739
// Assert pointer equality.
740
ASSERT_EQ(modules[i].get(), expected[i].get());
744
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
745
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
746
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
747
model->named_modules();
748
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
749
model.ptr(), model[0], model[1], model[2]};
750
ASSERT_EQ(modules.size(), expected.size());
751
for (const auto i : c10::irange(expected.size())) {
752
// Assert pointer equality.
753
ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
754
ASSERT_EQ(modules[i].value().get(), expected[i].get());
758
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
759
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
760
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
761
model->named_modules(
762
/*name_prefix=*/std::string(), /*include_self=*/false);
763
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
764
model[0], model[1], model[2]};
765
ASSERT_EQ(modules.size(), expected.size());
766
for (const auto i : c10::irange(expected.size())) {
767
// Assert pointer equality.
768
ASSERT_EQ(modules[i].key(), std::to_string(i));
769
ASSERT_EQ(modules[i].value().get(), expected[i].get());
773
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
774
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
775
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
776
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
777
model[0], model[1], model[2]};
778
ASSERT_EQ(modules.size(), expected.size());
779
for (const auto i : c10::irange(expected.size())) {
780
// Assert pointer equality.
781
ASSERT_EQ(modules[i].get(), expected[i].get());
784
// For this flat model, this should be true.
785
ASSERT_EQ(modules, model->modules(/*include_self=*/false));
788
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
789
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
790
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
791
model->named_children();
792
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
793
model[0], model[1], model[2]};
794
ASSERT_EQ(modules.size(), expected.size());
795
for (const auto i : c10::irange(expected.size())) {
796
// Assert pointer equality.
797
ASSERT_EQ(modules[i].key(), std::to_string(i));
798
ASSERT_EQ(modules[i].value().get(), expected[i].get());
802
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
803
TestModule module(1);
804
std::vector<torch::Tensor> parameters = module.parameters();
805
ASSERT_EQ(parameters.size(), 2);
806
ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
807
ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
810
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
811
TestModule module(1);
812
torch::OrderedDict<std::string, torch::Tensor> parameters =
813
module.named_parameters();
814
ASSERT_EQ(parameters.size(), 2);
815
ASSERT_EQ(parameters[0].key(), "p1");
816
ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
817
ASSERT_EQ(parameters[1].key(), "p2");
818
ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
821
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
822
TestModule module(1);
823
std::vector<torch::Tensor> buffers = module.buffers();
824
ASSERT_EQ(buffers.size(), 2);
825
ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
826
ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
829
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
830
TestModule module(1);
831
torch::OrderedDict<std::string, torch::Tensor> buffers =
832
module.named_buffers();
833
ASSERT_EQ(buffers.size(), 2);
834
ASSERT_EQ(buffers[0].key(), "b1");
835
ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
836
ASSERT_EQ(buffers[1].key(), "b2");
837
ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
840
struct TestContainer : torch::nn::Module {
841
TestContainer(int64_t number, std::vector<TestContainer> modules = {})
842
: tensor(torch::tensor(number)) {
843
for (const auto i : c10::irange(modules.size())) {
846
std::make_shared<TestContainer>(std::move(modules[i])));
849
torch::Tensor tensor;
852
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
853
return std::dynamic_pointer_cast<TestContainer>(module)
854
->tensor.item<int64_t>();
857
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
858
return std::make_shared<TestContainer>(TestContainer(
860
{TestContainer(1, {TestContainer(2), TestContainer(3)}),
865
TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
868
std::vector<std::pair<std::string, int64_t>>
869
make_key_value_pairs_for_deeply_nested_container() {
872
{"test_prefix.0", 1},
873
{"test_prefix.0.0", 2},
874
{"test_prefix.0.1", 3},
875
{"test_prefix.1", 4},
876
{"test_prefix.2", 5},
877
{"test_prefix.2.0", 6},
878
{"test_prefix.2.1", 7},
879
{"test_prefix.2.1.0", 8},
880
{"test_prefix.2.1.1", 9}};
883
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
884
auto model = make_deeply_nested_test_container();
885
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
887
ASSERT_EQ(modules.size(), 10);
888
for (const auto i : c10::irange(modules.size())) {
889
ASSERT_EQ(get_test_container_item(modules[i]), i);
893
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
894
auto model = make_deeply_nested_test_container();
895
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
896
model->named_modules(/*name_prefix=*/"test_prefix");
897
auto expected = make_key_value_pairs_for_deeply_nested_container();
899
ASSERT_EQ(modules.size(), expected.size());
901
for (const auto i : c10::irange(expected.size())) {
902
ASSERT_EQ(modules[i].key(), expected[i].first);
903
ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
907
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
908
auto model = make_deeply_nested_test_container();
909
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
911
ASSERT_EQ(modules.size(), 3);
912
ASSERT_EQ(get_test_container_item(modules[0]), 1);
913
ASSERT_EQ(get_test_container_item(modules[1]), 4);
914
ASSERT_EQ(get_test_container_item(modules[2]), 5);
917
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
918
auto model = make_deeply_nested_test_container();
919
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
920
model->named_children();
922
ASSERT_EQ(modules.size(), 3);
924
ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
925
ASSERT_EQ(modules[0].key(), "0");
927
ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
928
ASSERT_EQ(modules[1].key(), "1");
930
ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
931
ASSERT_EQ(modules[2].key(), "2");
934
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
935
auto model = make_deeply_nested_test_container();
937
model->apply([&index](torch::nn::Module& module) {
938
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
940
ASSERT_EQ(index, 10);
943
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
944
std::shared_ptr<const TestContainer> model =
945
make_deeply_nested_test_container();
947
model->apply([&index](const torch::nn::Module& module) {
948
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
950
ASSERT_EQ(index, 10);
953
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
954
auto model = make_deeply_nested_test_container();
955
auto expected = make_key_value_pairs_for_deeply_nested_container();
958
[&index, expected](const std::string& name, torch::nn::Module& module) {
959
ASSERT_EQ(name, expected[index].first);
961
module.as<TestContainer>()->tensor.item<int64_t>(),
962
expected[index++].second);
964
/*name_prefix=*/"test_prefix");
965
ASSERT_EQ(index, 10);
968
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
969
std::shared_ptr<const TestContainer> model =
970
make_deeply_nested_test_container();
971
auto expected = make_key_value_pairs_for_deeply_nested_container();
975
const std::string& name, const torch::nn::Module& module) {
976
ASSERT_EQ(name, expected[index].first);
978
module.as<const TestContainer>()->tensor.item<int64_t>(),
979
expected[index++].second);
981
/*name_prefix=*/"test_prefix");
982
ASSERT_EQ(index, 10);
985
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
986
auto model = make_deeply_nested_test_container();
988
model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
989
ASSERT_EQ(get_test_container_item(module), index++);
991
ASSERT_EQ(index, 10);
994
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
995
auto model = make_deeply_nested_test_container();
996
auto expected = make_key_value_pairs_for_deeply_nested_container();
1000
const std::string& name,
1001
const std::shared_ptr<torch::nn::Module>& module) {
1002
ASSERT_EQ(name, expected[index].first);
1003
ASSERT_EQ(get_test_container_item(module), expected[index++].second);
1005
/*name_prefix=*/"test_prefix");
1006
ASSERT_EQ(index, 10);
1009
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
1011
TestModule module(1);
1014
"It looks like you attempted to retrieve "
1015
"your top-level module as a shared_ptr")
1018
TestModule module(1);
1019
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1020
ASSERT_NO_THROW(module.modules(/*include_self=*/false));
1023
auto module = std::make_shared<TestModule>(1);
1024
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1025
ASSERT_NO_THROW(module->modules());
1029
struct EmptyModule : torch::nn::Module {};
1031
TEST_F(ModuleTest, PrettyPrint) {
1032
struct TestModule : torch::nn::Module {
1033
TestModule(int x, float y) : x_(x), y_(y) {}
1035
void pretty_print(std::ostream& stream) const override {
1036
stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
1043
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
1044
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
1047
struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
1048
int64_t forward(torch::Tensor x) {
1052
TORCH_MODULE(ModuleWithNonTensorForward);
1054
TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
1055
ModuleWithNonTensorForward m;
1056
ASSERT_EQ(m(torch::ones(123)), 123);