pytorch

Форк
0
/
module.cpp 
1057 строк · 33.9 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
5

6
#include <test/cpp/api/support.h>
7

8
using namespace torch::nn;
9
using namespace torch::test;
10

11
struct AGIUnit : torch::nn::Module {};
12

13
namespace test {
14
struct AGIUnit : torch::nn::Module {};
15
struct AGIUnit2 : torch::nn::Module {
16
  AGIUnit2() : torch::nn::Module("Foo") {}
17
};
18
} // namespace test
19

20
struct ModuleTest : torch::test::SeedingFixture {};
21

22
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
23
  Linear module(3, 4);
24
  ASSERT_TRUE(module->is_training());
25

26
  module->eval();
27
  ASSERT_FALSE(module->is_training());
28

29
  module->train();
30
  ASSERT_TRUE(module->is_training());
31
}
32

33
TEST_F(ModuleTest, ZeroGrad) {
34
  Linear module(3, 4);
35
  auto weight = torch::ones({8, 3}, torch::requires_grad());
36
  auto loss = module(weight).sum();
37
  loss.backward();
38
  for (auto& parameter : module->parameters()) {
39
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
40
    auto grad = parameter.grad();
41
    ASSERT_TRUE(grad.defined());
42
    ASSERT_NE(grad.sum().item<float>(), 0);
43
  }
44
  module->zero_grad();
45
  for (auto& parameter : module->parameters()) {
46
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
47
    auto grad = parameter.grad();
48
    ASSERT_FALSE(grad.defined());
49
  }
50
}
51

52
TEST_F(ModuleTest, ZeroGradWithUndefined) {
53
  struct TestModule : torch::nn::Module {
54
    TestModule() {
55
      x = register_parameter("x", torch::ones(5, torch::requires_grad()));
56
      y = register_parameter("y", torch::ones(5, torch::requires_grad()));
57
    }
58
    torch::Tensor x, y;
59
  };
60

61
  TestModule module;
62
  auto z = module.x * 2;
63
  z.sum().backward();
64

65
  ASSERT_TRUE(module.x.grad().defined());
66
  ASSERT_FALSE(module.y.grad().defined());
67

68
  module.zero_grad(false); // set_to_none = false
69

70
  ASSERT_TRUE(module.x.grad().defined());
71
  ASSERT_FALSE(module.y.grad().defined());
72

73
  ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
74

75
  module.zero_grad();
76

77
  ASSERT_FALSE(module.x.grad().defined());
78
  ASSERT_FALSE(module.y.grad().defined());
79
}
80

81
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
82
  struct TestModel : public torch::nn::Module {};
83
  ASSERT_THROWS_WITH(
84
      TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85
      "Submodule name must not contain a dot (got 'name.with.dot')");
86
  ASSERT_THROWS_WITH(
87
      TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88
      "Submodule name must not be empty");
89
}
90

91
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92
  struct TestModel : public torch::nn::Module {};
93
  TestModel model;
94
  model.register_module("linear", torch::nn::Linear(3, 4));
95
  ASSERT_THROWS_WITH(
96
      model.register_module("linear", torch::nn::Linear(3, 4)),
97
      "Submodule 'linear' already defined");
98
}
99

100
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
101
  torch::nn::Module model;
102
  ASSERT_THROWS_WITH(
103
      model.replace_module("linear", torch::nn::Linear(3, 4)),
104
      "Submodule 'linear' is not defined");
105
}
106

107
TEST_F(ModuleTest, ReplaceModule) {
108
  struct TestModel : public torch::nn::Module {
109
    torch::nn::Linear l1{nullptr};
110
    TestModel() {
111
      l1 = register_module("l1", torch::nn::Linear(3, 4));
112
    }
113
  };
114
  auto model = std::make_shared<TestModel>();
115
  model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
116
  ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
117
  ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
118
}
119

120
TEST_F(ModuleTest, UnregisterModule) {
121
  struct TestModel : public torch::nn::Module {};
122
  TestModel model;
123
  ASSERT_THROWS_WITH(
124
      model.unregister_module("linear"),
125
      "No Module with name `linear` is registered");
126
  model.register_module("linear", torch::nn::Linear(3, 4));
127
  model.unregister_module("linear");
128
  ASSERT_TRUE(model.children().empty());
129
}
130

131
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
132
  struct TestModel : public torch::nn::Module {};
133
  ASSERT_THROWS_WITH(
134
      TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
135
      "Parameter name must not contain a dot (got 'name.with.dot')");
136
  ASSERT_THROWS_WITH(
137
      TestModel{}.register_parameter("", torch::ones(5)),
138
      "Parameter name must not be empty");
139
}
140

141
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
142
  struct TestModel : public torch::nn::Module {};
143
  TestModel model;
144
  model.register_parameter("p", torch::ones(5));
145
  ASSERT_THROWS_WITH(
146
      model.register_parameter("p", torch::ones(5)),
147
      "Parameter 'p' already defined");
148
}
149

150
TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
151
  struct TestModel : public torch::nn::Module {};
152
  {
153
    TestModel model;
154
    model.register_parameter(
155
        "undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
156
    ASSERT_EQ(model.parameters().size(), 0);
157
  }
158
  {
159
    WarningCapture warnings;
160

161
    TestModel model;
162
    model.register_parameter("undefined_tensor", torch::Tensor());
163
    ASSERT_EQ(model.parameters().size(), 0);
164

165
    ASSERT_EQ(
166
        count_substr_occurrences(
167
            warnings.str(),
168
            "Ignoring the `requires_grad=true` function parameter"),
169
        1);
170
  }
171
}
172

173
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
174
  struct TestModel : public torch::nn::Module {};
175
  ASSERT_THROWS_WITH(
176
      TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
177
      "Buffer name must not contain a dot (got 'name.with.dot')");
178
  ASSERT_THROWS_WITH(
179
      TestModel{}.register_buffer("", torch::ones(5)),
180
      "Buffer name must not be empty");
181
}
182

183
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
184
  struct TestModel : public torch::nn::Module {};
185
  TestModel model;
186
  model.register_buffer("p", torch::ones(5));
187
  ASSERT_THROWS_WITH(
188
      model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
189
}
190

191
TEST_F(ModuleTest, CanGetName) {
192
  // CHECK instead of REQUIRE because demangling may fail.
193
  AGIUnit agi;
194
  // Call it twice just to make sure there are no bugs in the lazy
195
  // initialization semantics.
196
  EXPECT_EQ(agi.name(), "AGIUnit");
197
  EXPECT_EQ(agi.name(), "AGIUnit");
198
  EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
199
  EXPECT_EQ(test::AGIUnit2().name(), "Foo");
200
}
201

202
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
203
  Linear module(3, 4);
204
  ASSERT_EQ(module->as<Linear>(), module.get());
205
  ASSERT_EQ(module->as<LinearImpl>(), module.get());
206
  ASSERT_EQ(module->as<Module>(), module.get());
207
  ASSERT_EQ(module->as<AGIUnit>(), nullptr);
208

209
  std::shared_ptr<Module> raw = module.ptr();
210
  ASSERT_EQ(raw->as<Linear>(), module.get());
211
  ASSERT_EQ(raw->as<LinearImpl>(), module.get());
212
  ASSERT_EQ(raw->as<Module>(), module.get());
213
  ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
214

215
  Module& raw_ref = *raw.get();
216
  ASSERT_EQ(raw_ref.as<Linear>(), module.get());
217
  ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
218
  ASSERT_EQ(raw_ref.as<Module>(), module.get());
219
  ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
220
  if (auto* linear = raw_ref.as<Linear>()) {
221
    ASSERT_EQ(linear->weight.ndimension(), 2);
222
  }
223

224
  AGIUnit unit;
225
  ASSERT_EQ(unit.as<Linear>(), nullptr);
226
  ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
227
  ASSERT_EQ(unit.as<AGIUnit>(), &unit);
228
}
229

230
void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
231
    torch::Device to_device,
232
    torch::Dtype to_dtype) {
233
  {
234
    // Case 1: Undefined tensors as parameters
235
    Linear module(LinearOptions(10, 20).bias(false));
236
    ASSERT_TRUE(module->weight.defined());
237
    ASSERT_FALSE(module->bias.defined());
238

239
    module->to(to_device);
240
    ASSERT_TRUE(module->weight.defined());
241
    ASSERT_EQ(module->weight.device().type(), to_device.type());
242
    ASSERT_FALSE(module->bias.defined());
243

244
    module->to(to_dtype);
245
    ASSERT_TRUE(module->weight.defined());
246
    ASSERT_EQ(module->weight.dtype(), to_dtype);
247
    ASSERT_FALSE(module->bias.defined());
248
  }
249
  {
250
    // Case 2: Undefined tensors as buffers
251
    BatchNorm1d module(
252
        BatchNorm1dOptions(5).track_running_stats(false).affine(true));
253
    ASSERT_TRUE(module->weight.defined());
254
    ASSERT_FALSE(module->running_mean.defined());
255

256
    module->to(to_device);
257
    ASSERT_TRUE(module->weight.defined());
258
    ASSERT_EQ(module->weight.device().type(), to_device.type());
259
    ASSERT_FALSE(module->running_mean.defined());
260

261
    module->to(to_dtype);
262
    ASSERT_TRUE(module->weight.defined());
263
    ASSERT_EQ(module->weight.dtype(), to_dtype);
264
    ASSERT_FALSE(module->running_mean.defined());
265
  }
266
}
267

