1
#include <gtest/gtest.h>
3
#include <torch/torch.h>
5
#include <test/cpp/api/support.h>
7
using namespace torch::nn;
9
struct TransformerTest : torch::test::SeedingFixture {};
11
// a generic function to set constants for parameters so we have fixed result
12
// for deterministic test
13
template <typename Model>
14
void set_parameter_to_constants(
16
const torch::TensorOptions& tensor_options) {
17
torch::NoGradGuard guard;
18
for (auto& p : model->parameters()) {
19
auto sz = p.view(-1).size(0);
20
p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
24
// a generic function to provide consistent encoder/decoder layer for all the
26
template <typename T_LAYER, typename T_OPTIONS>
27
T_LAYER get_a_test_layer(
28
const torch::TensorOptions& tensor_options,
29
bool use_callable_activation) {
32
int64_t dim_feedforward = 16;
35
// activation is always ReLU here and it can be adjusted later depending on
37
T_LAYER layer(T_OPTIONS(d_model, nhead)
38
.dim_feedforward(dim_feedforward)
40
if (tensor_options.device() == torch::kCUDA) {
41
layer->to(torch::kCUDA);
43
if (use_callable_activation) {
44
layer.get()->options.activation(
45
[&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
48
// set constant weights of the model
49
set_parameter_to_constants<T_LAYER>(layer, tensor_options);
54
void transformer_encoder_layer_test_helper(
56
bool use_callable_activation) {
57
// this is a deterministic test for TransformerEncoderLayer
58
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
59
torch::TensorOptions tensor_options =
60
torch::TensorOptions().dtype(torch::kFloat32).device(device);
62
TransformerEncoderLayer model =
63
get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
64
tensor_options, use_callable_activation);
67
torch::Tensor encoder_input =
68
torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
69
torch::Tensor result = model(encoder_input).detach();
70
torch::Tensor ref_output = torch::tensor(
71
{{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options);
72
ASSERT_EQ(result.sizes(), ref_output.sizes());
74
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
76
// all 0 values are NOT masked. This should't mask anything
77
torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1;
80
/*src_mask=*/torch::Tensor{},
81
/*src_key_padding_mask=*/mask)
83
ASSERT_EQ(result.sizes(), ref_output.sizes());
85
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
87
// all 1 values are masked. Since there is only 1 input embedding this will
89
mask = torch::tensor({{1}}, tensor_options) == 1;
92
/*src_mask=*/torch::Tensor{},
93
/*src_key_padding_mask=*/mask)
95
ASSERT_TRUE(torch::isnan(result).all().item().to<bool>());
99
torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
100
result = model(encoder_input).detach();
101
ref_output = torch::tensor(
102
{{{2.272644, 0.119035, -0.691669, 0.153486}},
103
{{2.272644, 0.119035, -0.691669, 0.153486}}},
105
ASSERT_EQ(result.sizes(), ref_output.sizes());
107
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
109
// all 0 values are NOT masked
110
mask = torch::tensor({{0, 0}}, tensor_options) == 1;
113
/*src_mask=*/torch::Tensor{},
114
/*src_key_padding_mask=*/mask)
116
ASSERT_EQ(result.sizes(), ref_output.sizes());
118
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
121
mask = torch::tensor({{1, 0}}, tensor_options) == 1;
124
/*src_mask=*/torch::Tensor{},
125
/*src_key_padding_mask=*/mask)
127
ref_output = torch::tensor(
128
{{{2.301516, 0.092249, -0.679101, 0.103088}},
129
{{2.301516, 0.092249, -0.679101, 0.103088}}},
131
ASSERT_EQ(result.sizes(), ref_output.sizes());
133
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
136
encoder_input = torch::tensor(
137
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
138
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
139
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
140
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
141
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
143
result = model(encoder_input).detach();
144
ref_output = torch::tensor(
145
{{{2.428589, 0.020835, -0.602055, -0.085249},
146
{2.427987, 0.021213, -0.602496, -0.084103}},
147
{{2.424689, 0.019155, -0.604793, -0.085672},
148
{2.413863, 0.022211, -0.612486, -0.072490}},
149
{{2.433774, 0.021598, -0.598343, -0.087548},
150
{2.425104, 0.019748, -0.604515, -0.084839}},
151
{{2.436185, 0.022682, -0.596625, -0.087261},
152
{2.433556, 0.021891, -0.598509, -0.086832}},
153
{{2.416246, 0.017512, -0.610712, -0.082961},
154
{2.422901, 0.024187, -0.606178, -0.074929}}},
156
ASSERT_EQ(result.sizes(), ref_output.sizes());
158
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
160
// all 0 values are NOT masked
161
mask = torch::zeros({2, 5}, tensor_options) == 1;
164
/*src_mask=*/torch::Tensor{},
165
/*src_key_padding_mask=*/mask)
167
ASSERT_EQ(result.sizes(), ref_output.sizes());
169
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
171
// mask with 0s and 1s
177
/*src_mask=*/torch::Tensor{},
178
/*src_key_padding_mask=*/mask)
180
ref_output = torch::tensor(
181
{{{2.429026, 0.020793, -0.601741, -0.085642},
182
{2.428811, 0.021445, -0.601912, -0.084252}},
183
{{2.425009, 0.019155, -0.604566, -0.085899},
184
{2.415408, 0.02249, -0.611415, -0.073}},
185
{{2.434199, 0.021682, -0.598039, -0.087699},
186
{2.42598, 0.019941, -0.603896, -0.085091}},
187
{{2.436457, 0.022736, -0.59643, -0.08736},
188
{2.434021, 0.022093, -0.598179, -0.08679}},
189
{{2.416531, 0.017498, -0.610513, -0.083181},
190
{2.4242, 0.024653, -0.605266, -0.074959}}},
192
ASSERT_EQ(result.sizes(), ref_output.sizes());
194
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
197
model.get()->options.activation(torch::kGELU);
198
encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
199
result = model(encoder_input).detach();
200
ref_output = torch::tensor(
201
{{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options);
202
ASSERT_EQ(result.sizes(), ref_output.sizes());
204
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
207
encoder_input = torch::tensor(
208
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
209
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
210
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
211
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
212
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
214
result = model(encoder_input);
215
ref_output = torch::tensor(
216
{{{2.42163188, 0.03227153, -0.60714219, -0.05908082},
217
{2.42151276, 0.03302179, -0.60722523, -0.05762651}},
218
{{2.41926761, 0.02974034, -0.60879519, -0.0621269},
219
{2.41626395, 0.03539356, -0.61087842, -0.04978623}},
220
{{2.42382808, 0.03218872, -0.6055963, -0.06073591},
221
{2.41983477, 0.03085259, -0.60840145, -0.06046414}},
222
{{2.42500749, 0.03328855, -0.60476388, -0.0595334},
223
{2.4237977, 0.03290575, -0.60561789, -0.05940082}},
224
{{2.41383916, 0.02686345, -0.61256377, -0.06380707},
225
{2.42000277, 0.03800944, -0.60824798, -0.04754947}}},
227
ASSERT_EQ(result.sizes(), ref_output.sizes());
229
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
232
TEST_F(TransformerTest, TransformerEncoderLayer) {
233
transformer_encoder_layer_test_helper(
234
/*is_cuda=*/false, /*use_callable_activation=*/false);
235
transformer_encoder_layer_test_helper(
236
/*is_cuda=*/false, /*use_callable_activation=*/true);
239
TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) {
240
transformer_encoder_layer_test_helper(
241
/*is_cuda=*/true, /*use_callable_activation=*/false);
242
transformer_encoder_layer_test_helper(
243
/*is_cuda=*/true, /*use_callable_activation=*/true);
246
void transformer_decoder_layer_test_helper(
248
bool use_callable_activation) {
249
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
250
torch::TensorOptions tensor_options =
251
torch::TensorOptions().dtype(torch::kFloat32).device(device);
253
TransformerDecoderLayer model =
254
get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
255
tensor_options, use_callable_activation);
257
// deterministic input
258
torch::Tensor decoder_input =
259
torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
260
torch::Tensor memory_input =
261
torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
262
torch::Tensor result = model(decoder_input, memory_input).detach();
263
torch::Tensor ref_output = torch::tensor(
264
{{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
265
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
266
ASSERT_TRUE(torch::allclose(
271
/*equal_nan=*/true));
273
// deterministic input
275
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
276
memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
277
result = model(decoder_input, memory_input).detach();
278
ref_output = torch::tensor(
279
{{{2.422245, 0.051716, -0.606338, -0.024756}},
280
{{2.422245, 0.051716, -0.606338, -0.024756}}},
282
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
283
ASSERT_TRUE(torch::allclose(
288
/*equal_nan=*/true));
290
// deterministic input
292
torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
294
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
295
result = model(decoder_input, memory_input).detach();
296
ref_output = torch::tensor(
297
{{{2.343536, 0.085561, -0.654954, 0.074991}},
298
{{2.343536, 0.085561, -0.654954, 0.074991}}},
300
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
301
ASSERT_TRUE(torch::allclose(
306
/*equal_nan=*/true));
308
// deterministic input
309
decoder_input = torch::tensor(
310
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
311
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
312
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
314
memory_input = torch::tensor(
315
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
316
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
317
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
318
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
319
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
321
result = model(decoder_input, memory_input).detach();
322
ref_output = torch::tensor(
323
{{{2.430065, 0.027862, -0.601136, -0.073096},
324
{2.431935, 0.028907, -0.599809, -0.072488}},
325
{{2.428457, 0.027053, -0.602275, -0.073462},
326
{2.431970, 0.029387, -0.599789, -0.071621}},
327
{{2.431934, 0.028196, -0.599802, -0.073809},
328
{2.432306, 0.028858, -0.599542, -0.072846}}},
330
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
331
ASSERT_TRUE(torch::allclose(
336
/*equal_nan=*/true));
339
torch::Tensor t_mask = {};
340
torch::Tensor m_mask = {};
341
torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
342
result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
344
ref_output = torch::tensor(
345
{{{2.430065, 0.027862, -0.601136, -0.073096},
346
{2.431935, 0.028907, -0.599809, -0.072488}},
347
{{2.428457, 0.027053, -0.602275, -0.073462},
348
{2.431970, 0.029387, -0.599789, -0.071621}},
349
{{2.431934, 0.028196, -0.599802, -0.073809},
350
{2.432306, 0.028858, -0.599542, -0.072846}}},
352
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
353
ASSERT_TRUE(torch::allclose(
358
/*equal_nan=*/true));
361
key_padding_mask[0][2] = 1;
362
key_padding_mask[1][1] = 1;
363
key_padding_mask[1][2] = 1;
364
result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
366
ref_output = torch::tensor(
367
{{{2.430025, 0.027643, -0.601164, -0.073476},
368
{2.4323, 0.029375, -0.599553, -0.071881}},
369
{{2.428523, 0.026838, -0.602226, -0.07391},
370
{2.432634, 0.029842, -0.599318, -0.071253}},
371
{{2.432278, 0.028152, -0.599555, -0.074139},
372
{2.432659, 0.029244, -0.599294, -0.072382}}},
374
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
375
ASSERT_TRUE(torch::allclose(
380
/*equal_nan=*/true));
382
// memory_key_padding_mask
383
torch::Tensor t_key_padding_mask = {};
384
key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
393
ref_output = torch::tensor(
394
{{{2.430065, 0.027862, -0.601136, -0.073096},
395
{2.431935, 0.028907, -0.599809, -0.072488}},
396
{{2.428457, 0.027053, -0.602275, -0.073462},
397
{2.431970, 0.029387, -0.599789, -0.071621}},
398
{{2.431934, 0.028196, -0.599802, -0.073809},
399
{2.432306, 0.028858, -0.599542, -0.072846}}},
401
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
402
ASSERT_TRUE(torch::allclose(
407
/*equal_nan=*/true));
409
// memory_key_padding_mask
410
key_padding_mask[0][4] = 1;
411
key_padding_mask[1][3] = 1;
412
key_padding_mask[1][4] = 1;
421
ref_output = torch::tensor(
422
{{{2.429757, 0.027358, -0.601351, -0.073816},
423
{2.432692, 0.028583, -0.599263, -0.073634}},
424
{{2.428247, 0.02662, -0.602419, -0.074123},
425
{2.432657, 0.029055, -0.599293, -0.072732}},
426
{{2.431515, 0.027687, -0.600096, -0.074459},
427
{2.433075, 0.028543, -0.598987, -0.073985}}},
429
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
430
ASSERT_TRUE(torch::allclose(
435
/*equal_nan=*/true));
438
TEST_F(TransformerTest, TransformerDecoderLayer) {
439
transformer_decoder_layer_test_helper(
440
/*is_cuda=*/false, /*use_callable_activation=*/false);
441
transformer_decoder_layer_test_helper(
442
/*is_cuda=*/false, /*use_callable_activation=*/true);
445
TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) {
446
transformer_decoder_layer_test_helper(
447
/*is_cuda=*/true, /*use_callable_activation=*/false);
448
transformer_decoder_layer_test_helper(
449
/*is_cuda=*/true, /*use_callable_activation=*/true);
452
void transformer_decoder_layer_test_helper_gelu(
454
bool use_callable_activation) {
455
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
456
torch::TensorOptions tensor_options =
457
torch::TensorOptions().dtype(torch::kFloat32).device(device);
459
TransformerDecoderLayer model =
460
get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
461
tensor_options, use_callable_activation);
462
if (use_callable_activation) {
463
model.get()->options.activation(
464
[&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); });
466
model.get()->options.activation(torch::kGELU);
469
// deterministic input
470
torch::Tensor decoder_input =
471
torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
472
torch::Tensor memory_input =
473
torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
474
torch::Tensor result = model(decoder_input, memory_input).detach();
475
torch::Tensor ref_output = torch::tensor(
476
{{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
477
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
478
ASSERT_TRUE(torch::allclose(
483
/*equal_nan=*/true));
485
// deterministic input
487
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
488
memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
489
result = model(decoder_input, memory_input).detach();
490
ref_output = torch::tensor(
491
{{{2.415448, 0.054389, -0.610932, -0.0156613}},
492
{{2.415448, 0.054389, -0.610932, -0.0156613}}},
494
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
495
ASSERT_TRUE(torch::allclose(
500
/*equal_nan=*/true));
502
// deterministic input
504
torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
506
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
507
result = model(decoder_input, memory_input).detach();
508
ref_output = torch::tensor(
509
{{{2.338531, 0.087709, -0.65776, 0.080646}},
510
{{2.338531, 0.087709, -0.65776, 0.080646}}},
512
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
513
ASSERT_TRUE(torch::allclose(
518
/*equal_nan=*/true));
520
// deterministic input
521
decoder_input = torch::tensor(
522
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
523
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
524
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
526
memory_input = torch::tensor(
527
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
528
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
529
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
530
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
531
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
533
result = model(decoder_input, memory_input).detach();
534
ref_output = torch::tensor(
535
{{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
536
{2.42210631, 0.03546578, -0.60679895, -0.05357488}},
537
{{2.41907674, 0.0336104, -0.60892977, -0.05490462},
538
{2.42216881, 0.03586554, -0.6067524, -0.05289126}},
539
{{2.42205716, 0.03488046, -0.60683681, -0.05460596},
540
{2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
542
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
543
ASSERT_TRUE(torch::allclose(
548
/*equal_nan=*/true));
551
TEST_F(TransformerTest, TransformerDecoderLayer_gelu) {
552
transformer_decoder_layer_test_helper_gelu(
553
/*is_cuda=*/false, /*use_callable_activation=*/false);
554
transformer_decoder_layer_test_helper_gelu(
555
/*is_cuda=*/false, /*use_callable_activation=*/true);
558
TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) {
559
transformer_decoder_layer_test_helper_gelu(
560
/*is_cuda=*/true, /*use_callable_activation=*/false);
561
transformer_decoder_layer_test_helper_gelu(
562
/*is_cuda=*/true, /*use_callable_activation=*/true);
565
void transformer_encoder_test_helper(
567
bool use_callable_activation) {
568
// this is a deterministic test for TransformerEncoderLayer
569
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
570
torch::TensorOptions tensor_options =
571
torch::TensorOptions().dtype(torch::kFloat32).device(device);
573
TransformerEncoderLayer encoder_layer =
574
get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
575
tensor_options, use_callable_activation);
577
TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1));
579
model->to(torch::kCUDA);
582
torch::Tensor encoder_input = torch::tensor(
583
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
584
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
585
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
586
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
587
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
589
torch::Tensor result = model(encoder_input).detach();
590
torch::Tensor ref_output = torch::tensor(
591
{{{2.428589, 0.020835, -0.602055, -0.085249},
592
{2.427987, 0.021213, -0.602496, -0.084103}},
593
{{2.424689, 0.019155, -0.604793, -0.085672},
594
{2.413863, 0.022211, -0.612486, -0.072490}},
595
{{2.433774, 0.021598, -0.598343, -0.087548},
596
{2.425104, 0.019748, -0.604515, -0.084839}},
597
{{2.436185, 0.022682, -0.596625, -0.087261},
598
{2.433556, 0.021891, -0.598509, -0.086832}},
599
{{2.416246, 0.017512, -0.610712, -0.082961},
600
{2.422901, 0.024187, -0.606178, -0.074929}}},
602
ASSERT_EQ(result.sizes(), ref_output.sizes());
604
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
606
// all 0 values are NOT masked
607
torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1;
610
/*src_mask=*/torch::Tensor{},
611
/*src_key_padding_mask=*/mask)
613
ASSERT_EQ(result.sizes(), ref_output.sizes());
615
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
617
// mask with 0s and 1s
623
/*src_mask=*/torch::Tensor{},
624
/*src_key_padding_mask=*/mask)
626
ref_output = torch::tensor(
627
{{{2.429026, 0.020793, -0.601741, -0.085642},
628
{2.428811, 0.021445, -0.601912, -0.084252}},
629
{{2.425009, 0.019155, -0.604566, -0.085899},
630
{2.415408, 0.02249, -0.611415, -0.073}},
631
{{2.434199, 0.021682, -0.598039, -0.087699},
632
{2.42598, 0.019941, -0.603896, -0.085091}},
633
{{2.436457, 0.022736, -0.59643, -0.08736},
634
{2.434021, 0.022093, -0.598179, -0.08679}},
635
{{2.416531, 0.017498, -0.610513, -0.083181},
636
{2.4242, 0.024653, -0.605266, -0.074959}}},
638
ASSERT_EQ(result.sizes(), ref_output.sizes());
640
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
642
// test case 2, multiple layers no norm
643
model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2));
645
model->to(torch::kCUDA);
649
/*src_mask=*/torch::Tensor{},
650
/*src_key_padding_mask=*/mask)
652
ref_output = torch::tensor(
653
{{{2.419051, 0.017446, -0.608738, -0.085003},
654
{2.419102, 0.017452, -0.608703, -0.085026}},
655
{{2.419043, 0.017445, -0.608744, -0.084999},
656
{2.419052, 0.017446, -0.608738, -0.085004}},
657
{{2.419067, 0.017448, -0.608727, -0.085010},
658
{2.419098, 0.017452, -0.608706, -0.085024}},
659
{{2.419072, 0.017449, -0.608724, -0.085012},
660
{2.419119, 0.017455, -0.608691, -0.085034}},
661
{{2.419019, 0.017442, -0.608761, -0.084989},
662
{2.419075, 0.017449, -0.608722, -0.085014}}},
664
ASSERT_EQ(result.sizes(), ref_output.sizes());
666
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
668
model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6));
670
model->to(torch::kCUDA);
674
/*src_mask=*/torch::Tensor{},
675
/*src_key_padding_mask=*/mask)
677
ref_output = torch::tensor(
678
{{{2.419101, 0.017453, -0.608703, -0.085025},
679
{2.419101, 0.017453, -0.608704, -0.085025}},
680
{{2.419101, 0.017453, -0.608703, -0.085025},
681
{2.419101, 0.017453, -0.608704, -0.085025}},
682
{{2.419101, 0.017453, -0.608703, -0.085025},
683
{2.419101, 0.017453, -0.608704, -0.085025}},
684
{{2.419101, 0.017453, -0.608703, -0.085025},
685
{2.419101, 0.017453, -0.608704, -0.085025}},
686
{{2.419101, 0.017453, -0.608703, -0.085025},
687
{2.419101, 0.017453, -0.608704, -0.085025}}},
689
ASSERT_EQ(result.sizes(), ref_output.sizes());
691
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
693
// test case 3, multiple layers with norm
694
LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()}));
695
model = TransformerEncoder(
696
TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm)));
698
model->to(torch::kCUDA);
702
/*src_mask=*/torch::Tensor{},
703
/*src_key_padding_mask=*/mask)
705
ref_output = torch::tensor(
706
{{{1.695949, -0.357635, -0.893077, -0.445238},
707
{1.695955, -0.357639, -0.893050, -0.445266}},
708
{{1.695948, -0.357634, -0.893082, -0.445233},
709
{1.695950, -0.357635, -0.893077, -0.445238}},
710
{{1.695951, -0.357636, -0.893069, -0.445246},
711
{1.695955, -0.357639, -0.893052, -0.445264}},
712
{{1.695952, -0.357636, -0.893066, -0.445249},
713
{1.695957, -0.357641, -0.893041, -0.445276}},
714
{{1.695946, -0.357632, -0.893095, -0.445220},
715
{1.695952, -0.357637, -0.893065, -0.445251}}},
717
ASSERT_EQ(result.sizes(), ref_output.sizes());
719
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
721
model = TransformerEncoder(
722
TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm)));
724
model->to(torch::kCUDA);
728
/*src_mask=*/torch::Tensor{},
729
/*src_key_padding_mask=*/mask)
731
ref_output = torch::tensor(
732
{{{1.695955, -0.357639, -0.893051, -0.445265},
733
{1.695955, -0.357639, -0.893051, -0.445265}},
734
{{1.695955, -0.357639, -0.893051, -0.445265},
735
{1.695955, -0.357639, -0.893051, -0.445265}},
736
{{1.695955, -0.357639, -0.893051, -0.445265},
737
{1.695955, -0.357639, -0.893051, -0.445265}},
738
{{1.695955, -0.357639, -0.893051, -0.445265},
739
{1.695955, -0.357639, -0.893051, -0.445265}},
740
{{1.695955, -0.357639, -0.893051, -0.445265},
741
{1.695955, -0.357639, -0.893051, -0.445265}}},
743
ASSERT_EQ(result.sizes(), ref_output.sizes());
745
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
748
TEST_F(TransformerTest, TransformerEncoder) {
749
transformer_encoder_test_helper(
750
/*is_cuda=*/false, /*use_callable_activation=*/false);
751
transformer_encoder_test_helper(
752
/*is_cuda=*/false, /*use_callable_activation=*/true);
755
TEST_F(TransformerTest, TransformerEncoder_CUDA) {
756
transformer_encoder_test_helper(
757
/*is_cuda=*/true, /*use_callable_activation=*/false);
758
transformer_encoder_test_helper(
759
/*is_cuda=*/true, /*use_callable_activation=*/true);
762
TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) {
764
c10::str(TransformerEncoderLayer(4, 2)),
765
"torch::nn::TransformerEncoderLayerImpl(\n"
766
" (self_attn): torch::nn::MultiheadAttention(\n"
767
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
769
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
770
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
771
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
772
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
773
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
774
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
775
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
779
TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
780
LayerNorm norm = LayerNorm(LayerNormOptions({4}));
781
TransformerEncoderOptions options(
782
TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2)
783
.norm(AnyModule(norm)));
785
c10::str(TransformerEncoder(options)),
786
"torch::nn::TransformerEncoderImpl(\n"
787
" (layers): torch::nn::ModuleList(\n"
788
" (0): torch::nn::TransformerEncoderLayerImpl(\n"
789
" (self_attn): torch::nn::MultiheadAttention(\n"
790
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
792
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
793
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
794
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
795
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
796
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
797
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
798
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
800
" (1): torch::nn::TransformerEncoderLayerImpl(\n"
801
" (self_attn): torch::nn::MultiheadAttention(\n"
802
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
804
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
805
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
806
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
807
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
808
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
809
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
810
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
813
" (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
817
TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) {
819
c10::str(TransformerDecoderLayer(4, 2)),
820
"torch::nn::TransformerDecoderLayerImpl(\n"
821
" (self_attn): torch::nn::MultiheadAttention(\n"
822
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
824
" (multihead_attn): torch::nn::MultiheadAttention(\n"
825
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
827
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
828
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
829
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
830
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
831
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
832
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
833
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
834
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
835
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
839
void transformer_decoder_test_helper(
841
bool use_callable_activation) {
842
// this is a deterministic test for TransformerDecoder
843
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
844
torch::TensorOptions tensor_options =
845
torch::TensorOptions().dtype(torch::kFloat32).device(device);
847
TransformerDecoderLayer decoder_layer =
848
get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
849
tensor_options, use_callable_activation);
851
TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1));
853
model->to(torch::kCUDA);
856
torch::Tensor decoder_input =
857
torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
858
torch::Tensor memory_input =
859
torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
860
torch::Tensor result = model(decoder_input, memory_input).detach();
861
torch::Tensor ref_output = torch::tensor(
862
{{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
863
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
864
ASSERT_TRUE(torch::allclose(
869
/*equal_nan=*/true));
871
// deterministic input
873
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
874
memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
875
result = model(decoder_input, memory_input).detach();
876
ref_output = torch::tensor(
877
{{{2.422245, 0.051716, -0.606338, -0.024756}},
878
{{2.422245, 0.051716, -0.606338, -0.024756}}},
880
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
881
ASSERT_TRUE(torch::allclose(
886
/*equal_nan=*/true));
888
// deterministic input
890
torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
892
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
893
result = model(decoder_input, memory_input).detach();
894
ref_output = torch::tensor(
895
{{{2.343536, 0.085561, -0.654954, 0.074991}},
896
{{2.343536, 0.085561, -0.654954, 0.074991}}},
898
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
899
ASSERT_TRUE(torch::allclose(
904
/*equal_nan=*/true));
906
// deterministic input
907
decoder_input = torch::tensor(
908
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
909
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
910
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
912
memory_input = torch::tensor(
913
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
914
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
915
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
916
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
917
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
919
result = model(decoder_input, memory_input).detach();
920
ref_output = torch::tensor(
921
{{{2.430065, 0.027862, -0.601136, -0.073096},
922
{2.431935, 0.028907, -0.599809, -0.072488}},
923
{{2.428457, 0.027053, -0.602275, -0.073462},
924
{2.431970, 0.029387, -0.599789, -0.071621}},
925
{{2.431934, 0.028196, -0.599802, -0.073809},
926
{2.432306, 0.028858, -0.599542, -0.072846}}},
928
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
929
ASSERT_TRUE(torch::allclose(
934
/*equal_nan=*/true));
937
torch::Tensor t_mask = {};
938
torch::Tensor m_mask = {};
939
torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
940
result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
942
ref_output = torch::tensor(
943
{{{2.430065, 0.027862, -0.601136, -0.073096},
944
{2.431935, 0.028907, -0.599809, -0.072488}},
945
{{2.428457, 0.027053, -0.602275, -0.073462},
946
{2.431970, 0.029387, -0.599789, -0.071621}},
947
{{2.431934, 0.028196, -0.599802, -0.073809},
948
{2.432306, 0.028858, -0.599542, -0.072846}}},
950
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
951
ASSERT_TRUE(torch::allclose(
956
/*equal_nan=*/true));
959
key_padding_mask[0][2] = 1;
960
key_padding_mask[1][1] = 1;
961
key_padding_mask[1][2] = 1;
962
result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
964
ref_output = torch::tensor(
965
{{{2.430025, 0.027643, -0.601164, -0.073476},
966
{2.4323, 0.029375, -0.599553, -0.071881}},
967
{{2.428523, 0.026838, -0.602226, -0.07391},
968
{2.432634, 0.029842, -0.599318, -0.071253}},
969
{{2.432278, 0.028152, -0.599555, -0.074139},
970
{2.432659, 0.029244, -0.599294, -0.072382}}},
972
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
973
ASSERT_TRUE(torch::allclose(
978
/*equal_nan=*/true));
980
// memory_key_padding_mask
981
torch::Tensor t_key_padding_mask = {};
982
key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
991
ref_output = torch::tensor(
992
{{{2.430065, 0.027862, -0.601136, -0.073096},
993
{2.431935, 0.028907, -0.599809, -0.072488}},
994
{{2.428457, 0.027053, -0.602275, -0.073462},
995
{2.431970, 0.029387, -0.599789, -0.071621}},
996
{{2.431934, 0.028196, -0.599802, -0.073809},
997
{2.432306, 0.028858, -0.599542, -0.072846}}},
999
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1000
ASSERT_TRUE(torch::allclose(
1005
/*equal_nan=*/true));
1007
// memory_key_padding_mask
1008
key_padding_mask[0][4] = 1;
1009
key_padding_mask[1][3] = 1;
1010
key_padding_mask[1][4] = 1;
1019
ref_output = torch::tensor(
1020
{{{2.429757, 0.027358, -0.601351, -0.073816},
1021
{2.432692, 0.028583, -0.599263, -0.073634}},
1022
{{2.428247, 0.02662, -0.602419, -0.074123},
1023
{2.432657, 0.029055, -0.599293, -0.072732}},
1024
{{2.431515, 0.027687, -0.600096, -0.074459},
1025
{2.433075, 0.028543, -0.598987, -0.073985}}},
1027
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1028
ASSERT_TRUE(torch::allclose(
1033
/*equal_nan=*/true));
1035
// multiple layers no norm
1036
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2));
1038
model->to(torch::kCUDA);
1041
decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1042
memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1043
result = model(decoder_input, memory_input).detach();
1044
ref_output = torch::tensor(
1045
{{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options);
1046
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1047
ASSERT_TRUE(torch::allclose(
1052
/*equal_nan=*/true));
1054
// multiple layers no norm
1055
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1057
model->to(torch::kCUDA);
1059
// deterministic input
1060
decoder_input = torch::tensor(
1061
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1062
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1063
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1065
memory_input = torch::tensor(
1066
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1067
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1068
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1069
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1070
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1072
result = model(decoder_input, memory_input).detach();
1073
ref_output = torch::tensor(
1074
{{{2.42794, 0.026164, -0.60263, -0.0747591},
1075
{2.43113, 0.0279516, -0.600376, -0.0736896}},
1076
{{2.42794, 0.026164, -0.60263, -0.0747591},
1077
{2.43113, 0.0279516, -0.600376, -0.0736896}},
1078
{{2.42794, 0.026164, -0.60263, -0.0747591},
1079
{2.43113, 0.0279516, -0.600376, -0.0736896}}},
1081
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1082
ASSERT_TRUE(torch::allclose(
1087
/*equal_nan=*/true));
1089
// multiple layers with norm
1090
LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1091
model = TransformerDecoder(
1092
TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm)));
1094
model->to(torch::kCUDA);
1097
decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1098
memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1099
result = model(decoder_input, memory_input).detach();
1100
ref_output = torch::tensor(
1101
{{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options);
1102
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1103
ASSERT_TRUE(torch::allclose(
1108
/*equal_nan=*/true));
1110
// multiple layers with norm
1111
model = TransformerDecoder(
1112
TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1114
model->to(torch::kCUDA);
1116
// deterministic input
1117
decoder_input = torch::tensor(
1118
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1119
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1120
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1122
memory_input = torch::tensor(
1123
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1124
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1125
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1126
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1127
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1129
result = model(decoder_input, memory_input).detach();
1130
ref_output = torch::tensor(
1131
{{{1.69559, -0.357291, -0.894741, -0.443553},
1132
{1.69571, -0.357363, -0.894154, -0.444196}},
1133
{{1.69559, -0.357291, -0.894741, -0.443553},
1134
{1.69571, -0.357363, -0.894154, -0.444196}},
1135
{{1.69559, -0.357291, -0.894741, -0.443553},
1136
{1.69571, -0.357363, -0.894154, -0.444196}}},
1138
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1139
ASSERT_TRUE(torch::allclose(
1144
/*equal_nan=*/true));
1146
// gelu activation test cases
1147
decoder_layer.get()->options.activation(torch::kGELU);
1148
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1));
1150
model->to(torch::kCUDA);
1153
// deterministic input
1154
decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1155
memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1156
result = model(decoder_input, memory_input).detach();
1157
ref_output = torch::tensor(
1158
{{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
1159
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1160
ASSERT_TRUE(torch::allclose(
1165
/*equal_nan=*/true));
1167
// deterministic input
1169
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1170
memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
1171
result = model(decoder_input, memory_input).detach();
1172
ref_output = torch::tensor(
1173
{{{2.415448, 0.054389, -0.610932, -0.0156613}},
1174
{{2.415448, 0.054389, -0.610932, -0.0156613}}},
1176
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1177
ASSERT_TRUE(torch::allclose(
1182
/*equal_nan=*/true));
1184
// deterministic input
1186
torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
1188
torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1189
result = model(decoder_input, memory_input).detach();
1190
ref_output = torch::tensor(
1191
{{{2.338531, 0.087709, -0.65776, 0.080646}},
1192
{{2.338531, 0.087709, -0.65776, 0.080646}}},
1194
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1195
ASSERT_TRUE(torch::allclose(
1200
/*equal_nan=*/true));
1202
// deterministic input
1203
decoder_input = torch::tensor(
1204
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1205
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1206
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1208
memory_input = torch::tensor(
1209
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1210
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1211
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1212
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1213
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1215
result = model(decoder_input, memory_input).detach();
1216
ref_output = torch::tensor(
1217
{{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
1218
{2.42210631, 0.03546578, -0.60679895, -0.05357488}},
1219
{{2.41907674, 0.0336104, -0.60892977, -0.05490462},
1220
{2.42216881, 0.03586554, -0.6067524, -0.05289126}},
1221
{{2.42205716, 0.03488046, -0.60683681, -0.05460596},
1222
{2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
1224
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1225
ASSERT_TRUE(torch::allclose(
1230
/*equal_nan=*/true));
1232
// Multiple layers no norm
1233
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1235
model->to(torch::kCUDA);
1237
decoder_input = torch::tensor(
1238
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1239
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1240
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1242
memory_input = torch::tensor(
1243
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1244
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1245
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1246
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1247
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1249
result = model(decoder_input, memory_input).detach();
1250
ref_output = torch::tensor(
1251
{{{2.41859, 0.0328114, -0.609269, -0.0560386},
1252
{2.42138, 0.034598, -0.607316, -0.0546574}},
1253
{{2.41859, 0.0328114, -0.609269, -0.0560386},
1254
{2.42138, 0.034598, -0.607316, -0.0546574}},
1255
{{2.41859, 0.0328114, -0.609269, -0.0560386},
1256
{2.42138, 0.034598, -0.607316, -0.0546574}}},
1258
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1259
ASSERT_TRUE(torch::allclose(
1264
/*equal_nan=*/true));
1266
// Multiple layers with norm
1267
norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1268
model = TransformerDecoder(
1269
TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1271
model->to(torch::kCUDA);
1274
decoder_input = torch::tensor(
1275
{{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1276
{{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1277
{{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1279
memory_input = torch::tensor(
1280
{{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1281
{{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1282
{{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1283
{{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1284
{{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1286
result = model(decoder_input, memory_input).detach();
1287
ref_output = torch::tensor(
1288
{{{1.69298, -0.355163, -0.906375, -0.431439},
1289
{1.69305, -0.355195, -0.906062, -0.431791}},
1290
{{1.69298, -0.355163, -0.906375, -0.431439},
1291
{1.69305, -0.355195, -0.906062, -0.431791}},
1292
{{1.69298, -0.355163, -0.906375, -0.431439},
1293
{1.69305, -0.355195, -0.906062, -0.431791}}},
1295
ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1296
ASSERT_TRUE(torch::allclose(
1301
/*equal_nan=*/true));
1304
TEST_F(TransformerTest, TransformerDecoder) {
1305
transformer_decoder_test_helper(
1306
/*is_cuda=*/false, /*use_callable_activation=*/false);
1307
transformer_decoder_test_helper(
1308
/*is_cuda=*/false, /*use_callable_activation=*/true);
1311
TEST_F(TransformerTest, TransformerDecoder_CUDA) {
1312
transformer_decoder_test_helper(
1313
/*is_cuda=*/true, /*use_callable_activation=*/false);
1314
transformer_decoder_test_helper(
1315
/*is_cuda=*/true, /*use_callable_activation=*/true);
1318
TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
1319
LayerNorm norm = LayerNorm(LayerNormOptions({4}));
1320
TransformerDecoderOptions options(
1321
TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2)
1322
.norm(AnyModule(norm)));
1324
c10::str(TransformerDecoder(options)),
1325
"torch::nn::TransformerDecoderImpl(\n"
1326
" (layers): torch::nn::ModuleList(\n"
1327
" (0): torch::nn::TransformerDecoderLayerImpl(\n"
1328
" (self_attn): torch::nn::MultiheadAttention(\n"
1329
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1331
" (multihead_attn): torch::nn::MultiheadAttention(\n"
1332
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1334
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1335
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1336
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1337
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1338
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1339
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1340
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1341
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1342
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1344
" (1): torch::nn::TransformerDecoderLayerImpl(\n"
1345
" (self_attn): torch::nn::MultiheadAttention(\n"
1346
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1348
" (multihead_attn): torch::nn::MultiheadAttention(\n"
1349
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1351
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1352
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1353
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1354
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1355
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1356
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1357
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1358
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1359
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1362
" (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1366
void transformer_test_helper(bool is_cuda, bool use_callable_activation) {
1367
// this is a deterministic test for Transformere
1368
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
1369
torch::TensorOptions tensor_options =
1370
torch::TensorOptions().dtype(torch::kFloat32).device(device);
1372
// transformer created encoder/decoder
1373
auto options = TransformerOptions()
1376
.num_encoder_layers(2)
1377
.num_decoder_layers(1)
1378
.dim_feedforward(16)
1380
.activation(torch::kReLU);
1381
if (use_callable_activation) {
1383
[&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
1385
Transformer model(options);
1387
set_parameter_to_constants<Transformer>(model, tensor_options);
1388
if (tensor_options.device() == torch::kCUDA) {
1389
model->to(torch::kCUDA);
1392
// transformer with customized encoder/decoder
1393
LayerNorm enorm(LayerNormOptions({4}));
1394
TransformerEncoder encoder(
1395
TransformerEncoderOptions(
1396
TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1398
.norm(AnyModule(enorm)));
1400
LayerNorm dnorm(LayerNormOptions({4}));
1401
TransformerDecoder decoder(
1402
TransformerDecoderOptions(
1403
TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1405
.norm(AnyModule(dnorm)));
1407
Transformer model_cus(TransformerOptions()
1410
.custom_encoder(AnyModule(encoder))
1411
.custom_decoder(AnyModule(decoder)));
1413
set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1414
if (tensor_options.device() == torch::kCUDA) {
1415
model_cus->to(torch::kCUDA);
1419
torch::Tensor src = torch::tensor(
1420
{{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1421
{{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}},
1422
{{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}},
1425
torch::Tensor tgt = torch::tensor(
1426
{{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1427
{{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}},
1430
torch::Tensor ref_output = torch::tensor(
1431
{{{2.695875, 0.347114, -0.044355, -0.549541},
1432
{2.696091, 0.347015, -0.044770, -0.548522}},
1433
{{2.695875, 0.347114, -0.044355, -0.549541},
1434
{2.696091, 0.347015, -0.044770, -0.548522}}},
1436
torch::Tensor result = model(src, tgt);
1437
torch::Tensor result_cus = model_cus(src, tgt);
1438
ASSERT_EQ(result.sizes(), ref_output.sizes());
1439
ASSERT_TRUE(result.equal(result_cus));
1441
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1443
torch::Tensor src_mask =
1444
Transformer::Impl::generate_square_subsequent_mask(src.size(0))
1445
.to(tensor_options);
1446
ref_output = torch::tensor(
1447
{{{2.695875, 0.347114, -0.044355, -0.549541},
1448
{2.696091, 0.347015, -0.044770, -0.548522}},
1449
{{2.695875, 0.347114, -0.044355, -0.549541},
1450
{2.696091, 0.347015, -0.044770, -0.548522}}},
1452
result = model(src, tgt, src_mask);
1453
result_cus = model_cus(src, tgt, src_mask);
1454
ASSERT_EQ(result.sizes(), ref_output.sizes());
1455
ASSERT_TRUE(result.equal(result_cus));
1457
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1459
torch::Tensor tgt_key_padding_mask =
1460
torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1;
1461
tgt_key_padding_mask[0][0] = 1;
1462
tgt_key_padding_mask[1][1] = 1;
1463
ref_output = torch::tensor(
1464
{{{2.696114, 0.347004, -0.044813, -0.548417},
1465
{2.696091, 0.347015, -0.044770, -0.548522}},
1466
{{2.696114, 0.347004, -0.044813, -0.548417},
1467
{2.696091, 0.347015, -0.044770, -0.548522}}},
1476
tgt_key_padding_mask);
1477
result_cus = model_cus(
1484
tgt_key_padding_mask);
1485
ASSERT_EQ(result.sizes(), ref_output.sizes());
1486
ASSERT_TRUE(result.equal(result_cus));
1488
torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1491
TEST_F(TransformerTest, Transformer) {
1492
transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false);
1493
transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true);
1496
TEST_F(TransformerTest, Transformer_CUDA) {
1497
transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false);
1498
transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true);
1501
TEST_F(TransformerTest, TransformerArgsCorrectness) {
1502
Transformer model(TransformerOptions()
1505
.num_encoder_layers(2)
1506
.num_decoder_layers(1)
1507
.dim_feedforward(16)
1509
.activation(torch::kReLU));
1511
torch::Tensor src = torch::randn({2, 3, 4});
1512
torch::Tensor tgt = torch::randn({3, 2, 4});
1515
model(src, tgt), "src and tgt should have equal batch size");
1517
tgt = torch::randn({2, 3, 3});
1519
model(src, tgt), "src and tgt should have same feature size as d_model");
1521
src = torch::randn({2, 3});
1522
ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions");