pytorch

Форк
0
/
transformer.cpp 
1523 строки · 58.9 Кб
1
#include <gtest/gtest.h>
2

3
#include <torch/torch.h>
4

5
#include <test/cpp/api/support.h>
6

7
using namespace torch::nn;
8

9
struct TransformerTest : torch::test::SeedingFixture {};
10

11
// a generic function to set constants for parameters so we have fixed result
12
// for deterministic test
13
template <typename Model>
14
void set_parameter_to_constants(
15
    Model& model,
16
    const torch::TensorOptions& tensor_options) {
17
  torch::NoGradGuard guard;
18
  for (auto& p : model->parameters()) {
19
    auto sz = p.view(-1).size(0);
20
    p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
21
  }
22
}
23

24
// a generic function to provide consistent encoder/decoder layer for all the
25
// transformer tests
26
template <typename T_LAYER, typename T_OPTIONS>
27
T_LAYER get_a_test_layer(
28
    const torch::TensorOptions& tensor_options,
29
    bool use_callable_activation) {
30
  int64_t d_model = 4;
31
  int64_t nhead = 2;
32
  int64_t dim_feedforward = 16;
33
  double dropout = 0.0;
34

35
  // activation is always ReLU here and it can be adjusted later depending on
36
  // the usage
37
  T_LAYER layer(T_OPTIONS(d_model, nhead)
38
                    .dim_feedforward(dim_feedforward)
39
                    .dropout(dropout));
40
  if (tensor_options.device() == torch::kCUDA) {
41
    layer->to(torch::kCUDA);
42
  }
43
  if (use_callable_activation) {
44
    layer.get()->options.activation(
45
        [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
46
  }
47

48
  // set constant weights of the model
49
  set_parameter_to_constants<T_LAYER>(layer, tensor_options);
50

51
  return layer;
52
}
53

54
void transformer_encoder_layer_test_helper(
55
    bool is_cuda,
56
    bool use_callable_activation) {
57
  // this is a deterministic test for TransformerEncoderLayer
58
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
59
  torch::TensorOptions tensor_options =
60
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
61

62
  TransformerEncoderLayer model =
63
      get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
64
          tensor_options, use_callable_activation);
65

66
  // relu test case 1
67
  torch::Tensor encoder_input =
68
      torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
69
  torch::Tensor result = model(encoder_input).detach();
70
  torch::Tensor ref_output = torch::tensor(
71
      {{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options);
72
  ASSERT_EQ(result.sizes(), ref_output.sizes());
73
  ASSERT_TRUE(
74
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
75

76
  // all 0 values are NOT masked. This should't mask anything
77
  torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1;
78
  result = model(
79
               encoder_input,
80
               /*src_mask=*/torch::Tensor{},
81
               /*src_key_padding_mask=*/mask)
82
               .detach();
83
  ASSERT_EQ(result.sizes(), ref_output.sizes());
84
  ASSERT_TRUE(
85
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
86

87
  // all 1 values are masked. Since there is only 1 input embedding this will
88
  // result in nan.
89
  mask = torch::tensor({{1}}, tensor_options) == 1;
90
  result = model(
91
               encoder_input,
92
               /*src_mask=*/torch::Tensor{},
93
               /*src_key_padding_mask=*/mask)
94
               .detach();
95
  ASSERT_TRUE(torch::isnan(result).all().item().to<bool>());
96

97
  // relu test case 2
98
  encoder_input =
99
      torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
100
  result = model(encoder_input).detach();
101
  ref_output = torch::tensor(
102
      {{{2.272644, 0.119035, -0.691669, 0.153486}},
103
       {{2.272644, 0.119035, -0.691669, 0.153486}}},
104
      tensor_options);
105
  ASSERT_EQ(result.sizes(), ref_output.sizes());
106
  ASSERT_TRUE(
107
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
108

109
  // all 0 values are NOT masked
110
  mask = torch::tensor({{0, 0}}, tensor_options) == 1;
111
  result = model(
112
               encoder_input,
113
               /*src_mask=*/torch::Tensor{},
114
               /*src_key_padding_mask=*/mask)
115
               .detach();
116
  ASSERT_EQ(result.sizes(), ref_output.sizes());
117
  ASSERT_TRUE(
118
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
119

120
  // mask with 1 and 0
121
  mask = torch::tensor({{1, 0}}, tensor_options) == 1;
122
  result = model(
123
               encoder_input,
124
               /*src_mask=*/torch::Tensor{},
125
               /*src_key_padding_mask=*/mask)
126
               .detach();
127
  ref_output = torch::tensor(
128
      {{{2.301516, 0.092249, -0.679101, 0.103088}},
129
       {{2.301516, 0.092249, -0.679101, 0.103088}}},
130
      tensor_options);
131
  ASSERT_EQ(result.sizes(), ref_output.sizes());
132
  ASSERT_TRUE(
133
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
134

135
  // relu test case 3
136
  encoder_input = torch::tensor(
137
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
138
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
139
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
140
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
141
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
142
      tensor_options);
143
  result = model(encoder_input).detach();
144
  ref_output = torch::tensor(
145
      {{{2.428589, 0.020835, -0.602055, -0.085249},
146
        {2.427987, 0.021213, -0.602496, -0.084103}},
147
       {{2.424689, 0.019155, -0.604793, -0.085672},
148
        {2.413863, 0.022211, -0.612486, -0.072490}},
149
       {{2.433774, 0.021598, -0.598343, -0.087548},
150
        {2.425104, 0.019748, -0.604515, -0.084839}},
151
       {{2.436185, 0.022682, -0.596625, -0.087261},
152
        {2.433556, 0.021891, -0.598509, -0.086832}},
153
       {{2.416246, 0.017512, -0.610712, -0.082961},
154
        {2.422901, 0.024187, -0.606178, -0.074929}}},
155
      tensor_options);
156
  ASSERT_EQ(result.sizes(), ref_output.sizes());
157
  ASSERT_TRUE(
158
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
159

160
  // all 0 values are NOT masked
161
  mask = torch::zeros({2, 5}, tensor_options) == 1;
162
  result = model(
163
               encoder_input,
164
               /*src_mask=*/torch::Tensor{},
165
               /*src_key_padding_mask=*/mask)
166
               .detach();
167
  ASSERT_EQ(result.sizes(), ref_output.sizes());
168
  ASSERT_TRUE(
169
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
170

171
  // mask with 0s and 1s
172
  mask[0][1] = 1;
173
  mask[1][3] = 1;
174
  mask[1][4] = 1;
175
  result = model(
176
               encoder_input,
177
               /*src_mask=*/torch::Tensor{},
178
               /*src_key_padding_mask=*/mask)
179
               .detach();
180
  ref_output = torch::tensor(
181
      {{{2.429026, 0.020793, -0.601741, -0.085642},
182
        {2.428811, 0.021445, -0.601912, -0.084252}},
183
       {{2.425009, 0.019155, -0.604566, -0.085899},
184
        {2.415408, 0.02249, -0.611415, -0.073}},
185
       {{2.434199, 0.021682, -0.598039, -0.087699},
186
        {2.42598, 0.019941, -0.603896, -0.085091}},
187
       {{2.436457, 0.022736, -0.59643, -0.08736},
188
        {2.434021, 0.022093, -0.598179, -0.08679}},
189
       {{2.416531, 0.017498, -0.610513, -0.083181},
190
        {2.4242, 0.024653, -0.605266, -0.074959}}},
191
      tensor_options);
192
  ASSERT_EQ(result.sizes(), ref_output.sizes());
193
  ASSERT_TRUE(
194
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
195

196
  // gelu test case 1
197
  model.get()->options.activation(torch::kGELU);
198
  encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
199
  result = model(encoder_input).detach();
200
  ref_output = torch::tensor(
201
      {{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options);
202
  ASSERT_EQ(result.sizes(), ref_output.sizes());
203
  ASSERT_TRUE(
204
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
205

206
  // gelu test case 2
207
  encoder_input = torch::tensor(
208
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
209
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
210
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
211
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
212
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
213
      tensor_options);
214
  result = model(encoder_input);
215
  ref_output = torch::tensor(
216
      {{{2.42163188, 0.03227153, -0.60714219, -0.05908082},
217
        {2.42151276, 0.03302179, -0.60722523, -0.05762651}},
218
       {{2.41926761, 0.02974034, -0.60879519, -0.0621269},
219
        {2.41626395, 0.03539356, -0.61087842, -0.04978623}},
220
       {{2.42382808, 0.03218872, -0.6055963, -0.06073591},
221
        {2.41983477, 0.03085259, -0.60840145, -0.06046414}},
222
       {{2.42500749, 0.03328855, -0.60476388, -0.0595334},
223
        {2.4237977, 0.03290575, -0.60561789, -0.05940082}},
224
       {{2.41383916, 0.02686345, -0.61256377, -0.06380707},
225
        {2.42000277, 0.03800944, -0.60824798, -0.04754947}}},
226
      tensor_options);
227
  ASSERT_EQ(result.sizes(), ref_output.sizes());
228
  ASSERT_TRUE(
229
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
230
}
231

232
TEST_F(TransformerTest, TransformerEncoderLayer) {
233
  transformer_encoder_layer_test_helper(
234
      /*is_cuda=*/false, /*use_callable_activation=*/false);
235
  transformer_encoder_layer_test_helper(
236
      /*is_cuda=*/false, /*use_callable_activation=*/true);
237
}
238

239
TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) {
240
  transformer_encoder_layer_test_helper(
241
      /*is_cuda=*/true, /*use_callable_activation=*/false);
242
  transformer_encoder_layer_test_helper(
243
      /*is_cuda=*/true, /*use_callable_activation=*/true);
244
}
245

246
void transformer_decoder_layer_test_helper(
247
    bool is_cuda,
248
    bool use_callable_activation) {
249
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
250
  torch::TensorOptions tensor_options =
251
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
252

253
  TransformerDecoderLayer model =
254
      get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
255
          tensor_options, use_callable_activation);
256

257
  // deterministic input
258
  torch::Tensor decoder_input =
259
      torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
260
  torch::Tensor memory_input =
261
      torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
262
  torch::Tensor result = model(decoder_input, memory_input).detach();
263
  torch::Tensor ref_output = torch::tensor(
264
      {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
265
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
266
  ASSERT_TRUE(torch::allclose(
267
      result,
268
      ref_output,
269
      1e-7,
270
      1e-5,
271
      /*equal_nan=*/true));
272

273
  // deterministic input
274
  decoder_input =
275
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
276
  memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
277
  result = model(decoder_input, memory_input).detach();
278
  ref_output = torch::tensor(
279
      {{{2.422245, 0.051716, -0.606338, -0.024756}},
280
       {{2.422245, 0.051716, -0.606338, -0.024756}}},
281
      tensor_options);
282
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
283
  ASSERT_TRUE(torch::allclose(
284
      result,
285
      ref_output,
286
      1e-7,
287
      1e-5,
288
      /*equal_nan=*/true));
289

290
  // deterministic input
291
  decoder_input =
292
      torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
293
  memory_input =
294
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
295
  result = model(decoder_input, memory_input).detach();
296
  ref_output = torch::tensor(
297
      {{{2.343536, 0.085561, -0.654954, 0.074991}},
298
       {{2.343536, 0.085561, -0.654954, 0.074991}}},
299
      tensor_options);
300
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
301
  ASSERT_TRUE(torch::allclose(
302
      result,
303
      ref_output,
304
      1e-7,
305
      1e-5,
306
      /*equal_nan=*/true));
307

308
  // deterministic input
309
  decoder_input = torch::tensor(
310
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
311
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
312
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
313
      tensor_options);
314
  memory_input = torch::tensor(
315
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
316
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
317
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
318
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
319
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
320
      tensor_options);
321
  result = model(decoder_input, memory_input).detach();
322
  ref_output = torch::tensor(
323
      {{{2.430065, 0.027862, -0.601136, -0.073096},
324
        {2.431935, 0.028907, -0.599809, -0.072488}},
325
       {{2.428457, 0.027053, -0.602275, -0.073462},
326
        {2.431970, 0.029387, -0.599789, -0.071621}},
327
       {{2.431934, 0.028196, -0.599802, -0.073809},
328
        {2.432306, 0.028858, -0.599542, -0.072846}}},
329
      tensor_options);
330
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
331
  ASSERT_TRUE(torch::allclose(
332
      result,
333
      ref_output,
334
      1e-7,
335
      1e-5,
336
      /*equal_nan=*/true));
337

338
  // key_padding_mask
339
  torch::Tensor t_mask = {};
340
  torch::Tensor m_mask = {};
341
  torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
342
  result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
343
               .detach();
344
  ref_output = torch::tensor(
345
      {{{2.430065, 0.027862, -0.601136, -0.073096},
346
        {2.431935, 0.028907, -0.599809, -0.072488}},
347
       {{2.428457, 0.027053, -0.602275, -0.073462},
348
        {2.431970, 0.029387, -0.599789, -0.071621}},
349
       {{2.431934, 0.028196, -0.599802, -0.073809},
350
        {2.432306, 0.028858, -0.599542, -0.072846}}},
351
      tensor_options);
352
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
353
  ASSERT_TRUE(torch::allclose(
354
      result,
355
      ref_output,
356
      1e-7,
357
      1e-5,
358
      /*equal_nan=*/true));
359

360
  // key_padding_mask
361
  key_padding_mask[0][2] = 1;
362
  key_padding_mask[1][1] = 1;
363
  key_padding_mask[1][2] = 1;
364
  result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
365
               .detach();
366
  ref_output = torch::tensor(
367
      {{{2.430025, 0.027643, -0.601164, -0.073476},
368
        {2.4323, 0.029375, -0.599553, -0.071881}},
369
       {{2.428523, 0.026838, -0.602226, -0.07391},
370
        {2.432634, 0.029842, -0.599318, -0.071253}},
371
       {{2.432278, 0.028152, -0.599555, -0.074139},
372
        {2.432659, 0.029244, -0.599294, -0.072382}}},
373
      tensor_options);
374
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
375
  ASSERT_TRUE(torch::allclose(
376
      result,
377
      ref_output,
378
      1e-7,
379
      1e-5,
380
      /*equal_nan=*/true));
381

382
  // memory_key_padding_mask
383
  torch::Tensor t_key_padding_mask = {};
384
  key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
385
  result = model(
386
               decoder_input,
387
               memory_input,
388
               t_mask,
389
               m_mask,
390
               t_key_padding_mask,
391
               key_padding_mask)
392
               .detach();
393
  ref_output = torch::tensor(
394
      {{{2.430065, 0.027862, -0.601136, -0.073096},
395
        {2.431935, 0.028907, -0.599809, -0.072488}},
396
       {{2.428457, 0.027053, -0.602275, -0.073462},
397
        {2.431970, 0.029387, -0.599789, -0.071621}},
398
       {{2.431934, 0.028196, -0.599802, -0.073809},
399
        {2.432306, 0.028858, -0.599542, -0.072846}}},
400
      tensor_options);
401
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
402
  ASSERT_TRUE(torch::allclose(
403
      result,
404
      ref_output,
405
      1e-7,
406
      1e-5,
407
      /*equal_nan=*/true));
408

409
  // memory_key_padding_mask
410
  key_padding_mask[0][4] = 1;
411
  key_padding_mask[1][3] = 1;
412
  key_padding_mask[1][4] = 1;
413
  result = model(
414
               decoder_input,
415
               memory_input,
416
               t_mask,
417
               m_mask,
418
               t_key_padding_mask,
419
               key_padding_mask)
420
               .detach();
421
  ref_output = torch::tensor(
422
      {{{2.429757, 0.027358, -0.601351, -0.073816},
423
        {2.432692, 0.028583, -0.599263, -0.073634}},
424
       {{2.428247, 0.02662, -0.602419, -0.074123},
425
        {2.432657, 0.029055, -0.599293, -0.072732}},
426
       {{2.431515, 0.027687, -0.600096, -0.074459},
427
        {2.433075, 0.028543, -0.598987, -0.073985}}},
428
      tensor_options);
429
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
430
  ASSERT_TRUE(torch::allclose(
431
      result,
432
      ref_output,
433
      1e-7,
434
      1e-5,
435
      /*equal_nan=*/true));
436
}
437

438
TEST_F(TransformerTest, TransformerDecoderLayer) {
439
  transformer_decoder_layer_test_helper(
440
      /*is_cuda=*/false, /*use_callable_activation=*/false);
441
  transformer_decoder_layer_test_helper(
442
      /*is_cuda=*/false, /*use_callable_activation=*/true);
443
}
444

445
TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) {
446
  transformer_decoder_layer_test_helper(
447
      /*is_cuda=*/true, /*use_callable_activation=*/false);
448
  transformer_decoder_layer_test_helper(
449
      /*is_cuda=*/true, /*use_callable_activation=*/true);
450
}
451

452
void transformer_decoder_layer_test_helper_gelu(
453
    bool is_cuda,
454
    bool use_callable_activation) {
455
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
456
  torch::TensorOptions tensor_options =
457
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
458

459
  TransformerDecoderLayer model =
460
      get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
461
          tensor_options, use_callable_activation);
462
  if (use_callable_activation) {
463
    model.get()->options.activation(
464
        [&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); });
465
  } else {
466
    model.get()->options.activation(torch::kGELU);
467
  }
468

469
  // deterministic input
470
  torch::Tensor decoder_input =
471
      torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
472
  torch::Tensor memory_input =
473
      torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
474
  torch::Tensor result = model(decoder_input, memory_input).detach();
475
  torch::Tensor ref_output = torch::tensor(
476
      {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
477
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
478
  ASSERT_TRUE(torch::allclose(
479
      result,
480
      ref_output,
481
      1e-7,
482
      1e-5,
483
      /*equal_nan=*/true));
484

485
  // deterministic input
486
  decoder_input =
487
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
488
  memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
489
  result = model(decoder_input, memory_input).detach();
490
  ref_output = torch::tensor(
491
      {{{2.415448, 0.054389, -0.610932, -0.0156613}},
492
       {{2.415448, 0.054389, -0.610932, -0.0156613}}},
493
      tensor_options);
494
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
495
  ASSERT_TRUE(torch::allclose(
496
      result,
497
      ref_output,
498
      1e-7,
499
      1e-5,
500
      /*equal_nan=*/true));
501

502
  // deterministic input
503
  decoder_input =
504
      torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
505
  memory_input =
506
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
507
  result = model(decoder_input, memory_input).detach();
508
  ref_output = torch::tensor(
509
      {{{2.338531, 0.087709, -0.65776, 0.080646}},
510
       {{2.338531, 0.087709, -0.65776, 0.080646}}},
511
      tensor_options);
512
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
513
  ASSERT_TRUE(torch::allclose(
514
      result,
515
      ref_output,
516
      1e-7,
517
      1e-5,
518
      /*equal_nan=*/true));
519

520
  // deterministic input
521
  decoder_input = torch::tensor(
522
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
523
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
524
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
525
      tensor_options);
526
  memory_input = torch::tensor(
527
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
528
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
529
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
530
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
531
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
532
      tensor_options);
533
  result = model(decoder_input, memory_input).detach();
534
  ref_output = torch::tensor(
535
      {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
536
        {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
537
       {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
538
        {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
539
       {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
540
        {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
541
      tensor_options);
542
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
543
  ASSERT_TRUE(torch::allclose(
544
      result,
545
      ref_output,
546
      1e-7,
547
      1e-5,
548
      /*equal_nan=*/true));
549
}
550

551
TEST_F(TransformerTest, TransformerDecoderLayer_gelu) {
552
  transformer_decoder_layer_test_helper_gelu(
553
      /*is_cuda=*/false, /*use_callable_activation=*/false);
554
  transformer_decoder_layer_test_helper_gelu(
555
      /*is_cuda=*/false, /*use_callable_activation=*/true);
556
}
557

558
TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) {
559
  transformer_decoder_layer_test_helper_gelu(
560
      /*is_cuda=*/true, /*use_callable_activation=*/false);
561
  transformer_decoder_layer_test_helper_gelu(
562
      /*is_cuda=*/true, /*use_callable_activation=*/true);
563
}
564

565
void transformer_encoder_test_helper(
566
    bool is_cuda,
567
    bool use_callable_activation) {
568
  // this is a deterministic test for TransformerEncoderLayer
569
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
570
  torch::TensorOptions tensor_options =
571
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
572

573
  TransformerEncoderLayer encoder_layer =
574
      get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
575
          tensor_options, use_callable_activation);
576

577
  TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1));
578
  if (is_cuda) {
579
    model->to(torch::kCUDA);
580
  }
581

582
  torch::Tensor encoder_input = torch::tensor(
583
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
584
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
585
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
586
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
587
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
588
      tensor_options);
589
  torch::Tensor result = model(encoder_input).detach();
590
  torch::Tensor ref_output = torch::tensor(
591
      {{{2.428589, 0.020835, -0.602055, -0.085249},
592
        {2.427987, 0.021213, -0.602496, -0.084103}},
593
       {{2.424689, 0.019155, -0.604793, -0.085672},
594
        {2.413863, 0.022211, -0.612486, -0.072490}},
595
       {{2.433774, 0.021598, -0.598343, -0.087548},
596
        {2.425104, 0.019748, -0.604515, -0.084839}},
597
       {{2.436185, 0.022682, -0.596625, -0.087261},
598
        {2.433556, 0.021891, -0.598509, -0.086832}},
599
       {{2.416246, 0.017512, -0.610712, -0.082961},
600
        {2.422901, 0.024187, -0.606178, -0.074929}}},
601
      tensor_options);
602
  ASSERT_EQ(result.sizes(), ref_output.sizes());
603
  ASSERT_TRUE(
604
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
605

606
  // all 0 values are NOT masked
607
  torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1;
608
  result = model(
609
               encoder_input,
610
               /*src_mask=*/torch::Tensor{},
611
               /*src_key_padding_mask=*/mask)
612
               .detach();
613
  ASSERT_EQ(result.sizes(), ref_output.sizes());
614
  ASSERT_TRUE(
615
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
616

617
  // mask with 0s and 1s
618
  mask[0][1] = 1;
619
  mask[1][3] = 1;
620
  mask[1][4] = 1;
621
  result = model(
622
               encoder_input,
623
               /*src_mask=*/torch::Tensor{},
624
               /*src_key_padding_mask=*/mask)
625
               .detach();
626
  ref_output = torch::tensor(
627
      {{{2.429026, 0.020793, -0.601741, -0.085642},
628
        {2.428811, 0.021445, -0.601912, -0.084252}},
629
       {{2.425009, 0.019155, -0.604566, -0.085899},
630
        {2.415408, 0.02249, -0.611415, -0.073}},
631
       {{2.434199, 0.021682, -0.598039, -0.087699},
632
        {2.42598, 0.019941, -0.603896, -0.085091}},
633
       {{2.436457, 0.022736, -0.59643, -0.08736},
634
        {2.434021, 0.022093, -0.598179, -0.08679}},
635
       {{2.416531, 0.017498, -0.610513, -0.083181},
636
        {2.4242, 0.024653, -0.605266, -0.074959}}},
637
      tensor_options);
638
  ASSERT_EQ(result.sizes(), ref_output.sizes());
639
  ASSERT_TRUE(
640
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
641

642
  // test case 2, multiple layers no norm
643
  model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2));
644
  if (is_cuda) {
645
    model->to(torch::kCUDA);
646
  }
647
  result = model(
648
               encoder_input,
649
               /*src_mask=*/torch::Tensor{},
650
               /*src_key_padding_mask=*/mask)
651
               .detach();
652
  ref_output = torch::tensor(
653
      {{{2.419051, 0.017446, -0.608738, -0.085003},
654
        {2.419102, 0.017452, -0.608703, -0.085026}},
655
       {{2.419043, 0.017445, -0.608744, -0.084999},
656
        {2.419052, 0.017446, -0.608738, -0.085004}},
657
       {{2.419067, 0.017448, -0.608727, -0.085010},
658
        {2.419098, 0.017452, -0.608706, -0.085024}},
659
       {{2.419072, 0.017449, -0.608724, -0.085012},
660
        {2.419119, 0.017455, -0.608691, -0.085034}},
661
       {{2.419019, 0.017442, -0.608761, -0.084989},
662
        {2.419075, 0.017449, -0.608722, -0.085014}}},
663
      tensor_options);
664
  ASSERT_EQ(result.sizes(), ref_output.sizes());
665
  ASSERT_TRUE(
666
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
667

668
  model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6));
669
  if (is_cuda) {
670
    model->to(torch::kCUDA);
671
  }
672
  result = model(
673
               encoder_input,
674
               /*src_mask=*/torch::Tensor{},
675
               /*src_key_padding_mask=*/mask)
676
               .detach();
677
  ref_output = torch::tensor(
678
      {{{2.419101, 0.017453, -0.608703, -0.085025},
679
        {2.419101, 0.017453, -0.608704, -0.085025}},
680
       {{2.419101, 0.017453, -0.608703, -0.085025},
681
        {2.419101, 0.017453, -0.608704, -0.085025}},
682
       {{2.419101, 0.017453, -0.608703, -0.085025},
683
        {2.419101, 0.017453, -0.608704, -0.085025}},
684
       {{2.419101, 0.017453, -0.608703, -0.085025},
685
        {2.419101, 0.017453, -0.608704, -0.085025}},
686
       {{2.419101, 0.017453, -0.608703, -0.085025},
687
        {2.419101, 0.017453, -0.608704, -0.085025}}},
688
      tensor_options);
689
  ASSERT_EQ(result.sizes(), ref_output.sizes());
690
  ASSERT_TRUE(
691
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
692

693
  // test case 3, multiple layers with norm
694
  LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()}));
695
  model = TransformerEncoder(
696
      TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm)));
697
  if (is_cuda) {
698
    model->to(torch::kCUDA);
699
  }
700
  result = model(
701
               encoder_input,
702
               /*src_mask=*/torch::Tensor{},
703
               /*src_key_padding_mask=*/mask)
704
               .detach();
705
  ref_output = torch::tensor(
706
      {{{1.695949, -0.357635, -0.893077, -0.445238},
707
        {1.695955, -0.357639, -0.893050, -0.445266}},
708
       {{1.695948, -0.357634, -0.893082, -0.445233},
709
        {1.695950, -0.357635, -0.893077, -0.445238}},
710
       {{1.695951, -0.357636, -0.893069, -0.445246},
711
        {1.695955, -0.357639, -0.893052, -0.445264}},
712
       {{1.695952, -0.357636, -0.893066, -0.445249},
713
        {1.695957, -0.357641, -0.893041, -0.445276}},
714
       {{1.695946, -0.357632, -0.893095, -0.445220},
715
        {1.695952, -0.357637, -0.893065, -0.445251}}},
716
      tensor_options);
717
  ASSERT_EQ(result.sizes(), ref_output.sizes());
718
  ASSERT_TRUE(
719
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
720

721
  model = TransformerEncoder(
722
      TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm)));
723
  if (is_cuda) {
724
    model->to(torch::kCUDA);
725
  }
726
  result = model(
727
               encoder_input,
728
               /*src_mask=*/torch::Tensor{},
729
               /*src_key_padding_mask=*/mask)
730
               .detach();
731
  ref_output = torch::tensor(
732
      {{{1.695955, -0.357639, -0.893051, -0.445265},
733
        {1.695955, -0.357639, -0.893051, -0.445265}},
734
       {{1.695955, -0.357639, -0.893051, -0.445265},
735
        {1.695955, -0.357639, -0.893051, -0.445265}},
736
       {{1.695955, -0.357639, -0.893051, -0.445265},
737
        {1.695955, -0.357639, -0.893051, -0.445265}},
738
       {{1.695955, -0.357639, -0.893051, -0.445265},
739
        {1.695955, -0.357639, -0.893051, -0.445265}},
740
       {{1.695955, -0.357639, -0.893051, -0.445265},
741
        {1.695955, -0.357639, -0.893051, -0.445265}}},
742
      tensor_options);
743
  ASSERT_EQ(result.sizes(), ref_output.sizes());
744
  ASSERT_TRUE(
745
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
746
}
747

748
TEST_F(TransformerTest, TransformerEncoder) {
749
  transformer_encoder_test_helper(
750
      /*is_cuda=*/false, /*use_callable_activation=*/false);
751
  transformer_encoder_test_helper(
752
      /*is_cuda=*/false, /*use_callable_activation=*/true);
753
}
754

755
TEST_F(TransformerTest, TransformerEncoder_CUDA) {
756
  transformer_encoder_test_helper(
757
      /*is_cuda=*/true, /*use_callable_activation=*/false);
758
  transformer_encoder_test_helper(
759
      /*is_cuda=*/true, /*use_callable_activation=*/true);
760
}
761

762
TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) {
763
  ASSERT_EQ(
764
      c10::str(TransformerEncoderLayer(4, 2)),
765
      "torch::nn::TransformerEncoderLayerImpl(\n"
766
      "  (self_attn): torch::nn::MultiheadAttention(\n"
767
      "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
768
      "  )\n"
769
      "  (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
770
      "  (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
771
      "  (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
772
      "  (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
773
      "  (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
774
      "  (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
775
      "  (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
776
      ")");
777
}
778

779
TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
780
  LayerNorm norm = LayerNorm(LayerNormOptions({4}));
781
  TransformerEncoderOptions options(
782
      TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2)
783
          .norm(AnyModule(norm)));
784
  ASSERT_EQ(
785
      c10::str(TransformerEncoder(options)),
786
      "torch::nn::TransformerEncoderImpl(\n"
787
      "  (layers): torch::nn::ModuleList(\n"
788
      "    (0): torch::nn::TransformerEncoderLayerImpl(\n"
789
      "      (self_attn): torch::nn::MultiheadAttention(\n"
790
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
791
      "      )\n"
792
      "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
793
      "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
794
      "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
795
      "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
796
      "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
797
      "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
798
      "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
799
      "    )\n"
800
      "    (1): torch::nn::TransformerEncoderLayerImpl(\n"
801
      "      (self_attn): torch::nn::MultiheadAttention(\n"
802
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
803
      "      )\n"
804
      "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
805
      "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
806
      "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
807
      "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
808
      "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
809
      "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
810
      "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
811
      "    )\n"
812
      "  )\n"
813
      "  (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
814
      ")");
815
}
816

817
TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) {
818
  ASSERT_EQ(
819
      c10::str(TransformerDecoderLayer(4, 2)),
820
      "torch::nn::TransformerDecoderLayerImpl(\n"
821
      "  (self_attn): torch::nn::MultiheadAttention(\n"
822
      "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
823
      "  )\n"
824
      "  (multihead_attn): torch::nn::MultiheadAttention(\n"
825
      "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
826
      "  )\n"
827
      "  (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
828
      "  (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
829
      "  (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
830
      "  (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
831
      "  (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
832
      "  (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
833
      "  (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
834
      "  (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
835
      "  (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
836
      ")");
837
}
838

839
void transformer_decoder_test_helper(
840
    bool is_cuda,
841
    bool use_callable_activation) {
842
  // this is a deterministic test for TransformerDecoder
843
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
844
  torch::TensorOptions tensor_options =
845
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
846

847
  TransformerDecoderLayer decoder_layer =
848
      get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
849
          tensor_options, use_callable_activation);
850

851
  TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1));
852
  if (is_cuda) {
853
    model->to(torch::kCUDA);
854
  }
855

856
  torch::Tensor decoder_input =
857
      torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
858
  torch::Tensor memory_input =
859
      torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
860
  torch::Tensor result = model(decoder_input, memory_input).detach();
861
  torch::Tensor ref_output = torch::tensor(
862
      {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
863
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
864
  ASSERT_TRUE(torch::allclose(
865
      result,
866
      ref_output,
867
      1e-7,
868
      1e-5,
869
      /*equal_nan=*/true));
870

871
  // deterministic input
872
  decoder_input =
873
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
874
  memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
875
  result = model(decoder_input, memory_input).detach();
876
  ref_output = torch::tensor(
877
      {{{2.422245, 0.051716, -0.606338, -0.024756}},
878
       {{2.422245, 0.051716, -0.606338, -0.024756}}},
879
      tensor_options);
880
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
881
  ASSERT_TRUE(torch::allclose(
882
      result,
883
      ref_output,
884
      1e-7,
885
      1e-5,
886
      /*equal_nan=*/true));
887

888
  // deterministic input
889
  decoder_input =
890
      torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
891
  memory_input =
892
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
893
  result = model(decoder_input, memory_input).detach();
894
  ref_output = torch::tensor(
895
      {{{2.343536, 0.085561, -0.654954, 0.074991}},
896
       {{2.343536, 0.085561, -0.654954, 0.074991}}},
897
      tensor_options);
898
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
899
  ASSERT_TRUE(torch::allclose(
900
      result,
901
      ref_output,
902
      1e-7,
903
      1e-5,
904
      /*equal_nan=*/true));
905

906
  // deterministic input
907
  decoder_input = torch::tensor(
908
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
909
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
910
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
911
      tensor_options);
912
  memory_input = torch::tensor(
913
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
914
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
915
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
916
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
917
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
918
      tensor_options);
919
  result = model(decoder_input, memory_input).detach();
920
  ref_output = torch::tensor(
921
      {{{2.430065, 0.027862, -0.601136, -0.073096},
922
        {2.431935, 0.028907, -0.599809, -0.072488}},
923
       {{2.428457, 0.027053, -0.602275, -0.073462},
924
        {2.431970, 0.029387, -0.599789, -0.071621}},
925
       {{2.431934, 0.028196, -0.599802, -0.073809},
926
        {2.432306, 0.028858, -0.599542, -0.072846}}},
927
      tensor_options);
928
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
929
  ASSERT_TRUE(torch::allclose(
930
      result,
931
      ref_output,
932
      1e-7,
933
      1e-5,
934
      /*equal_nan=*/true));
935

936
  // key_padding_mask
937
  torch::Tensor t_mask = {};
938
  torch::Tensor m_mask = {};
939
  torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
940
  result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
941
               .detach();
942
  ref_output = torch::tensor(
943
      {{{2.430065, 0.027862, -0.601136, -0.073096},
944
        {2.431935, 0.028907, -0.599809, -0.072488}},
945
       {{2.428457, 0.027053, -0.602275, -0.073462},
946
        {2.431970, 0.029387, -0.599789, -0.071621}},
947
       {{2.431934, 0.028196, -0.599802, -0.073809},
948
        {2.432306, 0.028858, -0.599542, -0.072846}}},
949
      tensor_options);
950
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
951
  ASSERT_TRUE(torch::allclose(
952
      result,
953
      ref_output,
954
      1e-7,
955
      1e-5,
956
      /*equal_nan=*/true));
957

958
  // key_padding_mask
959
  key_padding_mask[0][2] = 1;
960
  key_padding_mask[1][1] = 1;
961
  key_padding_mask[1][2] = 1;
962
  result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
963
               .detach();
964
  ref_output = torch::tensor(
965
      {{{2.430025, 0.027643, -0.601164, -0.073476},
966
        {2.4323, 0.029375, -0.599553, -0.071881}},
967
       {{2.428523, 0.026838, -0.602226, -0.07391},
968
        {2.432634, 0.029842, -0.599318, -0.071253}},
969
       {{2.432278, 0.028152, -0.599555, -0.074139},
970
        {2.432659, 0.029244, -0.599294, -0.072382}}},
971
      tensor_options);
972
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
973
  ASSERT_TRUE(torch::allclose(
974
      result,
975
      ref_output,
976
      1e-7,
977
      1e-5,
978
      /*equal_nan=*/true));
979

980
  // memory_key_padding_mask
981
  torch::Tensor t_key_padding_mask = {};
982
  key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
983
  result = model(
984
               decoder_input,
985
               memory_input,
986
               t_mask,
987
               m_mask,
988
               t_key_padding_mask,
989
               key_padding_mask)
990
               .detach();
991
  ref_output = torch::tensor(
992
      {{{2.430065, 0.027862, -0.601136, -0.073096},
993
        {2.431935, 0.028907, -0.599809, -0.072488}},
994
       {{2.428457, 0.027053, -0.602275, -0.073462},
995
        {2.431970, 0.029387, -0.599789, -0.071621}},
996
       {{2.431934, 0.028196, -0.599802, -0.073809},
997
        {2.432306, 0.028858, -0.599542, -0.072846}}},
998
      tensor_options);
999
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1000
  ASSERT_TRUE(torch::allclose(
1001
      result,
1002
      ref_output,
1003
      1e-7,
1004
      1e-5,
1005
      /*equal_nan=*/true));
1006

1007
  // memory_key_padding_mask
1008
  key_padding_mask[0][4] = 1;
1009
  key_padding_mask[1][3] = 1;
1010
  key_padding_mask[1][4] = 1;
1011
  result = model(
1012
               decoder_input,
1013
               memory_input,
1014
               t_mask,
1015
               m_mask,
1016
               t_key_padding_mask,
1017
               key_padding_mask)
1018
               .detach();
1019
  ref_output = torch::tensor(
1020
      {{{2.429757, 0.027358, -0.601351, -0.073816},
1021
        {2.432692, 0.028583, -0.599263, -0.073634}},
1022
       {{2.428247, 0.02662, -0.602419, -0.074123},
1023
        {2.432657, 0.029055, -0.599293, -0.072732}},
1024
       {{2.431515, 0.027687, -0.600096, -0.074459},
1025
        {2.433075, 0.028543, -0.598987, -0.073985}}},
1026
      tensor_options);
1027
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1028
  ASSERT_TRUE(torch::allclose(
1029
      result,
1030
      ref_output,
1031
      1e-7,
1032
      1e-5,
1033
      /*equal_nan=*/true));
1034

1035
  // multiple layers no norm
1036
  model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2));
1037
  if (is_cuda) {
1038
    model->to(torch::kCUDA);
1039
  }
1040

1041
  decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1042
  memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1043
  result = model(decoder_input, memory_input).detach();
1044
  ref_output = torch::tensor(
1045
      {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options);
1046
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1047
  ASSERT_TRUE(torch::allclose(
1048
      result,
1049
      ref_output,
1050
      1e-7,
1051
      1e-5,
1052
      /*equal_nan=*/true));
1053

1054
  // multiple layers no norm
1055
  model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1056
  if (is_cuda) {
1057
    model->to(torch::kCUDA);
1058
  }
1059
  // deterministic input
1060
  decoder_input = torch::tensor(
1061
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1062
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1063
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1064
      tensor_options);
1065
  memory_input = torch::tensor(
1066
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1067
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1068
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1069
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1070
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1071
      tensor_options);
1072
  result = model(decoder_input, memory_input).detach();
1073
  ref_output = torch::tensor(
1074
      {{{2.42794, 0.026164, -0.60263, -0.0747591},
1075
        {2.43113, 0.0279516, -0.600376, -0.0736896}},
1076
       {{2.42794, 0.026164, -0.60263, -0.0747591},
1077
        {2.43113, 0.0279516, -0.600376, -0.0736896}},
1078
       {{2.42794, 0.026164, -0.60263, -0.0747591},
1079
        {2.43113, 0.0279516, -0.600376, -0.0736896}}},
1080
      tensor_options);
1081
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1082
  ASSERT_TRUE(torch::allclose(
1083
      result,
1084
      ref_output,
1085
      1e-7,
1086
      1e-5,
1087
      /*equal_nan=*/true));
1088

1089
  // multiple layers with norm
1090
  LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1091
  model = TransformerDecoder(
1092
      TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm)));
1093
  if (is_cuda) {
1094
    model->to(torch::kCUDA);
1095
  }
1096

1097
  decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1098
  memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1099
  result = model(decoder_input, memory_input).detach();
1100
  ref_output = torch::tensor(
1101
      {{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options);
1102
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1103
  ASSERT_TRUE(torch::allclose(
1104
      result,
1105
      ref_output,
1106
      1e-7,
1107
      1e-5,
1108
      /*equal_nan=*/true));
1109

1110
  // multiple layers with norm
1111
  model = TransformerDecoder(
1112
      TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1113
  if (is_cuda) {
1114
    model->to(torch::kCUDA);
1115
  }
1116
  // deterministic input
1117
  decoder_input = torch::tensor(
1118
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1119
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1120
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1121
      tensor_options);
1122
  memory_input = torch::tensor(
1123
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1124
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1125
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1126
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1127
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1128
      tensor_options);
1129
  result = model(decoder_input, memory_input).detach();
1130
  ref_output = torch::tensor(
1131
      {{{1.69559, -0.357291, -0.894741, -0.443553},
1132
        {1.69571, -0.357363, -0.894154, -0.444196}},
1133
       {{1.69559, -0.357291, -0.894741, -0.443553},
1134
        {1.69571, -0.357363, -0.894154, -0.444196}},
1135
       {{1.69559, -0.357291, -0.894741, -0.443553},
1136
        {1.69571, -0.357363, -0.894154, -0.444196}}},
1137
      tensor_options);
1138
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1139
  ASSERT_TRUE(torch::allclose(
1140
      result,
1141
      ref_output,
1142
      1e-7,
1143
      1e-5,
1144
      /*equal_nan=*/true));
1145

1146
  // gelu activation test cases
1147
  decoder_layer.get()->options.activation(torch::kGELU);
1148
  model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1));
1149
  if (is_cuda) {
1150
    model->to(torch::kCUDA);
1151
  }
1152

1153
  // deterministic input
1154
  decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1155
  memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1156
  result = model(decoder_input, memory_input).detach();
1157
  ref_output = torch::tensor(
1158
      {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
1159
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1160
  ASSERT_TRUE(torch::allclose(
1161
      result,
1162
      ref_output,
1163
      1e-7,
1164
      1e-5,
1165
      /*equal_nan=*/true));
1166

1167
  // deterministic input
1168
  decoder_input =
1169
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1170
  memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
1171
  result = model(decoder_input, memory_input).detach();
1172
  ref_output = torch::tensor(
1173
      {{{2.415448, 0.054389, -0.610932, -0.0156613}},
1174
       {{2.415448, 0.054389, -0.610932, -0.0156613}}},
1175
      tensor_options);
1176
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1177
  ASSERT_TRUE(torch::allclose(
1178
      result,
1179
      ref_output,
1180
      1e-7,
1181
      1e-5,
1182
      /*equal_nan=*/true));
1183

1184
  // deterministic input
1185
  decoder_input =
1186
      torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
1187
  memory_input =
1188
      torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1189
  result = model(decoder_input, memory_input).detach();
1190
  ref_output = torch::tensor(
1191
      {{{2.338531, 0.087709, -0.65776, 0.080646}},
1192
       {{2.338531, 0.087709, -0.65776, 0.080646}}},
1193
      tensor_options);
1194
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1195
  ASSERT_TRUE(torch::allclose(
1196
      result,
1197
      ref_output,
1198
      1e-7,
1199
      1e-5,
1200
      /*equal_nan=*/true));
1201

1202
  // deterministic input
1203
  decoder_input = torch::tensor(
1204
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1205
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1206
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1207
      tensor_options);
1208
  memory_input = torch::tensor(
1209
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1210
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1211
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1212
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1213
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1214
      tensor_options);
1215
  result = model(decoder_input, memory_input).detach();
1216
  ref_output = torch::tensor(
1217
      {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
1218
        {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
1219
       {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
1220
        {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
1221
       {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
1222
        {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
1223
      tensor_options);
1224
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1225
  ASSERT_TRUE(torch::allclose(
1226
      result,
1227
      ref_output,
1228
      1e-7,
1229
      1e-5,
1230
      /*equal_nan=*/true));
1231

1232
  // Multiple layers no norm
1233
  model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1234
  if (is_cuda) {
1235
    model->to(torch::kCUDA);
1236
  }
1237
  decoder_input = torch::tensor(
1238
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1239
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1240
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1241
      tensor_options);
1242
  memory_input = torch::tensor(
1243
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1244
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1245
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1246
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1247
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1248
      tensor_options);
1249
  result = model(decoder_input, memory_input).detach();
1250
  ref_output = torch::tensor(
1251
      {{{2.41859, 0.0328114, -0.609269, -0.0560386},
1252
        {2.42138, 0.034598, -0.607316, -0.0546574}},
1253
       {{2.41859, 0.0328114, -0.609269, -0.0560386},
1254
        {2.42138, 0.034598, -0.607316, -0.0546574}},
1255
       {{2.41859, 0.0328114, -0.609269, -0.0560386},
1256
        {2.42138, 0.034598, -0.607316, -0.0546574}}},
1257
      tensor_options);
1258
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1259
  ASSERT_TRUE(torch::allclose(
1260
      result,
1261
      ref_output,
1262
      1e-7,
1263
      1e-5,
1264
      /*equal_nan=*/true));
1265

1266
  // Multiple layers with norm
1267
  norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1268
  model = TransformerDecoder(
1269
      TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1270
  if (is_cuda) {
1271
    model->to(torch::kCUDA);
1272
  }
1273

1274
  decoder_input = torch::tensor(
1275
      {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1276
       {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1277
       {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1278
      tensor_options);
1279
  memory_input = torch::tensor(
1280
      {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1281
       {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1282
       {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1283
       {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1284
       {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1285
      tensor_options);
1286
  result = model(decoder_input, memory_input).detach();
1287
  ref_output = torch::tensor(
1288
      {{{1.69298, -0.355163, -0.906375, -0.431439},
1289
        {1.69305, -0.355195, -0.906062, -0.431791}},
1290
       {{1.69298, -0.355163, -0.906375, -0.431439},
1291
        {1.69305, -0.355195, -0.906062, -0.431791}},
1292
       {{1.69298, -0.355163, -0.906375, -0.431439},
1293
        {1.69305, -0.355195, -0.906062, -0.431791}}},
1294
      tensor_options);
1295
  ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1296
  ASSERT_TRUE(torch::allclose(
1297
      result,
1298
      ref_output,
1299
      1e-7,
1300
      1e-5,
1301
      /*equal_nan=*/true));
1302
}
1303

1304
TEST_F(TransformerTest, TransformerDecoder) {
1305
  transformer_decoder_test_helper(
1306
      /*is_cuda=*/false, /*use_callable_activation=*/false);
1307
  transformer_decoder_test_helper(
1308
      /*is_cuda=*/false, /*use_callable_activation=*/true);
1309
}
1310

1311
TEST_F(TransformerTest, TransformerDecoder_CUDA) {
1312
  transformer_decoder_test_helper(
1313
      /*is_cuda=*/true, /*use_callable_activation=*/false);
1314
  transformer_decoder_test_helper(
1315
      /*is_cuda=*/true, /*use_callable_activation=*/true);
1316
}
1317

1318
TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
1319
  LayerNorm norm = LayerNorm(LayerNormOptions({4}));
1320
  TransformerDecoderOptions options(
1321
      TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2)
1322
          .norm(AnyModule(norm)));
1323
  ASSERT_EQ(
1324
      c10::str(TransformerDecoder(options)),
1325
      "torch::nn::TransformerDecoderImpl(\n"
1326
      "  (layers): torch::nn::ModuleList(\n"
1327
      "    (0): torch::nn::TransformerDecoderLayerImpl(\n"
1328
      "      (self_attn): torch::nn::MultiheadAttention(\n"
1329
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1330
      "      )\n"
1331
      "      (multihead_attn): torch::nn::MultiheadAttention(\n"
1332
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1333
      "      )\n"
1334
      "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1335
      "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1336
      "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1337
      "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1338
      "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1339
      "      (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1340
      "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1341
      "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1342
      "      (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1343
      "    )\n"
1344
      "    (1): torch::nn::TransformerDecoderLayerImpl(\n"
1345
      "      (self_attn): torch::nn::MultiheadAttention(\n"
1346
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1347
      "      )\n"
1348
      "      (multihead_attn): torch::nn::MultiheadAttention(\n"
1349
      "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1350
      "      )\n"
1351
      "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1352
      "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1353
      "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1354
      "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1355
      "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1356
      "      (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1357
      "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1358
      "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1359
      "      (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1360
      "    )\n"
1361
      "  )\n"
1362
      "  (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1363
      ")");
1364
}
1365

1366
void transformer_test_helper(bool is_cuda, bool use_callable_activation) {
1367
  // this is a deterministic test for Transformere
1368
  torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
1369
  torch::TensorOptions tensor_options =
1370
      torch::TensorOptions().dtype(torch::kFloat32).device(device);
1371

1372
  // transformer created encoder/decoder
1373
  auto options = TransformerOptions()
1374
                     .d_model(4)
1375
                     .nhead(2)
1376
                     .num_encoder_layers(2)
1377
                     .num_decoder_layers(1)
1378
                     .dim_feedforward(16)
1379
                     .dropout(0.0)
1380
                     .activation(torch::kReLU);
1381
  if (use_callable_activation) {
1382
    options.activation(
1383
        [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
1384
  }
1385
  Transformer model(options);
1386

1387
  set_parameter_to_constants<Transformer>(model, tensor_options);
1388
  if (tensor_options.device() == torch::kCUDA) {
1389
    model->to(torch::kCUDA);
1390
  }
1391

1392
  // transformer with customized encoder/decoder
1393
  LayerNorm enorm(LayerNormOptions({4}));
1394
  TransformerEncoder encoder(
1395
      TransformerEncoderOptions(
1396
          TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1397
          2)
1398
          .norm(AnyModule(enorm)));
1399

1400
  LayerNorm dnorm(LayerNormOptions({4}));
1401
  TransformerDecoder decoder(
1402
      TransformerDecoderOptions(
1403
          TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1404
          1)
1405
          .norm(AnyModule(dnorm)));
1406

1407
  Transformer model_cus(TransformerOptions()
1408
                            .d_model(4)
1409
                            .nhead(2)
1410
                            .custom_encoder(AnyModule(encoder))
1411
                            .custom_decoder(AnyModule(decoder)));
1412

1413
  set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1414
  if (tensor_options.device() == torch::kCUDA) {
1415
    model_cus->to(torch::kCUDA);
1416
  }
1417

1418
  // test cases
1419
  torch::Tensor src = torch::tensor(
1420
      {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1421
       {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}},
1422
       {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}},
1423
      tensor_options);
1424

1425
  torch::Tensor tgt = torch::tensor(
1426
      {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1427
       {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}},
1428
      tensor_options);
1429

1430
  torch::Tensor ref_output = torch::tensor(
1431
      {{{2.695875, 0.347114, -0.044355, -0.549541},
1432
        {2.696091, 0.347015, -0.044770, -0.548522}},
1433
       {{2.695875, 0.347114, -0.044355, -0.549541},
1434
        {2.696091, 0.347015, -0.044770, -0.548522}}},
1435
      tensor_options);
1436
  torch::Tensor result = model(src, tgt);
1437
  torch::Tensor result_cus = model_cus(src, tgt);
1438
  ASSERT_EQ(result.sizes(), ref_output.sizes());
1439
  ASSERT_TRUE(result.equal(result_cus));
1440
  ASSERT_TRUE(
1441
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1442

1443
  torch::Tensor src_mask =
1444
      Transformer::Impl::generate_square_subsequent_mask(src.size(0))
1445
          .to(tensor_options);
1446
  ref_output = torch::tensor(
1447
      {{{2.695875, 0.347114, -0.044355, -0.549541},
1448
        {2.696091, 0.347015, -0.044770, -0.548522}},
1449
       {{2.695875, 0.347114, -0.044355, -0.549541},
1450
        {2.696091, 0.347015, -0.044770, -0.548522}}},
1451
      tensor_options);
1452
  result = model(src, tgt, src_mask);
1453
  result_cus = model_cus(src, tgt, src_mask);
1454
  ASSERT_EQ(result.sizes(), ref_output.sizes());
1455
  ASSERT_TRUE(result.equal(result_cus));
1456
  ASSERT_TRUE(
1457
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1458

1459
  torch::Tensor tgt_key_padding_mask =
1460
      torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1;
1461
  tgt_key_padding_mask[0][0] = 1;
1462
  tgt_key_padding_mask[1][1] = 1;
1463
  ref_output = torch::tensor(
1464
      {{{2.696114, 0.347004, -0.044813, -0.548417},
1465
        {2.696091, 0.347015, -0.044770, -0.548522}},
1466
       {{2.696114, 0.347004, -0.044813, -0.548417},
1467
        {2.696091, 0.347015, -0.044770, -0.548522}}},
1468
      tensor_options);
1469
  result = model(
1470
      src,
1471
      tgt,
1472
      src_mask,
1473
      torch::Tensor(),
1474
      torch::Tensor(),
1475
      torch::Tensor(),
1476
      tgt_key_padding_mask);
1477
  result_cus = model_cus(
1478
      src,
1479
      tgt,
1480
      src_mask,
1481
      torch::Tensor(),
1482
      torch::Tensor(),
1483
      torch::Tensor(),
1484
      tgt_key_padding_mask);
1485
  ASSERT_EQ(result.sizes(), ref_output.sizes());
1486
  ASSERT_TRUE(result.equal(result_cus));
1487
  ASSERT_TRUE(
1488
      torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1489
}
1490

1491
TEST_F(TransformerTest, Transformer) {
1492
  transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false);
1493
  transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true);
1494
}
1495

1496
TEST_F(TransformerTest, Transformer_CUDA) {
1497
  transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false);
1498
  transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true);
1499
}
1500

1501
TEST_F(TransformerTest, TransformerArgsCorrectness) {
1502
  Transformer model(TransformerOptions()
1503
                        .d_model(4)
1504
                        .nhead(2)
1505
                        .num_encoder_layers(2)
1506
                        .num_decoder_layers(1)
1507
                        .dim_feedforward(16)
1508
                        .dropout(0.0)
1509
                        .activation(torch::kReLU));
1510

1511
  torch::Tensor src = torch::randn({2, 3, 4});
1512
  torch::Tensor tgt = torch::randn({3, 2, 4});
1513

1514
  ASSERT_THROWS_WITH(
1515
      model(src, tgt), "src and tgt should have equal batch size");
1516

1517
  tgt = torch::randn({2, 3, 3});
1518
  ASSERT_THROWS_WITH(
1519
      model(src, tgt), "src and tgt should have same feature size as d_model");
1520

1521
  src = torch::randn({2, 3});
1522
  ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions");
1523
}
1524

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.