268
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
269
  test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
270
}
271

272
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
273
  test_DeviceOrDtypeConversionSkipsUndefinedTensor(
274
      torch::kCUDA, torch::kDouble);
275
}
276

277
TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
278
  {
279
    Linear module(LinearOptions(10, 20).bias(false));
280

281
    auto params = module->parameters();
282
    ASSERT_EQ(params.size(), 1);
283
    auto named_params = module->named_parameters();
284
    ASSERT_EQ(named_params.size(), 1);
285

286
    ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
287
    ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
288
  }
289
  {
290
    BatchNorm1d module(
291
        BatchNorm1dOptions(5).track_running_stats(false).affine(false));
292

293
    auto buffers = module->buffers();
294
    ASSERT_EQ(buffers.size(), 0);
295
    auto named_buffers = module->named_buffers();
296
    ASSERT_EQ(named_buffers.size(), 0);
297
  }
298
  {
299
    BatchNorm1d module(
300
        BatchNorm1dOptions(5).track_running_stats(true).affine(false));
301

302
    auto buffers = module->buffers();
303
    ASSERT_EQ(buffers.size(), 3);
304
    auto named_buffers = module->named_buffers();
305
    ASSERT_EQ(named_buffers.size(), 3);
306

307
    ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
308
    ASSERT_TRUE(
309
        pointer_equal(named_buffers["running_mean"], module->running_mean));
310
    ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
311
    ASSERT_TRUE(
312
        pointer_equal(named_buffers["running_var"], module->running_var));
313
    ASSERT_TRUE(
314
        pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
315
    ASSERT_TRUE(pointer_equal(
316
        named_buffers["num_batches_tracked"], module->num_batches_tracked));
317
  }
318
}
319

320
TEST_F(ModuleTest, Conversion_MultiCUDA) {
321
  Linear module(128, 64);
322
  for (auto& parameter : module->parameters()) {
323
    ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
324
    ASSERT_EQ(parameter.dtype(), torch::kFloat32);
325
  }
326
  {
327
    module->to({torch::kCUDA, 0});
328
    for (auto& parameter : module->parameters()) {
329
      ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
330
      ASSERT_EQ(parameter.device().index(), 0);
331
    }
332
    module->to({torch::kCUDA, 1});
333
    for (auto& parameter : module->parameters()) {
334
      ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
335
      ASSERT_EQ(parameter.device().index(), 1);
336
    }
337
  }
338
  {
339
    module->to(torch::Device(torch::kCPU));
340
    for (auto& parameter : module->parameters()) {
341
      ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
342
    }
343
  }
344
  {
345
    module->to(torch::kFloat64);
346
    for (auto& parameter : module->parameters()) {
347
      ASSERT_EQ(parameter.dtype(), torch::kFloat64);
348
    }
349
  }
350
}
351

352
TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
353
  Linear module(128, 64);
354
  for (auto& parameter : module->parameters()) {
355
    parameter.requires_grad_(false);
356
  }
357
  {
358
    module->to(torch::kInt32);
359
    for (auto& parameter : module->parameters()) {
360
      ASSERT_EQ(parameter.dtype(), torch::kInt32);
361
    }
362
  }
363
  {
364
    module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
365
    for (auto& parameter : module->parameters()) {
366
      ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
367
      ASSERT_EQ(parameter.device().index(), 1);
368
    }
369
    for (auto& parameter : module->parameters()) {
370
      ASSERT_EQ(parameter.dtype(), torch::kUInt8);
371
    }
372
  }
373
}
374

375
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
376
  struct UnCloneable : Module {};
377
  UnCloneable module;
378
  ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
379
}
380

381
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
382
  struct Cloneable : Module {
383
    std::shared_ptr<Module> clone(
384
        const torch::optional<torch::Device>& device =
385
            torch::nullopt) const override {
386
      return nullptr;
387
    }
388
  };
