pytorch

Форк
0
/
sequential.cpp 
673 строки · 22.7 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
5

6
#include <algorithm>
7
#include <memory>
8
#include <vector>
9

10
#include <test/cpp/api/support.h>
11

12
using namespace torch::nn;
13
using namespace torch::test;
14

15
struct SequentialTest : torch::test::SeedingFixture {};
16

17
TEST_F(SequentialTest, CanContainThings) {
18
  Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
19
}
20

21
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
22
  struct M : torch::nn::Module {
23
    explicit M(int value_) : value(value_) {}
24
    int value;
25
    int forward() {
26
      return value;
27
    }
28
  };
29
  Sequential sequential(
30
      std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
31
  ASSERT_EQ(sequential->size(), 3);
32

33
  Sequential sequential_named(
34
      {{"m1", std::make_shared<M>(1)},
35
       {std::string("m2"), std::make_shared<M>(2)},
36
       {"m3", std::make_shared<M>(3)}});
37
  ASSERT_EQ(sequential->size(), 3);
38
}
39

40
TEST_F(SequentialTest, ConstructsFromConcreteType) {
41
  static int copy_count;
42

43
  struct M : torch::nn::Module {
44
    explicit M(int value_) : value(value_) {}
45
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
46
    M(const M& other) : torch::nn::Module(other) {
47
      copy_count++;
48
    }
49
    int value;
50
    int forward() {
51
      return value;
52
    }
53
  };
54

55
  copy_count = 0;
56
  Sequential sequential(M(1), M(2), M(3));
57
  ASSERT_EQ(sequential->size(), 3);
58
  // NOTE: The current implementation expects each module to be copied exactly
59
  // once, which happens when the module is passed into `std::make_shared<T>()`.
60
  // TODO: Find a way to avoid copying, and then delete the copy constructor of
61
  // `M`.
62
  ASSERT_EQ(copy_count, 3);
63

64
  copy_count = 0;
65
  Sequential sequential_named(
66
      {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
67
  ASSERT_EQ(sequential->size(), 3);
68
  ASSERT_EQ(copy_count, 3);
69
}
70

71
TEST_F(SequentialTest, ConstructsFromModuleHolder) {
72
  struct MImpl : torch::nn::Module {
73
    explicit MImpl(int value_) : value(value_) {}
74
    int forward() {
75
      return value;
76
    }
77
    int value;
78
  };
79

80
  struct M : torch::nn::ModuleHolder<MImpl> {
81
    using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
82
    using torch::nn::ModuleHolder<MImpl>::get;
83
  };
84

85
  Sequential sequential(M(1), M(2), M(3));
86
  ASSERT_EQ(sequential->size(), 3);
87

88
  Sequential sequential_named(
89
      {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
90
  ASSERT_EQ(sequential->size(), 3);
91
}
92

93
TEST_F(SequentialTest, PushBackAddsAnElement) {
94
  struct M : torch::nn::Module {
95
    explicit M(int value_) : value(value_) {}
96
    int forward() {
97
      return value;
98
    }
99
    int value;
100
  };
101

102
  // Test unnamed submodules
103
  Sequential sequential;
104
  ASSERT_EQ(sequential->size(), 0);
105
  ASSERT_TRUE(sequential->is_empty());
106
  sequential->push_back(Linear(3, 4));
107
  ASSERT_EQ(sequential->size(), 1);
108
  sequential->push_back(std::make_shared<M>(1));
109
  ASSERT_EQ(sequential->size(), 2);
110
  sequential->push_back(M(2));
111
  ASSERT_EQ(sequential->size(), 3);
112

113
  // Mix named and unnamed submodules
114
  Sequential sequential_named;
115
  ASSERT_EQ(sequential_named->size(), 0);
116
  ASSERT_TRUE(sequential_named->is_empty());
117

118
  sequential_named->push_back(Linear(3, 4));
119
  ASSERT_EQ(sequential_named->size(), 1);
120
  ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
121
  sequential_named->push_back(std::string("linear2"), Linear(3, 4));
122
  ASSERT_EQ(sequential_named->size(), 2);
123
  ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
124

125
  sequential_named->push_back("shared_m1", std::make_shared<M>(1));
126
  ASSERT_EQ(sequential_named->size(), 3);
127
  ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
128
  sequential_named->push_back(std::make_shared<M>(1));
129
  ASSERT_EQ(sequential_named->size(), 4);
130
  ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
131

132
  sequential_named->push_back(M(1));
133
  ASSERT_EQ(sequential_named->size(), 5);
134
  ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
135
  sequential_named->push_back(std::string("m2"), M(1));
136
  ASSERT_EQ(sequential_named->size(), 6);
137
  ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
138

139
  // named and unnamed AnyModule's
140
  Sequential sequential_any;
141
  auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2));
142
  ASSERT_EQ(sequential_any->size(), 0);
143
  ASSERT_TRUE(sequential_any->is_empty());
144
  sequential_any->push_back(a);
145
  ASSERT_EQ(sequential_any->size(), 1);
146
  ASSERT_EQ(sequential_any->named_children()[0].key(), "0");
147
  sequential_any->push_back("fc", a);
148
  ASSERT_EQ(sequential_any->size(), 2);
149
  ASSERT_EQ(sequential_any->named_children()[1].key(), "fc");
150
}
151

152
TEST_F(SequentialTest, AccessWithAt) {
153
  struct M : torch::nn::Module {
154
    explicit M(int value_) : value(value_) {}
155
    int forward() {
156
      return value;
157
    }
158
    int value;
159
  };
160
  std::vector<std::shared_ptr<M>> modules = {
161
      std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
162

163
  Sequential sequential;
164
  for (auto& module : modules) {
165
    sequential->push_back(module);
166
  }
167
  ASSERT_EQ(sequential->size(), 3);
168

169
  // returns the correct module for a given index
170
  for (const auto i : c10::irange(modules.size())) {
171
    ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
172
  }
173

174
  // throws for a bad index
175
  ASSERT_THROWS_WITH(
176
      sequential->at<M>(modules.size() + 1), "Index out of range");
177
  ASSERT_THROWS_WITH(
178
      sequential->at<M>(modules.size() + 1000000), "Index out of range");
179
}
180

181
TEST_F(SequentialTest, AccessWithPtr) {
182
  struct M : torch::nn::Module {
183
    explicit M(int value_) : value(value_) {}
184
    int forward() {
185
      return value;
186
    }
187
    int value;
188
  };
189
  std::vector<std::shared_ptr<M>> modules = {
190
      std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
191

192
  Sequential sequential;
193
  for (auto& module : modules) {
194
    sequential->push_back(module);
195
  }
196
  ASSERT_EQ(sequential->size(), 3);
197

198
  // returns the correct module for a given index
199
  for (const auto i : c10::irange(modules.size())) {
200
    ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
201
    ASSERT_EQ(sequential[i].get(), modules[i].get());
202
    ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
203
  }
204

205
  // throws for a bad index
206
  ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
207
  ASSERT_THROWS_WITH(
208
      sequential->ptr(modules.size() + 1000000), "Index out of range");
209
}
210

211
TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
212
  Sequential empty;
213
  ASSERT_THROWS_WITH(
214
      empty->forward<int>(), "Cannot call forward() on an empty Sequential");
215
}
216

217
TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
218
  struct MockModule : torch::nn::Module {
219
    explicit MockModule(int value) : expected(value) {}
220
    int expected;
221
    int forward(int value) {
222
      assert(value == expected);
223
      return value + 1;
224
    }
225
  };
226

227
  Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
228

229
  ASSERT_EQ(sequential->forward<int>(1), 4);
230
}
231

232
TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
233
  struct M : public torch::nn::Module {
234
    int forward() {
235
      return 5;
236
    }
237
  };
238

239
  Sequential sequential(M{});
240
  ASSERT_EQ(sequential->forward<int>(), 5);
241
  ASSERT_THROWS_WITH(
242
      sequential->forward<float>(),
243
      "The type of the return value is int, but you asked for type float");
244
}
245

246
TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
247
  struct M : public torch::nn::Module {
248
    torch::Tensor forward(torch::Tensor v) {
249
      return v;
250
    }
251
  };
252

253
  Sequential sequential(M{});
254
  auto variable = torch::ones({3, 3}, torch::requires_grad());
255
  ASSERT_TRUE(sequential->forward(variable).equal(variable));
256
}
257