389
  Cloneable module;
390
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
391
  ASSERT_NO_THROW({ module.clone(); });
392
}
393

394
// NOLINTNEXTLINE(bugprone-exception-escape)
395
struct TestDistinctParametersModule
396
    : public Cloneable<TestDistinctParametersModule> {
397
  TestDistinctParametersModule() {
398
    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
399
    reset();
400
  }
401
  void reset() override {
402
    l1 = register_module("l1", Linear(10, 3));
403
    l2 = register_module("l2", Linear(3, 5));
404
    l3 = register_module("l3", Linear(5, 100));
405
    buffer = register_buffer("buf", torch::ones({2, 2}));
406
  }
407

408
  Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
409
  torch::Tensor buffer;
410
};
411

412
void testDistinctParameters(
413
    std::shared_ptr<Module> m1,
414
    std::shared_ptr<Module> m2) {
415
  auto params1 = m1->named_parameters();
416
  auto params2 = m2->named_parameters();
417
  ASSERT_EQ(params1.size(), 6);
418
  ASSERT_EQ(params2.size(), 6);
419
  for (auto& param : params1) {
420
    ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
421
    ASSERT_TRUE(param->allclose(params2[param.key()]));
422
    param->add_(2);
423
  }
424
  for (auto& param : params1) {
425
    ASSERT_FALSE(param->allclose(params2[param.key()]));
426
  }
427

428
  auto buffers1 = m1->named_buffers();
429
  auto buffers2 = m2->named_buffers();
430
  ASSERT_EQ(buffers1.size(), 1);
431
  ASSERT_EQ(buffers2.size(), 1);
432
  for (auto& buffer : buffers1) {
433
    ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
434
    ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
435
    buffer->add_(2);
436
  }
437
  for (auto& buffer : buffers1) {
438
    ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
439
  }
440
}
441

442
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
443
  auto module = std::make_shared<TestDistinctParametersModule>();
444
  torch::NoGradGuard no_grad;
445
  auto module2 = module->clone();
446
  testDistinctParameters(module, module2);
447
}
448

449
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
450
  auto module = std::make_shared<TestDistinctParametersModule>();
451
  torch::NoGradGuard no_grad;
452
  torch::Device device(torch::kCUDA, 0);
453
  module->to(device);
454
  auto module2 = module->clone(device);
455
  testDistinctParameters(module, module2);
456
}
457

458
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
459
  auto module = std::make_shared<TestDistinctParametersModule>();
460
  torch::NoGradGuard no_grad;
461
  torch::Device d0(torch::kCUDA, 0);
462
  torch::Device d1(torch::kCUDA, 1);
463
  module->to(d0);
464
  auto module2 = module->clone(d1);
465

466
  for (auto& param : module->parameters()) {
467
    ASSERT_EQ(param.device(), d0);
468
  }
469

470
  for (auto& param : module2->parameters()) {
471
    ASSERT_EQ(param.device(), d1);
472
  }
473

474
  // need to move the module back to d0 as allclose expects two tensors on
475
  // the same device.
476
  module2->to(d0);
477
  testDistinctParameters(module, module2);
478
}
479

480
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
481
  // NOLINTNEXTLINE(bugprone-exception-escape)
482
  struct TestModule : public Cloneable<TestModule> {
483
    TestModule() {
484
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
485
      reset();
486
    }
487
    void reset() override {
488
      weight = register_parameter("weight", torch::ones({4, 4}));
489
    }
490
    torch::Tensor weight;
491
  };
492
  auto module = std::make_shared<TestModule>();
493
  {
494
    torch::NoGradGuard no_grad;
495
    module->weight += 1;
496
  }
497
  ASSERT_TRUE(
498
      pointer_equal(module->weight, module->named_parameters()["weight"]));
499
  ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
500

501
  auto module2 = std::dynamic_pointer_cast<TestModule>(
502
      std::shared_ptr<Module>(module->clone()));
503
  ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
504
  ASSERT_TRUE(
505
      pointer_equal(module2->weight, module2->named_parameters()["weight"]));
506
  ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
507
  ASSERT_TRUE(module2->weight.allclose(module->weight));
508
  ASSERT_FALSE(
509
      pointer_equal(module2->weight, module->named_parameters()["weight"]));
510
}
511

512
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
513
  // NOLINTNEXTLINE(bugprone-exception-escape)
514
  struct TestModule : public Cloneable<TestModule> {
515
    TestModule() {
516
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
517
      reset();
518
    }
519
    void reset() override {
520
      weight = register_parameter("weight", torch::ones({4, 4}));
521
    }
522

523
    torch::Tensor weight;
524
    int value = 0;
525
  };
526
  // NOLINTNEXTLINE(bugprone-exception-escape)
527
  struct NestedModule : public Cloneable<NestedModule> {
528
    NestedModule() {
529
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
530
      reset();
531
    }
532
    void reset() override {
533
      module = register_module("module", std::make_shared<TestModule>());
534
    }
535
    std::shared_ptr<TestModule> module;
536
  };
537

538
  auto a = std::make_shared<NestedModule>();
539
  {
540
    torch::NoGradGuard no_grad;
541
    a->module->weight += 1;
542
    a->module->value = 123;
543
  }
544

545
  auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
546

547
  ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
548
  ASSERT_TRUE(pointer_equal(
549
      b->module->weight, b->module->named_parameters()["weight"]));
550
  ASSERT_TRUE(
551
      b->module->named_parameters()["weight"].allclose(a->module->weight));
552
  ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
553
  ASSERT_EQ(b->module->value, a->module->value);
554
}
555

556
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
557
  // NOLINTNEXTLINE(bugprone-exception-escape)
558
  struct TestModule : public Cloneable<TestModule> {
559
    TestModule() {
560
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
561
      reset();
562
    }
563
    void reset() override {
564
      l1 = register_module("l1", Linear(10, 3));
565
      l2 = register_module("l2", Linear(3, 5));
566
      l3 = register_module("l3", Linear(5, 100));
567
      buffer = register_buffer("buf", torch::ones({2, 2}));
568
    }
569

570
    Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
571
    torch::Tensor buffer;
572
  };
573

574
  TestModule m;
575
  torch::Device device(torch::kCUDA, 0);
576

577
  m.to(device);
578

579
  auto clone = m.clone();
580
  for (const auto& parameter : clone->parameters()) {
581
    ASSERT_EQ(parameter.device().type(), device.type());
582
    ASSERT_EQ(parameter.device().index(), device.index());
583
  }
584
  for (const auto& buffer : clone->buffers()) {
585
    ASSERT_EQ(buffer.device().type(), device.type());
586
    ASSERT_EQ(buffer.device().index(), device.index());
587
  }
588
}
589

590
TEST_F(
591
    ModuleTest,
592
    CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
593
  // NOLINTNEXTLINE(bugprone-exception-escape)
594
  struct TestModule : public Cloneable<TestModule> {
595
    TestModule() {
596
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
597
      reset();
598
    }
599
    void reset() override {
600
      l1 = register_module("l1", Linear(10, 3));
601
      l2 = register_module("l2", Linear(3, 5));
602
      l3 = register_module("l3", Linear(5, 100));
603
      buffer = register_buffer("buf", torch::ones({2, 2}));
604
    }
605

606
    Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
607
    torch::Tensor buffer;
608
  };
609

610
  TestModule m;
611
  torch::Device device(torch::kCUDA, 1);
612
  // everything is on CPU here
613
  auto clone = m.clone(device);
614
  for (const auto& parameter : clone->parameters()) {
615
    ASSERT_EQ(parameter.device().type(), device.type());
616
    ASSERT_EQ(parameter.device().index(), device.index());
617
  }
618
  for (const auto& buffer : clone->buffers()) {
619
    ASSERT_EQ(buffer.device().type(), device.type());
620
    ASSERT_EQ(buffer.device().index(), device.index());
621
  }
622
}
623

624
struct ParameterTestModule : Module {
625
  ParameterTestModule() {
626
    a = register_parameter("a", torch::zeros({2, 2}));
627
    b = register_parameter("b", torch::ones({2, 2}));
628
    c = register_parameter("c", torch::ones({2, 2}) * 2);
629
  }
630

631
  torch::Tensor a, b, c;
632
};
633

634
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
635
  ParameterTestModule module;
636
  ASSERT_EQ(module.parameters().size(), 3);
637
  ASSERT_EQ(module.named_parameters().size(), 3);
638
}
639

640
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
641
  ParameterTestModule module;
642
  auto parameters = module.named_parameters();
643
  ASSERT_TRUE(parameters.contains("a"));
644
  ASSERT_TRUE(parameters.contains("b"));
645
  ASSERT_TRUE(parameters.contains("c"));
646
}
647