258
TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
259
  torch::manual_seed(0);
260
  Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
261

262
  auto x = torch::randn({1000, 10}, torch::requires_grad());
263
  auto y = sequential->forward(x);
264
  ASSERT_EQ(y.ndimension(), 2);
265
  ASSERT_EQ(y.size(0), 1000);
266
  ASSERT_EQ(y.size(1), 100);
267
}
268

269
TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
270
  Sequential sequential(
271
      Linear(10, 3),
272
      Conv2d(1, 2, 3),
273
      Dropout(0.5),
274
      BatchNorm2d(5),
275
      Embedding(4, 10),
276
      LSTM(4, 5));
277
}
278

279
TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
280
  struct A : torch::nn::Module {
281
    int forward(int x) {
282
      return x;
283
    }
284
  };
285
  struct B : torch::nn::Module {
286
    int forward(int x) {
287
      return x;
288
    }
289
  };
290
  struct C : torch::nn::Module {
291
    int forward(int x) {
292
      return x;
293
    }
294
  };
295
  struct D : torch::nn::Module {
296
    int forward(int x) {
297
      return x;
298
    }
299
  };
300
  Sequential a(A{}, B{});
301
  Sequential b(C{}, D{});
302
  a->extend(*b);
303

304
  ASSERT_EQ(a->size(), 4);
305
  ASSERT_TRUE(a[0]->as<A>());
306
  ASSERT_TRUE(a[1]->as<B>());
307
  ASSERT_TRUE(a[2]->as<C>());
308
  ASSERT_TRUE(a[3]->as<D>());
309

310
  ASSERT_EQ(b->size(), 2);
311
  ASSERT_TRUE(b[0]->as<C>());
312
  ASSERT_TRUE(b[1]->as<D>());
313

314
  std::vector<std::shared_ptr<A>> c = {
315
      std::make_shared<A>(), std::make_shared<A>()};
316
  b->extend(c);
317

318
  ASSERT_EQ(b->size(), 4);
319
  ASSERT_TRUE(b[0]->as<C>());
320
  ASSERT_TRUE(b[1]->as<D>());
321
  ASSERT_TRUE(b[2]->as<A>());
322
  ASSERT_TRUE(b[3]->as<A>());
323
}
324

325
TEST_F(SequentialTest, HasReferenceSemantics) {
326
  Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
327
  Sequential second(first);
328

329
  ASSERT_EQ(first.get(), second.get());
330
  ASSERT_EQ(first->size(), second->size());
331
  ASSERT_TRUE(std::equal(
332
      first->begin(),
333
      first->end(),
334
      second->begin(),
335
      [](const AnyModule& first, const AnyModule& second) {
336
        return &first == &second;
337
      }));
338
}
339