648
struct BufferTestModule : Module {
649
  BufferTestModule() {
650
    a = register_buffer("a", torch::zeros({2, 2}));
651
    b = register_buffer("b", torch::ones({2, 2}));
652
    c = register_buffer("c", torch::ones({2, 2}) * 2);
653
  }
654

655
  torch::Tensor a, b, c;
656
};
657

658
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
659
  BufferTestModule module;
660
  ASSERT_EQ(module.buffers().size(), 3);
661
  ASSERT_EQ(module.named_buffers().size(), 3);
662
}
663

664
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
665
  BufferTestModule module;
666
  auto buffers = module.named_buffers();
667
  ASSERT_TRUE(buffers.contains("a"));
668
  ASSERT_TRUE(buffers.contains("b"));
669
  ASSERT_TRUE(buffers.contains("c"));
670
}
671

672
struct AImpl : torch::nn::Module {
673
  AImpl() : x_(123) {}
674
  AImpl(int x) : x_(x) {}
675
  int x_;
676
};
677
TORCH_MODULE(A);
678

679
TEST_F(
680
    ModuleTest,
681
    DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
682
  A a;
683
  ASSERT_TRUE(a);
684
  ASSERT_FALSE(a.is_empty());
685
  ASSERT_EQ(a->x_, 123);
686
}
687

688
TEST_F(
689
    ModuleTest,
690
    ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
691
  A a(5);
692
  ASSERT_TRUE(a);
693
  ASSERT_FALSE(a.is_empty());
694
  ASSERT_EQ(a->x_, 5);
695
}
696

697
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
698
  A a = nullptr;
699
  ASSERT_FALSE(a);
700
  ASSERT_TRUE(a.is_empty());
701
  ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
702
}
703

704
struct TestModule : public torch::nn::Module {
705
  TestModule(int64_t size) {
706
    p1 = register_parameter("p1", torch::randn({size}));
707
    p2 = register_parameter("p2", torch::randn({size}));
708
    b1 = register_buffer("b1", torch::randn({size}));
709
    b2 = register_buffer("b2", torch::randn({size}));
710
  }
711

712
  torch::Tensor forward(torch::Tensor input) {
713
    return input;
714
  }
715

716
  torch::Tensor p1, p2, b1, b2;
717
};
718

719
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
720
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
721
  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
722
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
723
      model.ptr(), model[0], model[1], model[2]};
724
  ASSERT_EQ(modules.size(), expected.size());
725
  for (const auto i : c10::irange(expected.size())) {
726
    // Assert pointer equality.
727
    ASSERT_EQ(modules[i].get(), expected[i].get());
728
  }
729
}
730

731
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
732
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
733
  std::vector<std::shared_ptr<torch::nn::Module>> modules =
734
      model->modules(/*include_self=*/false);
735
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
736
      model[0], model[1], model[2]};
737
  ASSERT_EQ(modules.size(), expected.size());
738
  for (const auto i : c10::irange(expected.size())) {
739
    // Assert pointer equality.
740
    ASSERT_EQ(modules[i].get(), expected[i].get());
741
  }
742
}
743

744
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
745
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
746
  torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
747
      model->named_modules();
748
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
749
      model.ptr(), model[0], model[1], model[2]};
750
  ASSERT_EQ(modules.size(), expected.size());
751
  for (const auto i : c10::irange(expected.size())) {
752
    // Assert pointer equality.
753
    ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
754
    ASSERT_EQ(modules[i].value().get(), expected[i].get());
755
  }
756
}
757

758
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
759
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
760
  torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
761
      model->named_modules(
762
          /*name_prefix=*/std::string(), /*include_self=*/false);
763
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
764
      model[0], model[1], model[2]};
765
  ASSERT_EQ(modules.size(), expected.size());
766
  for (const auto i : c10::irange(expected.size())) {
767
    // Assert pointer equality.
768
    ASSERT_EQ(modules[i].key(), std::to_string(i));
769
    ASSERT_EQ(modules[i].value().get(), expected[i].get());
770
  }
771
}
772

773
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
774
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
775
  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
776
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
777
      model[0], model[1], model[2]};
778
  ASSERT_EQ(modules.size(), expected.size());
779
  for (const auto i : c10::irange(expected.size())) {
780
    // Assert pointer equality.
781
    ASSERT_EQ(modules[i].get(), expected[i].get());
782
  }
783

784
  // For this flat model, this should be true.
785
  ASSERT_EQ(modules, model->modules(/*include_self=*/false));