340
TEST_F(SequentialTest, IsCloneable) {
341
  Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
342
  Sequential clone =
343
      std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
344
  ASSERT_EQ(sequential->size(), clone->size());
345

346
  for (size_t i = 0; i < sequential->size(); ++i) {
347
    // The modules should be the same kind (type).
348
    ASSERT_EQ(sequential[i]->name(), clone[i]->name());
349
    // But not pointer-equal (distinct objects).
350
    ASSERT_NE(sequential[i], clone[i]);
351
  }
352

353
  // Verify that the clone is deep, i.e. parameters of modules are cloned too.
354

355
  torch::NoGradGuard no_grad;
356

357
  auto params1 = sequential->named_parameters();
358
  auto params2 = clone->named_parameters();
359
  ASSERT_EQ(params1.size(), params2.size());
360
  for (auto& param : params1) {
361
    ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
362
    ASSERT_EQ(param->device(), params2[param.key()].device());
363
    ASSERT_TRUE(param->allclose(params2[param.key()]));
364
    param->add_(2);
365
  }
366
  for (auto& param : params1) {
367
    ASSERT_FALSE(param->allclose(params2[param.key()]));
368
  }
369
}
370

371
TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
372
  Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
373

374
  auto modules = sequential->children();
375
  ASSERT_TRUE(modules[0]->as<Linear>());
376
  ASSERT_TRUE(modules[1]->as<Conv2d>());
377
  ASSERT_TRUE(modules[2]->as<Dropout2d>());
378
}
379

380
TEST_F(SequentialTest, CloneToDevice_CUDA) {
381
  Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
382
  torch::Device device(torch::kCUDA, 0);
383
  Sequential clone =
384
      std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
385
  for (const auto& p : clone->parameters()) {
386
    ASSERT_EQ(p.device(), device);
387
  }
388
  for (const auto& b : clone->buffers()) {
389
    ASSERT_EQ(b.device(), device);
390
  }
391
}
392

393
TEST_F(SequentialTest, PrettyPrintSequential) {
394
  Sequential sequential(
395
      Linear(10, 3),
396
      Conv2d(1, 2, 3),
397
      Dropout(0.5),
398
      BatchNorm2d(5),
399
      Embedding(4, 10),
400
      LSTM(4, 5));
401
  ASSERT_EQ(
402
      c10::str(sequential),
403
      "torch::nn::Sequential(\n"
404
      "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
405
      "  (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
406
      "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
407
      "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
408
      "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
409
      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
410
      ")");
411

412
  Sequential sequential_named(
413
      {{"linear", Linear(10, 3)},
414
       {"conv2d", Conv2d(1, 2, 3)},
415
       {"dropout", Dropout(0.5)},
416
       {"batchnorm2d", BatchNorm2d(5)},
417
       {"embedding", Embedding(4, 10)},
418
       {"lstm", LSTM(4, 5)}});
419
  ASSERT_EQ(
420
      c10::str(sequential_named),
421
      "torch::nn::Sequential(\n"
422
      "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
423
      "  (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
424
      "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
425
      "  (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
426
      "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
427
      "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
428
      ")");
429
}
430

431
TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
432
  {
433
    Sequential sequential(
434
        Identity(),
435
        ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
436
    std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])
437
        ->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
438
    auto x = torch::arange(30.).reshape({2, 3, 5});
439
    auto y = sequential->forward(x);
440
    auto expected = torch::tensor(
441
        {{{150., 333., 552., 615., 678., 501., 276.},
442
          {195., 432., 714., 804., 894., 654., 357.}},
443
         {{420., 918., 1497., 1560., 1623., 1176., 636.},
444
          {600., 1287., 2064., 2154., 2244., 1599., 852.}}});
445
    ASSERT_TRUE(torch::allclose(y, expected));
446
  }
447
  {
448
    Sequential sequential(
449
        Identity(),
450
        ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
451
    std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])
452
        ->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
453
    auto x = torch::arange(75.).reshape({1, 3, 5, 5});
454
    auto y = sequential->forward(x);