786
}
787

788
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
789
  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
790
  torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
791
      model->named_children();
792
  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
793
      model[0], model[1], model[2]};
794
  ASSERT_EQ(modules.size(), expected.size());
795
  for (const auto i : c10::irange(expected.size())) {
796
    // Assert pointer equality.
797
    ASSERT_EQ(modules[i].key(), std::to_string(i));
798
    ASSERT_EQ(modules[i].value().get(), expected[i].get());
799
  }
800
}
801

802
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
803
  TestModule module(1);
804
  std::vector<torch::Tensor> parameters = module.parameters();
805
  ASSERT_EQ(parameters.size(), 2);
806
  ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
807
  ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
808
}
809

810
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
811
  TestModule module(1);
812
  torch::OrderedDict<std::string, torch::Tensor> parameters =
813
      module.named_parameters();
814
  ASSERT_EQ(parameters.size(), 2);
815
  ASSERT_EQ(parameters[0].key(), "p1");
816
  ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
817
  ASSERT_EQ(parameters[1].key(), "p2");
818
  ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
819
}
820

821
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
822
  TestModule module(1);
823
  std::vector<torch::Tensor> buffers = module.buffers();
824
  ASSERT_EQ(buffers.size(), 2);
825
  ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
826
  ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
827
}
828

829
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
830
  TestModule module(1);
831
  torch::OrderedDict<std::string, torch::Tensor> buffers =
832
      module.named_buffers();
833
  ASSERT_EQ(buffers.size(), 2);
834
  ASSERT_EQ(buffers[0].key(), "b1");
835
  ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
836
  ASSERT_EQ(buffers[1].key(), "b2");
837
  ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
838
}
839

840
struct TestContainer : torch::nn::Module {
841
  TestContainer(int64_t number, std::vector<TestContainer> modules = {})
842
      : tensor(torch::tensor(number)) {
843
    for (const auto i : c10::irange(modules.size())) {
844
      register_module(
845
          std::to_string(i),
846
          std::make_shared<TestContainer>(std::move(modules[i])));
847
    }
848
  }
849
  torch::Tensor tensor;
850
};
851

852
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
853
  return std::dynamic_pointer_cast<TestContainer>(module)
854
      ->tensor.item<int64_t>();
855
}
856

857
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
858
  return std::make_shared<TestContainer>(TestContainer(
859
      0,
860
      {TestContainer(1, {TestContainer(2), TestContainer(3)}),
861
       TestContainer(4),
862
       TestContainer(
863
           5,
864
           {TestContainer(6),
865
            TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
866
}
867

868
std::vector<std::pair<std::string, int64_t>>
869
make_key_value_pairs_for_deeply_nested_container() {
870
  return {
871
      {"test_prefix", 0},
872
      {"test_prefix.0", 1},
873
      {"test_prefix.0.0", 2},
874
      {"test_prefix.0.1", 3},
875
      {"test_prefix.1", 4},
876
      {"test_prefix.2", 5},
877
      {"test_prefix.2.0", 6},
878
      {"test_prefix.2.1", 7},
879
      {"test_prefix.2.1.0", 8},
880
      {"test_prefix.2.1.1", 9}};
881
}
882

883
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
884
  auto model = make_deeply_nested_test_container();
885
  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
886

887
  ASSERT_EQ(modules.size(), 10);
888
  for (const auto i : c10::irange(modules.size())) {
889
    ASSERT_EQ(get_test_container_item(modules[i]), i);
890
  }
891
}
892

893
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
894
  auto model = make_deeply_nested_test_container();
895
  torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
896
      model->named_modules(/*name_prefix=*/"test_prefix");
897
  auto expected = make_key_value_pairs_for_deeply_nested_container();
898

899
  ASSERT_EQ(modules.size(), expected.size());
900

901
  for (const auto i : c10::irange(expected.size())) {
902
    ASSERT_EQ(modules[i].key(), expected[i].first);
903
    ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
904
  }
905
}
906

907
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
908
  auto model = make_deeply_nested_test_container();
909
  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
910

911
  ASSERT_EQ(modules.size(), 3);
912
  ASSERT_EQ(get_test_container_item(modules[0]), 1);
913
  ASSERT_EQ(get_test_container_item(modules[1]), 4);
914
  ASSERT_EQ(get_test_container_item(modules[2]), 5);
915
}
916

917
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
918
  auto model = make_deeply_nested_test_container();
919
  torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
920
      model->named_children();
921

922
  ASSERT_EQ(modules.size(), 3);
923

924
  ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
925
  ASSERT_EQ(modules[0].key(), "0");
926

927
  ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
928
  ASSERT_EQ(modules[1].key(), "1");
929

930
  ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
931
  ASSERT_EQ(modules[2].key(), "2");
932
}
933