455
    auto expected = torch::tensor(
456
        {{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.},
457
           {4995., 10272., 15837., 16206., 16575., 11364., 5841.},
458
           {8280., 17019., 26226., 26820., 27414., 18783., 9648.},
459
           {9225., 18954., 29196., 29790., 30384., 20808., 10683.},
460
           {10170., 20889., 32166., 32760., 33354., 22833., 11718.},
461
           {7515., 15420., 23721., 24144., 24567., 16800., 8613.},
462
           {4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
463
          {{2925., 6006., 9246., 9498., 9750., 6672., 3423.},
464
           {6480., 13296., 20454., 20985., 21516., 14712., 7542.},
465
           {10710., 21960., 33759., 34596., 35433., 24210., 12402.},
466
           {12060., 24705., 37944., 38781., 39618., 27045., 13842.},
467
           {13410., 27450., 42129., 42966., 43803., 29880., 15282.},
468
           {9810., 20064., 30768., 31353., 31938., 21768., 11124.},
469
           {5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
470
    ASSERT_TRUE(torch::allclose(y, expected));
471
  }
472
  {
473
    Sequential sequential(
474
        Identity(),
475
        ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
476
    std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])
477
        ->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
478
    auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
479
    auto y = sequential->forward(x);
480
    auto expected = torch::tensor(
481
        {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
482
           {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
483
           {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
484
          {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
485
           {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
486
           {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
487
    ASSERT_TRUE(torch::allclose(y, expected));
488
  }
489
  {
490
    auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
491
    Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
492
    auto x = torch::tensor({{1, 0}}, torch::kLong);
493
    auto y = sequential->forward(x);
494
    auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
495
    ASSERT_TRUE(torch::allclose(y, expected));
496
  }
497
  {
498
    torch::manual_seed(0);
499

500
    int64_t embed_dim = 8;
501
    int64_t num_heads = 4;
502
    int64_t batch_size = 8;
503
    int64_t src_len = 3;
504
    int64_t tgt_len = 1;
505

506
    auto query = torch::ones({batch_size, tgt_len, embed_dim});
507
    auto key = torch::ones({batch_size, src_len, embed_dim});
508
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
509
    auto value = key;
510

511
    Sequential sequential(MultiheadAttention(embed_dim, num_heads));
512
    auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(
513
        query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
514

515
    auto attn_output = std::get<0>(output);
516
    auto attn_output_expected = torch::tensor(
517
        {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
518
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
519
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
520
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
521
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
522
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
523
          {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
524
          {0.0674,
525
           -0.0056,
526
           0.1324,
527
           0.0922,
528
           0.0160,
529
           -0.0934,
530
           -0.1700,
531
           0.1663}}});
532
    ASSERT_TRUE(
533
        torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
534

535
    auto attn_output_weights = std::get<1>(output);
536
    auto attn_output_weights_expected = torch::tensor(
537
        {{{0.3333, 0.3333, 0.3333}},
538
         {{0.3333, 0.3333, 0.3333}},
539
         {{0.3333, 0.3333, 0.3333}},
540
         {{0.3333, 0.3333, 0.3333}},
541
         {{0.3333, 0.3333, 0.3333}},
542
         {{0.3333, 0.3333, 0.3333}},
543
         {{0.3333, 0.3333, 0.3333}},
544
         {{0.3333, 0.3333, 0.3333}}});
545
    ASSERT_TRUE(torch::allclose(
546
        attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
547
  }
548
  {
549
    auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
550
    auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
551
    Sequential sequential(MaxUnpool1d(3));
552
    auto y = sequential->forward(x, indices);
553
    auto expected =
554
        torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
555
    ASSERT_TRUE(torch::allclose(y, expected));
556
  }
557
  {
558
    auto indices = torch::tensor(
559
        {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
560
         {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
561
        torch::kLong);
562
    auto x = torch::tensor(
563
        {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
564
         {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
565
        torch::dtype(torch::kFloat));
566
    Sequential sequential(
567
        MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
568
    auto y = sequential->forward(x, indices);
569
    auto expected = torch::tensor(
570
        {{{{0, 0, 0, 0, 0},
571
           {0, 6, 0, 8, 9},
572
           {0, 0, 0, 0, 0},
573
           {0, 16, 0, 18, 19},
574
           {0, 21, 0, 23, 24}}},
575
         {{{0, 0, 0, 0, 0},
576
           {0, 31, 0, 33, 34},
577
           {0, 0, 0, 0, 0},
578
           {0, 41, 0, 43, 44},
579
           {0, 46, 0, 48, 49}}}},
580
        torch::kFloat);
581
    ASSERT_TRUE(torch::allclose(y, expected));
582
  }
583
  {
584
    auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
585
    auto x = torch::tensor(
586
        {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
587
    Sequential sequential(MaxUnpool3d(3));
588
    auto y = sequential->forward(x, indices);
589
    auto expected = torch::tensor(
590
        {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
591
           {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
592
           {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
593
        torch::kFloat);
594
    ASSERT_TRUE(torch::allclose(y, expected));
595
  }
596
  {
597
    torch::manual_seed(0);
598
    Sequential sequential(Identity(), RNN(2, 3));
599
    auto x = torch::ones({2, 3, 2});
600
    auto rnn_output =
601
        sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
602
    auto expected_output = torch::tensor(
603
        {{{-0.0645, -0.7274, 0.4531},
604
          {-0.0645, -0.7274, 0.4531},
605
          {-0.0645, -0.7274, 0.4531}},
606
         {{-0.3970, -0.6950, 0.6009},
607
          {-0.3970, -0.6950, 0.6009},
608
          {-0.3970, -0.6950, 0.6009}}});
609
    ASSERT_TRUE(torch::allclose(
610
        std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
611
  }
612
  {
613
    torch::manual_seed(0);
614
    Sequential sequential(Identity(), LSTM(2, 3));
615
    auto x = torch::ones({2, 3, 2});
616
    auto rnn_output = sequential->forward<
617
        std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
618
    auto expected_output = torch::tensor(
619
        {{{-0.2693, -0.1240, 0.0744},
620
          {-0.2693, -0.1240, 0.0744},
621
          {-0.2693, -0.1240, 0.0744}},
622
         {{-0.3889, -0.1919, 0.1183},
623
          {-0.3889, -0.1919, 0.1183},
624
          {-0.3889, -0.1919, 0.1183}}});
625
    ASSERT_TRUE(torch::allclose(
626
        std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
627
  }
628
  {
629
    torch::manual_seed(0);
630
    Sequential sequential(Identity(), GRU(2, 3));
631
    auto x = torch::ones({2, 3, 2});
632
    auto rnn_output =
633
        sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
634
    auto expected_output = torch::tensor(
635
        {{{-0.1134, 0.0467, 0.2336},
636
          {-0.1134, 0.0467, 0.2336},
637
          {-0.1134, 0.0467, 0.2336}},
638
         {{-0.1189, 0.0502, 0.2960},
639
          {-0.1189, 0.0502, 0.2960},
640
          {-0.1189, 0.0502, 0.2960}}});
641
    ASSERT_TRUE(torch::allclose(
642
        std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
643
  }
644
  {
645
    torch::manual_seed(0);
646
    Sequential sequential(Identity(), RNNCell(2, 3));
647
    auto x = torch::ones({2, 2});
648
    auto rnn_output = sequential->forward<torch::Tensor>(x);
649
    auto expected_output =
650
        torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}});
651
    ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
652
  }
653
  {
654
    torch::manual_seed(0);
655
    Sequential sequential(Identity(), LSTMCell(2, 3));
656
    auto x = torch::ones({2, 2});
657
    auto rnn_output =
658
        sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
659
    auto expected_output =
660
        torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}});
661
    ASSERT_TRUE(torch::allclose(
662
        std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
663
  }
664
  {
665
    torch::manual_seed(0);
666
    Sequential sequential(Identity(), GRUCell(2, 3));
667
    auto x = torch::ones({2, 2});
668
    auto rnn_output = sequential->forward<torch::Tensor>(x);
669
    auto expected_output =
670
        torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}});
671
    ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
672
  }
673
}
674

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.