934
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
935
  auto model = make_deeply_nested_test_container();
936
  int64_t index = 0;
937
  model->apply([&index](torch::nn::Module& module) {
938
    ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
939
  });
940
  ASSERT_EQ(index, 10);
941
}
942

943
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
944
  std::shared_ptr<const TestContainer> model =
945
      make_deeply_nested_test_container();
946
  int64_t index = 0;
947
  model->apply([&index](const torch::nn::Module& module) {
948
    ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
949
  });
950
  ASSERT_EQ(index, 10);
951
}
952

953
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
954
  auto model = make_deeply_nested_test_container();
955
  auto expected = make_key_value_pairs_for_deeply_nested_container();
956
  int64_t index = 0;
957
  model->apply(
958
      [&index, expected](const std::string& name, torch::nn::Module& module) {
959
        ASSERT_EQ(name, expected[index].first);
960
        ASSERT_EQ(
961
            module.as<TestContainer>()->tensor.item<int64_t>(),
962
            expected[index++].second);
963
      },
964
      /*name_prefix=*/"test_prefix");
965
  ASSERT_EQ(index, 10);
966
}
967

968
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
969
  std::shared_ptr<const TestContainer> model =
970
      make_deeply_nested_test_container();
971
  auto expected = make_key_value_pairs_for_deeply_nested_container();
972
  int64_t index = 0;
973
  model->apply(
974
      [&index, &expected](
975
          const std::string& name, const torch::nn::Module& module) {
976
        ASSERT_EQ(name, expected[index].first);
977
        ASSERT_EQ(
978
            module.as<const TestContainer>()->tensor.item<int64_t>(),
979
            expected[index++].second);
980
      },
981
      /*name_prefix=*/"test_prefix");
982
  ASSERT_EQ(index, 10);
983
}
984

985
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
986
  auto model = make_deeply_nested_test_container();
987
  int64_t index = 0;
988
  model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
989
    ASSERT_EQ(get_test_container_item(module), index++);
990
  });
991
  ASSERT_EQ(index, 10);
992
}
993

994
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
995
  auto model = make_deeply_nested_test_container();
996
  auto expected = make_key_value_pairs_for_deeply_nested_container();
997
  int64_t index = 0;
998
  model->apply(
999
      [&index, &expected](
1000
          const std::string& name,
1001
          const std::shared_ptr<torch::nn::Module>& module) {
1002
        ASSERT_EQ(name, expected[index].first);
1003
        ASSERT_EQ(get_test_container_item(module), expected[index++].second);
1004
      },
1005
      /*name_prefix=*/"test_prefix");
1006
  ASSERT_EQ(index, 10);
1007
}
1008

1009
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
1010
  {
1011
    TestModule module(1);
1012
    ASSERT_THROWS_WITH(
1013
        module.modules(),
1014
        "It looks like you attempted to retrieve "
1015
        "your top-level module as a shared_ptr")
1016
  }
1017
  {
1018
    TestModule module(1);
1019
    // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1020
    ASSERT_NO_THROW(module.modules(/*include_self=*/false));
1021
  }
1022
  {
1023
    auto module = std::make_shared<TestModule>(1);
1024
    // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1025
    ASSERT_NO_THROW(module->modules());
1026
  }
1027
}
1028

1029
struct EmptyModule : torch::nn::Module {};
1030

1031
TEST_F(ModuleTest, PrettyPrint) {
1032
  struct TestModule : torch::nn::Module {
1033
    TestModule(int x, float y) : x_(x), y_(y) {}
1034

1035
    void pretty_print(std::ostream& stream) const override {
1036
      stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
1037
    }
1038

1039
    int x_;
1040
    float y_;
1041
  };
1042

1043
  ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
1044
  ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
1045
}
1046

1047
struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
1048
  int64_t forward(torch::Tensor x) {
1049
    return x.numel();
1050
  }
1051
};
1052
TORCH_MODULE(ModuleWithNonTensorForward);
1053

1054
TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
1055
  ModuleWithNonTensorForward m;
1056
  ASSERT_EQ(m(torch::ones(123)), 123);
1057
}
1058

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.