pytorch

Форк
0
/
rnn_cell_test.py 
1771 строка · 58.3 Кб
1

2

3

4

5

6
from caffe2.python import (
7
    core, gradient_checker, rnn_cell, workspace, scope, utils
8
)
9
from caffe2.python.attention import AttentionType
10
from caffe2.python.model_helper import ModelHelper, ExtractPredictorNet
11
from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh, _prepare_rnn
12
from caffe2.proto import caffe2_pb2
13
import caffe2.python.hypothesis_test_util as hu
14

15
from functools import partial
16
from hypothesis import assume, given
17
from hypothesis import settings as ht_settings
18
import hypothesis.strategies as st
19
import numpy as np
20
import unittest
21

22

23
def lstm_unit(*args, **kwargs):
24
    forget_bias = kwargs.get('forget_bias', 0.0)
25
    drop_states = kwargs.get('drop_states', False)
26
    sequence_lengths = kwargs.get('sequence_lengths', True)
27

28
    if sequence_lengths:
29
        hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep = args
30
    else:
31
        hidden_t_prev, cell_t_prev, gates, timestep = args
32
    D = cell_t_prev.shape[2]
33
    G = gates.shape[2]
34
    N = gates.shape[1]
35
    t = (timestep * np.ones(shape=(N, D))).astype(np.int32)
36
    assert t.shape == (N, D)
37
    assert G == 4 * D
38
    # Resize to avoid broadcasting inconsistencies with NumPy
39
    gates = gates.reshape(N, 4, D)
40
    cell_t_prev = cell_t_prev.reshape(N, D)
41
    i_t = gates[:, 0, :].reshape(N, D)
42
    f_t = gates[:, 1, :].reshape(N, D)
43
    o_t = gates[:, 2, :].reshape(N, D)
44
    g_t = gates[:, 3, :].reshape(N, D)
45
    i_t = sigmoid(i_t)
46
    f_t = sigmoid(f_t + forget_bias)
47
    o_t = sigmoid(o_t)
48
    g_t = tanh(g_t)
49
    if sequence_lengths:
50
        seq_lengths = (np.ones(shape=(N, D)) *
51
                       seq_lengths.reshape(N, 1)).astype(np.int32)
52
        assert seq_lengths.shape == (N, D)
53
        valid = (t < seq_lengths).astype(np.int32)
54
    else:
55
        valid = np.ones(shape=(N, D))
56
    assert valid.shape == (N, D)
57
    cell_t = ((f_t * cell_t_prev) + (i_t * g_t)) * (valid) + \
58
        (1 - valid) * cell_t_prev * (1 - drop_states)
59
    assert cell_t.shape == (N, D)
60
    hidden_t = (o_t * tanh(cell_t)) * valid + hidden_t_prev * (
61
        1 - valid) * (1 - drop_states)
62
    hidden_t = hidden_t.reshape(1, N, D)
63
    cell_t = cell_t.reshape(1, N, D)
64
    return hidden_t, cell_t
65

66

67
def layer_norm_with_scale_and_bias_ref(X, scale, bias, axis=-1, epsilon=1e-4):
68
    left = np.prod(X.shape[:axis])
69
    reshaped = np.reshape(X, [left, -1])
70
    mean = np.mean(reshaped, axis=1).reshape([left, 1])
71
    stdev = np.sqrt(
72
        np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
73
        np.square(mean) + epsilon
74
    )
75
    norm = (reshaped - mean) / stdev
76
    norm = np.reshape(norm, X.shape)
77
    adjusted = scale * norm + bias
78

79
    return adjusted
80

81

82
def layer_norm_lstm_reference(
83
    input,
84
    hidden_input,
85
    cell_input,
86
    gates_w,
87
    gates_b,
88
    gates_t_norm_scale,
89
    gates_t_norm_bias,
90
    seq_lengths,
91
    forget_bias,
92
    drop_states=False
93
):
94
    T = input.shape[0]
95
    N = input.shape[1]
96
    G = input.shape[2]
97
    D = hidden_input.shape[hidden_input.ndim - 1]
98
    hidden = np.zeros(shape=(T + 1, N, D))
99
    cell = np.zeros(shape=(T + 1, N, D))
100
    assert hidden.shape[0] == T + 1
101
    assert cell.shape[0] == T + 1
102
    assert hidden.shape[1] == N
103
    assert cell.shape[1] == N
104
    cell[0, :, :] = cell_input
105
    hidden[0, :, :] = hidden_input
106
    for t in range(T):
107
        input_t = input[t].reshape(1, N, G)
108
        print(input_t.shape)
109
        hidden_t_prev = hidden[t].reshape(1, N, D)
110
        cell_t_prev = cell[t].reshape(1, N, D)
111
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
112
        gates = gates + input_t
113

114
        gates = layer_norm_with_scale_and_bias_ref(
115
            gates, gates_t_norm_scale, gates_t_norm_bias
116
        )
117

118
        hidden_t, cell_t = lstm_unit(
119
            hidden_t_prev,
120
            cell_t_prev,
121
            gates,
122
            seq_lengths,
123
            t,
124
            forget_bias=forget_bias,
125
            drop_states=drop_states,
126
        )
127
        hidden[t + 1] = hidden_t
128
        cell[t + 1] = cell_t
129
    return (
130
        hidden[1:],
131
        hidden[-1].reshape(1, N, D),
132
        cell[1:],
133
        cell[-1].reshape(1, N, D)
134
    )
135

136

137
def lstm_reference(input, hidden_input, cell_input,
138
                   gates_w, gates_b, seq_lengths, forget_bias,
139
                   drop_states=False):
140
    T = input.shape[0]
141
    N = input.shape[1]
142
    G = input.shape[2]
143
    D = hidden_input.shape[hidden_input.ndim - 1]
144
    hidden = np.zeros(shape=(T + 1, N, D))
145
    cell = np.zeros(shape=(T + 1, N, D))
146
    assert hidden.shape[0] == T + 1
147
    assert cell.shape[0] == T + 1
148
    assert hidden.shape[1] == N
149
    assert cell.shape[1] == N
150
    cell[0, :, :] = cell_input
151
    hidden[0, :, :] = hidden_input
152
    for t in range(T):
153
        input_t = input[t].reshape(1, N, G)
154
        hidden_t_prev = hidden[t].reshape(1, N, D)
155
        cell_t_prev = cell[t].reshape(1, N, D)
156
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
157
        gates = gates + input_t
158
        hidden_t, cell_t = lstm_unit(
159
            hidden_t_prev,
160
            cell_t_prev,
161
            gates,
162
            seq_lengths,
163
            t,
164
            forget_bias=forget_bias,
165
            drop_states=drop_states,
166
        )
167
        hidden[t + 1] = hidden_t
168
        cell[t + 1] = cell_t
169
    return (
170
        hidden[1:],
171
        hidden[-1].reshape(1, N, D),
172
        cell[1:],
173
        cell[-1].reshape(1, N, D)
174
    )
175

176

177
def multi_lstm_reference(input, hidden_input_list, cell_input_list,
178
                            i2h_w_list, i2h_b_list, gates_w_list, gates_b_list,
179
                            seq_lengths, forget_bias, drop_states=False):
180
    num_layers = len(hidden_input_list)
181
    assert len(cell_input_list) == num_layers
182
    assert len(i2h_w_list) == num_layers
183
    assert len(i2h_b_list) == num_layers
184
    assert len(gates_w_list) == num_layers
185
    assert len(gates_b_list) == num_layers
186

187
    for i in range(num_layers):
188
        layer_input = np.dot(input, i2h_w_list[i].T) + i2h_b_list[i]
189
        h_all, h_last, c_all, c_last = lstm_reference(
190
            layer_input,
191
            hidden_input_list[i],
192
            cell_input_list[i],
193
            gates_w_list[i],
194
            gates_b_list[i],
195
            seq_lengths,
196
            forget_bias,
197
            drop_states=drop_states,
198
        )
199
        input = h_all
200
    return h_all, h_last, c_all, c_last
201

202

203
def compute_regular_attention_logits(
204
    hidden_t,
205
    weighted_decoder_hidden_state_t_w,
206
    weighted_decoder_hidden_state_t_b,
207
    attention_weighted_encoder_context_t_prev,
208
    weighted_prev_attention_context_w,
209
    weighted_prev_attention_context_b,
210
    attention_v,
211
    weighted_encoder_outputs,
212
    encoder_outputs_for_dot_product,
213
    coverage_prev,
214
    coverage_weights,
215
):
216
    weighted_hidden_t = np.dot(
217
        hidden_t,
218
        weighted_decoder_hidden_state_t_w.T,
219
    ) + weighted_decoder_hidden_state_t_b
220
    attention_v = attention_v.reshape([-1])
221
    return np.sum(
222
        attention_v * np.tanh(weighted_encoder_outputs + weighted_hidden_t),
223
        axis=2,
224
    )
225

226

227
def compute_recurrent_attention_logits(
228
    hidden_t,
229
    weighted_decoder_hidden_state_t_w,
230
    weighted_decoder_hidden_state_t_b,
231
    attention_weighted_encoder_context_t_prev,
232
    weighted_prev_attention_context_w,
233
    weighted_prev_attention_context_b,
234
    attention_v,
235
    weighted_encoder_outputs,
236
    encoder_outputs_for_dot_product,
237
    coverage_prev,
238
    coverage_weights,
239
):
240
    weighted_hidden_t = np.dot(
241
        hidden_t,
242
        weighted_decoder_hidden_state_t_w.T,
243
    ) + weighted_decoder_hidden_state_t_b
244
    weighted_prev_attention_context = np.dot(
245
        attention_weighted_encoder_context_t_prev,
246
        weighted_prev_attention_context_w.T
247
    ) + weighted_prev_attention_context_b
248
    attention_v = attention_v.reshape([-1])
249
    return np.sum(
250
        attention_v * np.tanh(
251
            weighted_encoder_outputs + weighted_hidden_t +
252
            weighted_prev_attention_context
253
        ),
254
        axis=2,
255
    )
256

257

258
def compute_dot_attention_logits(
259
    hidden_t,
260
    weighted_decoder_hidden_state_t_w,
261
    weighted_decoder_hidden_state_t_b,
262
    attention_weighted_encoder_context_t_prev,
263
    weighted_prev_attention_context_w,
264
    weighted_prev_attention_context_b,
265
    attention_v,
266
    weighted_encoder_outputs,
267
    encoder_outputs_for_dot_product,
268
    coverage_prev,
269
    coverage_weights,
270
):
271
    hidden_t_for_dot_product = np.transpose(hidden_t, axes=[1, 2, 0])
272
    if (
273
        weighted_decoder_hidden_state_t_w is not None and
274
        weighted_decoder_hidden_state_t_b is not None
275
    ):
276
        hidden_t_for_dot_product = np.matmul(
277
            weighted_decoder_hidden_state_t_w,
278
            hidden_t_for_dot_product,
279
        ) + np.expand_dims(weighted_decoder_hidden_state_t_b, axis=1)
280
    attention_logits_t = np.sum(
281
        np.matmul(
282
            encoder_outputs_for_dot_product,
283
            hidden_t_for_dot_product,
284
        ),
285
        axis=2,
286
    )
287
    return np.transpose(attention_logits_t)
288

289

290
def compute_coverage_attention_logits(
291
    hidden_t,
292
    weighted_decoder_hidden_state_t_w,
293
    weighted_decoder_hidden_state_t_b,
294
    attention_weighted_encoder_context_t_prev,
295
    weighted_prev_attention_context_w,
296
    weighted_prev_attention_context_b,
297
    attention_v,
298
    weighted_encoder_outputs,
299
    encoder_outputs_for_dot_product,
300
    coverage_prev,
301
    coverage_weights,
302
):
303
    weighted_hidden_t = np.dot(
304
        hidden_t,
305
        weighted_decoder_hidden_state_t_w.T,
306
    ) + weighted_decoder_hidden_state_t_b
307
    coverage_part = coverage_prev.T * coverage_weights
308
    encoder_part = weighted_encoder_outputs + coverage_part
309
    attention_v = attention_v.reshape([-1])
310
    return np.sum(
311
        attention_v * np.tanh(encoder_part + weighted_hidden_t),
312
        axis=2,
313
    )
314

315

316
def lstm_with_attention_reference(
317
    input,
318
    initial_hidden_state,
319
    initial_cell_state,
320
    initial_attention_weighted_encoder_context,
321
    gates_w,
322
    gates_b,
323
    decoder_input_lengths,
324
    encoder_outputs_transposed,
325
    weighted_prev_attention_context_w,
326
    weighted_prev_attention_context_b,
327
    weighted_decoder_hidden_state_t_w,
328
    weighted_decoder_hidden_state_t_b,
329
    weighted_encoder_outputs,
330
    coverage_weights,
331
    attention_v,
332
    attention_zeros,
333
    compute_attention_logits,
334
):
335
    encoder_outputs = np.transpose(encoder_outputs_transposed, axes=[2, 0, 1])
336
    encoder_outputs_for_dot_product = np.transpose(
337
        encoder_outputs_transposed,
338
        [0, 2, 1],
339
    )
340
    decoder_input_length = input.shape[0]
341
    batch_size = input.shape[1]
342
    decoder_input_dim = input.shape[2]
343
    decoder_state_dim = initial_hidden_state.shape[2]
344
    encoder_output_dim = encoder_outputs.shape[2]
345
    hidden = np.zeros(
346
        shape=(decoder_input_length + 1, batch_size, decoder_state_dim))
347
    cell = np.zeros(
348
        shape=(decoder_input_length + 1, batch_size, decoder_state_dim))
349
    attention_weighted_encoder_context = np.zeros(
350
        shape=(decoder_input_length + 1, batch_size, encoder_output_dim))
351
    cell[0, :, :] = initial_cell_state
352
    hidden[0, :, :] = initial_hidden_state
353
    attention_weighted_encoder_context[0, :, :] = (
354
        initial_attention_weighted_encoder_context
355
    )
356
    encoder_length = encoder_outputs.shape[0]
357
    coverage = np.zeros(
358
        shape=(decoder_input_length + 1, batch_size, encoder_length))
359
    for t in range(decoder_input_length):
360
        input_t = input[t].reshape(1, batch_size, decoder_input_dim)
361
        hidden_t_prev = hidden[t].reshape(1, batch_size, decoder_state_dim)
362
        cell_t_prev = cell[t].reshape(1, batch_size, decoder_state_dim)
363
        attention_weighted_encoder_context_t_prev = (
364
            attention_weighted_encoder_context[t].reshape(
365
                1, batch_size, encoder_output_dim)
366
        )
367
        gates_input = np.concatenate(
368
            (hidden_t_prev, attention_weighted_encoder_context_t_prev),
369
            axis=2,
370
        )
371
        gates = np.dot(gates_input, gates_w.T) + gates_b
372
        gates = gates + input_t
373
        hidden_t, cell_t = lstm_unit(hidden_t_prev, cell_t_prev, gates,
374
                                     decoder_input_lengths, t)
375
        hidden[t + 1] = hidden_t
376
        cell[t + 1] = cell_t
377

378
        coverage_prev = coverage[t].reshape(1, batch_size, encoder_length)
379

380
        attention_logits_t = compute_attention_logits(
381
            hidden_t,
382
            weighted_decoder_hidden_state_t_w,
383
            weighted_decoder_hidden_state_t_b,
384
            attention_weighted_encoder_context_t_prev,
385
            weighted_prev_attention_context_w,
386
            weighted_prev_attention_context_b,
387
            attention_v,
388
            weighted_encoder_outputs,
389
            encoder_outputs_for_dot_product,
390
            coverage_prev,
391
            coverage_weights,
392
        )
393

394
        attention_logits_t_exp = np.exp(attention_logits_t)
395
        attention_weights_t = (
396
            attention_logits_t_exp /
397
            np.sum(attention_logits_t_exp, axis=0).reshape([1, -1])
398
        )
399
        coverage[t + 1, :, :] = coverage[t, :, :] + attention_weights_t.T
400
        attention_weighted_encoder_context[t + 1] = np.sum(
401
            (
402
                encoder_outputs *
403
                attention_weights_t.reshape([-1, batch_size, 1])
404
            ),
405
            axis=0,
406
        )
407
    return (
408
        hidden[1:],
409
        hidden[-1].reshape(1, batch_size, decoder_state_dim),
410
        cell[1:],
411
        cell[-1].reshape(1, batch_size, decoder_state_dim),
412
        attention_weighted_encoder_context[1:],
413
        attention_weighted_encoder_context[-1].reshape(
414
            1,
415
            batch_size,
416
            encoder_output_dim,
417
        )
418
    )
419

420

421
def lstm_with_regular_attention_reference(
422
    input,
423
    initial_hidden_state,
424
    initial_cell_state,
425
    initial_attention_weighted_encoder_context,
426
    gates_w,
427
    gates_b,
428
    decoder_input_lengths,
429
    weighted_decoder_hidden_state_t_w,
430
    weighted_decoder_hidden_state_t_b,
431
    weighted_encoder_outputs,
432
    attention_v,
433
    attention_zeros,
434
    encoder_outputs_transposed,
435
):
436
    return lstm_with_attention_reference(
437
        input=input,
438
        initial_hidden_state=initial_hidden_state,
439
        initial_cell_state=initial_cell_state,
440
        initial_attention_weighted_encoder_context=(
441
            initial_attention_weighted_encoder_context
442
        ),
443
        gates_w=gates_w,
444
        gates_b=gates_b,
445
        decoder_input_lengths=decoder_input_lengths,
446
        encoder_outputs_transposed=encoder_outputs_transposed,
447
        weighted_prev_attention_context_w=None,
448
        weighted_prev_attention_context_b=None,
449
        weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
450
        weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
451
        weighted_encoder_outputs=weighted_encoder_outputs,
452
        coverage_weights=None,
453
        attention_v=attention_v,
454
        attention_zeros=attention_zeros,
455
        compute_attention_logits=compute_regular_attention_logits,
456
    )
457

458

459
def lstm_with_recurrent_attention_reference(
460
    input,
461
    initial_hidden_state,
462
    initial_cell_state,
463
    initial_attention_weighted_encoder_context,
464
    gates_w,
465
    gates_b,
466
    decoder_input_lengths,
467
    weighted_prev_attention_context_w,
468
    weighted_prev_attention_context_b,
469
    weighted_decoder_hidden_state_t_w,
470
    weighted_decoder_hidden_state_t_b,
471
    weighted_encoder_outputs,
472
    attention_v,
473
    attention_zeros,
474
    encoder_outputs_transposed,
475
):
476
    return lstm_with_attention_reference(
477
        input=input,
478
        initial_hidden_state=initial_hidden_state,
479
        initial_cell_state=initial_cell_state,
480
        initial_attention_weighted_encoder_context=(
481
            initial_attention_weighted_encoder_context
482
        ),
483
        gates_w=gates_w,
484
        gates_b=gates_b,
485
        decoder_input_lengths=decoder_input_lengths,
486
        encoder_outputs_transposed=encoder_outputs_transposed,
487
        weighted_prev_attention_context_w=weighted_prev_attention_context_w,
488
        weighted_prev_attention_context_b=weighted_prev_attention_context_b,
489
        weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
490
        weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
491
        weighted_encoder_outputs=weighted_encoder_outputs,
492
        coverage_weights=None,
493
        attention_v=attention_v,
494
        attention_zeros=attention_zeros,
495
        compute_attention_logits=compute_recurrent_attention_logits,
496
    )
497

498

499
def lstm_with_dot_attention_reference(
500
    input,
501
    initial_hidden_state,
502
    initial_cell_state,
503
    initial_attention_weighted_encoder_context,
504
    gates_w,
505
    gates_b,
506
    decoder_input_lengths,
507
    encoder_outputs_transposed,
508
    weighted_decoder_hidden_state_t_w,
509
    weighted_decoder_hidden_state_t_b,
510
):
511
    return lstm_with_attention_reference(
512
        input=input,
513
        initial_hidden_state=initial_hidden_state,
514
        initial_cell_state=initial_cell_state,
515
        initial_attention_weighted_encoder_context=(
516
            initial_attention_weighted_encoder_context
517
        ),
518
        gates_w=gates_w,
519
        gates_b=gates_b,
520
        decoder_input_lengths=decoder_input_lengths,
521
        encoder_outputs_transposed=encoder_outputs_transposed,
522
        weighted_prev_attention_context_w=None,
523
        weighted_prev_attention_context_b=None,
524
        weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
525
        weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
526
        weighted_encoder_outputs=None,
527
        coverage_weights=None,
528
        attention_v=None,
529
        attention_zeros=None,
530
        compute_attention_logits=compute_dot_attention_logits,
531
    )
532

533

534
def lstm_with_dot_attention_reference_same_dim(
535
    input,
536
    initial_hidden_state,
537
    initial_cell_state,
538
    initial_attention_weighted_encoder_context,
539
    gates_w,
540
    gates_b,
541
    decoder_input_lengths,
542
    encoder_outputs_transposed,
543
):
544
    return lstm_with_dot_attention_reference(
545
        input=input,
546
        initial_hidden_state=initial_hidden_state,
547
        initial_cell_state=initial_cell_state,
548
        initial_attention_weighted_encoder_context=(
549
            initial_attention_weighted_encoder_context
550
        ),
551
        gates_w=gates_w,
552
        gates_b=gates_b,
553
        decoder_input_lengths=decoder_input_lengths,
554
        encoder_outputs_transposed=encoder_outputs_transposed,
555
        weighted_decoder_hidden_state_t_w=None,
556
        weighted_decoder_hidden_state_t_b=None,
557
    )
558

559

560
def lstm_with_dot_attention_reference_different_dim(
561
    input,
562
    initial_hidden_state,
563
    initial_cell_state,
564
    initial_attention_weighted_encoder_context,
565
    gates_w,
566
    gates_b,
567
    decoder_input_lengths,
568
    weighted_decoder_hidden_state_t_w,
569
    weighted_decoder_hidden_state_t_b,
570
    encoder_outputs_transposed,
571
):
572
    return lstm_with_dot_attention_reference(
573
        input=input,
574
        initial_hidden_state=initial_hidden_state,
575
        initial_cell_state=initial_cell_state,
576
        initial_attention_weighted_encoder_context=(
577
            initial_attention_weighted_encoder_context
578
        ),
579
        gates_w=gates_w,
580
        gates_b=gates_b,
581
        decoder_input_lengths=decoder_input_lengths,
582
        encoder_outputs_transposed=encoder_outputs_transposed,
583
        weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
584
        weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
585
    )
586

587

588
def lstm_with_coverage_attention_reference(
589
    input,
590
    initial_hidden_state,
591
    initial_cell_state,
592
    initial_attention_weighted_encoder_context,
593
    initial_coverage,
594
    gates_w,
595
    gates_b,
596
    decoder_input_lengths,
597
    weighted_decoder_hidden_state_t_w,
598
    weighted_decoder_hidden_state_t_b,
599
    weighted_encoder_outputs,
600
    coverage_weights,
601
    attention_v,
602
    attention_zeros,
603
    encoder_outputs_transposed,
604
):
605
    return lstm_with_attention_reference(
606
        input=input,
607
        initial_hidden_state=initial_hidden_state,
608
        initial_cell_state=initial_cell_state,
609
        initial_attention_weighted_encoder_context=(
610
            initial_attention_weighted_encoder_context
611
        ),
612
        gates_w=gates_w,
613
        gates_b=gates_b,
614
        decoder_input_lengths=decoder_input_lengths,
615
        encoder_outputs_transposed=encoder_outputs_transposed,
616
        weighted_prev_attention_context_w=None,
617
        weighted_prev_attention_context_b=None,
618
        weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
619
        weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
620
        weighted_encoder_outputs=weighted_encoder_outputs,
621
        coverage_weights=coverage_weights,
622
        attention_v=attention_v,
623
        attention_zeros=attention_zeros,
624
        compute_attention_logits=compute_coverage_attention_logits,
625
    )
626

627

628
def milstm_reference(
629
        input,
630
        hidden_input,
631
        cell_input,
632
        gates_w,
633
        gates_b,
634
        alpha,
635
        beta1,
636
        beta2,
637
        b,
638
        seq_lengths,
639
        forget_bias,
640
        drop_states=False):
641
    T = input.shape[0]
642
    N = input.shape[1]
643
    G = input.shape[2]
644
    D = hidden_input.shape[hidden_input.ndim - 1]
645
    hidden = np.zeros(shape=(T + 1, N, D))
646
    cell = np.zeros(shape=(T + 1, N, D))
647
    assert hidden.shape[0] == T + 1
648
    assert cell.shape[0] == T + 1
649
    assert hidden.shape[1] == N
650
    assert cell.shape[1] == N
651
    cell[0, :, :] = cell_input
652
    hidden[0, :, :] = hidden_input
653
    for t in range(T):
654
        input_t = input[t].reshape(1, N, G)
655
        hidden_t_prev = hidden[t].reshape(1, N, D)
656
        cell_t_prev = cell[t].reshape(1, N, D)
657
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
658
        gates = (alpha * gates * input_t) + \
659
                    (beta1 * gates) + \
660
                    (beta2 * input_t) + \
661
                    b
662
        hidden_t, cell_t = lstm_unit(
663
            hidden_t_prev,
664
            cell_t_prev,
665
            gates,
666
            seq_lengths,
667
            t,
668
            forget_bias=forget_bias,
669
            drop_states=drop_states,
670
        )
671
        hidden[t + 1] = hidden_t
672
        cell[t + 1] = cell_t
673
    return (
674
        hidden[1:],
675
        hidden[-1].reshape(1, N, D),
676
        cell[1:],
677
        cell[-1].reshape(1, N, D)
678
    )
679

680

681
def layer_norm_milstm_reference(
682
        input,
683
        hidden_input,
684
        cell_input,
685
        gates_w,
686
        gates_b,
687
        alpha,
688
        beta1,
689
        beta2,
690
        b,
691
        gates_t_norm_scale,
692
        gates_t_norm_bias,
693
        seq_lengths,
694
        forget_bias,
695
        drop_states=False):
696
    T = input.shape[0]
697
    N = input.shape[1]
698
    G = input.shape[2]
699
    D = hidden_input.shape[hidden_input.ndim - 1]
700
    hidden = np.zeros(shape=(T + 1, N, D))
701
    cell = np.zeros(shape=(T + 1, N, D))
702
    assert hidden.shape[0] == T + 1
703
    assert cell.shape[0] == T + 1
704
    assert hidden.shape[1] == N
705
    assert cell.shape[1] == N
706
    cell[0, :, :] = cell_input
707
    hidden[0, :, :] = hidden_input
708
    for t in range(T):
709
        input_t = input[t].reshape(1, N, G)
710
        hidden_t_prev = hidden[t].reshape(1, N, D)
711
        cell_t_prev = cell[t].reshape(1, N, D)
712
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
713
        gates = (alpha * gates * input_t) + \
714
                    (beta1 * gates) + \
715
                    (beta2 * input_t) + \
716
                    b
717
        gates = layer_norm_with_scale_and_bias_ref(
718
            gates, gates_t_norm_scale, gates_t_norm_bias
719
        )
720
        hidden_t, cell_t = lstm_unit(
721
            hidden_t_prev,
722
            cell_t_prev,
723
            gates,
724
            seq_lengths,
725
            t,
726
            forget_bias=forget_bias,
727
            drop_states=drop_states,
728
        )
729
        hidden[t + 1] = hidden_t
730
        cell[t + 1] = cell_t
731
    return (
732
        hidden[1:],
733
        hidden[-1].reshape(1, N, D),
734
        cell[1:],
735
        cell[-1].reshape(1, N, D)
736
    )
737

738

739
def lstm_input():
740
    '''
741
    Create input tensor where each dimension is from 1 to 4, ndim=3 and
742
    last dimension size is a factor of 4
743
    '''
744
    dims_ = st.tuples(
745
        st.integers(min_value=1, max_value=4),  # t
746
        st.integers(min_value=1, max_value=4),  # n
747
        st.integers(min_value=1, max_value=4),  # d
748
    )
749

750
    def create_input(dims):
751
        dims = list(dims)
752
        dims[2] *= 4
753
        return hu.arrays(dims)
754

755
    return dims_.flatmap(create_input)
756

757

758
def _prepare_attention(t, n, dim_in, encoder_dim,
759
                          forward_only=False, T=None,
760
                          dim_out=None, residual=False,
761
                          final_dropout=False):
762
    if dim_out is None:
763
        dim_out = [dim_in]
764
    print("Dims: t={} n={} dim_in={} dim_out={}".format(t, n, dim_in, dim_out))
765

766
    model = ModelHelper(name='external')
767

768
    def generate_input_state(shape):
769
        return np.random.random(shape).astype(np.float32)
770

771
    initial_states = []
772
    for layer_id, d in enumerate(dim_out):
773
        h, c = model.net.AddExternalInputs(
774
            "hidden_init_{}".format(layer_id),
775
            "cell_init_{}".format(layer_id),
776
        )
777
        initial_states.extend([h, c])
778
        workspace.FeedBlob(h, generate_input_state((1, n, d)))
779
        workspace.FeedBlob(c, generate_input_state((1, n, d)))
780

781
    awec_init = model.net.AddExternalInputs([
782
        'initial_attention_weighted_encoder_context',
783
    ])
784
    initial_states.append(awec_init)
785
    workspace.FeedBlob(
786
        awec_init,
787
        generate_input_state((1, n, encoder_dim)),
788
    )
789

790
    # Due to convoluted RNN scoping logic we make sure that things
791
    # work from a namescope
792
    with scope.NameScope("test_name_scope"):
793
        (
794
            input_blob,
795
            seq_lengths,
796
            encoder_outputs,
797
            weighted_encoder_outputs,
798
        ) = model.net.AddScopedExternalInputs(
799
            'input_blob',
800
            'seq_lengths',
801
            'encoder_outputs',
802
            'weighted_encoder_outputs',
803
        )
804

805
        layer_input_dim = dim_in
806
        cells = []
807
        for layer_id, d in enumerate(dim_out):
808

809
            cell = rnn_cell.MILSTMCell(
810
                name='decoder_{}'.format(layer_id),
811
                forward_only=forward_only,
812
                input_size=layer_input_dim,
813
                hidden_size=d,
814
                forget_bias=0.0,
815
                memory_optimization=False,
816
            )
817
            cells.append(cell)
818
            layer_input_dim = d
819

820
        decoder_cell = rnn_cell.MultiRNNCell(
821
            cells,
822
            name='decoder',
823
            residual_output_layers=range(1, len(cells)) if residual else None,
824
        )
825

826
        attention_cell = rnn_cell.AttentionCell(
827
            encoder_output_dim=encoder_dim,
828
            encoder_outputs=encoder_outputs,
829
            encoder_lengths=None,
830
            decoder_cell=decoder_cell,
831
            decoder_state_dim=dim_out[-1],
832
            name='attention_decoder',
833
            attention_type=AttentionType.Recurrent,
834
            weighted_encoder_outputs=weighted_encoder_outputs,
835
            attention_memory_optimization=True,
836
        )
837
        if final_dropout:
838
            # dropout ratio of 0.0 used to test mechanism but not interfere
839
            # with numerical tests
840
            attention_cell = rnn_cell.DropoutCell(
841
                internal_cell=attention_cell,
842
                dropout_ratio=0.0,
843
                name='dropout',
844
                forward_only=forward_only,
845
                is_test=False,
846
            )
847

848
        attention_cell = (
849
            attention_cell if T is None
850
            else rnn_cell.UnrolledCell(attention_cell, T)
851
        )
852

853
        output_indices = decoder_cell.output_indices
854
        output_indices.append(2 * len(cells))
855
        outputs_with_grads = [2 * i for i in output_indices]
856

857
        final_output, state_outputs = attention_cell.apply_over_sequence(
858
            model=model,
859
            inputs=input_blob,
860
            seq_lengths=seq_lengths,
861
            initial_states=initial_states,
862
            outputs_with_grads=outputs_with_grads,
863
        )
864

865
    workspace.RunNetOnce(model.param_init_net)
866

867
    workspace.FeedBlob(
868
        seq_lengths,
869
        np.random.randint(1, t + 1, size=(n,)).astype(np.int32)
870
    )
871

872
    return {
873
        'final_output': final_output,
874
        'net': model.net,
875
        'initial_states': initial_states,
876
        'input_blob': input_blob,
877
        'encoder_outputs': encoder_outputs,
878
        'weighted_encoder_outputs': weighted_encoder_outputs,
879
        'outputs_with_grads': outputs_with_grads,
880
    }
881

882

883
class MulCell(rnn_cell.RNNCell):
884
    def _apply(self, model, input_t,
885
               seq_lengths, states, timestep, extra_inputs):
886
        assert len(states) == 1
887
        result = model.net.Mul([input_t, states[0]])
888
        model.net.AddExternalOutput(result)
889
        return [result]
890

891
    def get_state_names(self):
892
        return [self.scope("state")]
893

894

895
def prepare_mul_rnn(model, input_blob, shape, T, outputs_with_grad, num_layers):
896
    print("Shape: ", shape)
897
    t, n, d = shape
898
    cells = [MulCell(name="layer_{}".format(i)) for i in range(num_layers)]
899
    cell = rnn_cell.MultiRNNCell(name="multi_mul_rnn", cells=cells)
900
    if T is not None:
901
        cell = rnn_cell.UnrolledCell(cell, T=T)
902
    states = [
903
        model.param_init_net.ConstantFill(
904
            [], "initial_state_{}".format(i), value=1.0, shape=[1, n, d])
905
        for i in range(num_layers)]
906
    _, results = cell.apply_over_sequence(
907
        model=model,
908
        inputs=input_blob,
909
        initial_states=states,
910
        outputs_with_grads=[
911
            x + 2 * (num_layers - 1) for x in outputs_with_grad
912
        ],
913
        seq_lengths=None,
914
    )
915
    return results[-2:]
916

917

918
class RNNCellTest(hu.HypothesisTestCase):
919
    @given(
920
        input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
921
        num_layers=st.integers(1, 4),
922
        outputs_with_grad=st.sampled_from(
923
            [[0], [1], [0, 1]]
924
        ),
925
    )
926
    @ht_settings(max_examples=10, deadline=None)
927
    def test_unroll_mul(self, input_tensor, num_layers, outputs_with_grad):
928
        outputs = []
929
        nets = []
930
        input_blob = None
931
        for T in [input_tensor.shape[0], None]:
932
            model = ModelHelper("rnn_mul_{}".format(
933
                "unroll" if T else "dynamic"))
934
            input_blob = model.net.AddExternalInputs("input_blob")
935
            outputs.append(
936
                prepare_mul_rnn(model, input_blob, input_tensor.shape, T,
937
                                outputs_with_grad, num_layers))
938
            workspace.RunNetOnce(model.param_init_net)
939
            nets.append(model.net)
940
            workspace.blobs[input_blob] = input_tensor
941

942
        gradient_checker.NetGradientChecker.CompareNets(
943
            nets, outputs, outputs_with_grad_ids=outputs_with_grad,
944
            inputs_with_grads=[input_blob],
945
        )
946

947
    @given(
948
        input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
949
        forget_bias=st.floats(-10.0, 10.0),
950
        drop_states=st.booleans(),
951
        dim_out=st.lists(
952
            elements=st.integers(min_value=1, max_value=3),
953
            min_size=1, max_size=3,
954
        ),
955
        outputs_with_grads=st.sampled_from(
956
            [[0], [1], [0, 1], [0, 2], [0, 1, 2, 3]]
957
        )
958
    )
959
    @ht_settings(max_examples=10, deadline=None)
960
    @utils.debug
961
    def test_unroll_lstm(self, input_tensor, dim_out, outputs_with_grads,
962
                         **kwargs):
963
        lstms = [
964
            _prepare_rnn(
965
                *input_tensor.shape,
966
                create_rnn=rnn_cell.LSTM,
967
                outputs_with_grads=outputs_with_grads,
968
                T=T,
969
                two_d_initial_states=False,
970
                dim_out=dim_out,
971
                **kwargs
972
            ) for T in [input_tensor.shape[0], None]
973
        ]
974
        outputs, nets, inputs = zip(*lstms)
975
        workspace.FeedBlob(inputs[0][-1], input_tensor)
976

977
        assert inputs[0] == inputs[1]
978
        gradient_checker.NetGradientChecker.CompareNets(
979
            nets, outputs, outputs_with_grads,
980
            inputs_with_grads=inputs[0],
981
        )
982

983
    @given(
984
        input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
985
        encoder_length=st.integers(min_value=1, max_value=3),
986
        encoder_dim=st.integers(min_value=1, max_value=3),
987
        hidden_units=st.integers(min_value=1, max_value=3),
988
        num_layers=st.integers(min_value=1, max_value=3),
989
        residual=st.booleans(),
990
        final_dropout=st.booleans(),
991
    )
992
    @ht_settings(max_examples=10, deadline=None)
993
    @utils.debug
994
    def test_unroll_attention(self, input_tensor, encoder_length,
995
                                    encoder_dim, hidden_units,
996
                                    num_layers, residual,
997
                                    final_dropout):
998

999
        dim_out = [hidden_units] * num_layers
1000
        encoder_tensor = np.random.random(
1001
            (encoder_length, input_tensor.shape[1], encoder_dim),
1002
        ).astype('float32')
1003

1004
        print('Decoder input shape: {}'.format(input_tensor.shape))
1005
        print('Encoder output shape: {}'.format(encoder_tensor.shape))
1006

1007
        # Necessary because otherwise test fails for networks with fewer
1008
        # layers than previous test. TODO: investigate why.
1009
        workspace.ResetWorkspace()
1010

1011
        net, unrolled = [
1012
            _prepare_attention(
1013
                t=input_tensor.shape[0],
1014
                n=input_tensor.shape[1],
1015
                dim_in=input_tensor.shape[2],
1016
                encoder_dim=encoder_dim,
1017
                T=T,
1018
                dim_out=dim_out,
1019
                residual=residual,
1020
                final_dropout=final_dropout,
1021
            ) for T in [input_tensor.shape[0], None]
1022
        ]
1023

1024
        workspace.FeedBlob(net['input_blob'], input_tensor)
1025
        workspace.FeedBlob(net['encoder_outputs'], encoder_tensor)
1026
        workspace.FeedBlob(
1027
            net['weighted_encoder_outputs'],
1028
            np.random.random(encoder_tensor.shape).astype('float32'),
1029
        )
1030

1031
        for input_name in [
1032
            'input_blob',
1033
            'encoder_outputs',
1034
            'weighted_encoder_outputs',
1035
        ]:
1036
            assert net[input_name] == unrolled[input_name]
1037
        for state_name, unrolled_state_name in zip(
1038
            net['initial_states'],
1039
            unrolled['initial_states'],
1040
        ):
1041
            assert state_name == unrolled_state_name
1042

1043
        inputs_with_grads = net['initial_states'] + [
1044
            net['input_blob'],
1045
            net['encoder_outputs'],
1046
            net['weighted_encoder_outputs'],
1047
        ]
1048

1049
        gradient_checker.NetGradientChecker.CompareNets(
1050
            [net['net'], unrolled['net']],
1051
            [[net['final_output']], [unrolled['final_output']]],
1052
            [0],
1053
            inputs_with_grads=inputs_with_grads,
1054
            threshold=0.000001,
1055
        )
1056

1057
    @given(
1058
        input_tensor=hu.tensor(min_dim=3, max_dim=3),
1059
        forget_bias=st.floats(-10.0, 10.0),
1060
        forward_only=st.booleans(),
1061
        drop_states=st.booleans(),
1062
    )
1063
    @ht_settings(max_examples=10, deadline=None)
1064
    def test_layered_lstm(self, input_tensor, **kwargs):
1065
        for outputs_with_grads in [[0], [1], [0, 1, 2, 3]]:
1066
            for memory_optim in [False, True]:
1067
                _, net, inputs = _prepare_rnn(
1068
                    *input_tensor.shape,
1069
                    create_rnn=rnn_cell.LSTM,
1070
                    outputs_with_grads=outputs_with_grads,
1071
                    memory_optim=memory_optim,
1072
                    **kwargs
1073
                )
1074
                workspace.FeedBlob(inputs[-1], input_tensor)
1075
                workspace.RunNetOnce(net)
1076
                workspace.ResetWorkspace()
1077

1078
    def test_lstm(self):
1079
        self.lstm_base(lstm_type=(rnn_cell.LSTM, lstm_reference))
1080

1081
    def test_milstm(self):
1082
        self.lstm_base(lstm_type=(rnn_cell.MILSTM, milstm_reference))
1083

1084
    @unittest.skip("This is currently numerically unstable")
1085
    def test_norm_lstm(self):
1086
        self.lstm_base(
1087
            lstm_type=(rnn_cell.LayerNormLSTM, layer_norm_lstm_reference),
1088
        )
1089

1090
    @unittest.skip("This is currently numerically unstable")
1091
    def test_norm_milstm(self):
1092
        self.lstm_base(
1093
            lstm_type=(rnn_cell.LayerNormMILSTM, layer_norm_milstm_reference)
1094
        )
1095

1096
    @given(
1097
        seed=st.integers(0, 2**32 - 1),
1098
        input_tensor=lstm_input(),
1099
        forget_bias=st.floats(-10.0, 10.0),
1100
        fwd_only=st.booleans(),
1101
        drop_states=st.booleans(),
1102
        memory_optim=st.booleans(),
1103
        outputs_with_grads=st.sampled_from([[0], [1], [0, 1, 2, 3]]),
1104
    )
1105
    @ht_settings(max_examples=10, deadline=None)
1106
    def lstm_base(self, seed, lstm_type, outputs_with_grads, memory_optim,
1107
                  input_tensor, forget_bias, fwd_only, drop_states):
1108
        np.random.seed(seed)
1109
        create_lstm, ref = lstm_type
1110
        ref = partial(ref, forget_bias=forget_bias)
1111

1112
        t, n, d = input_tensor.shape
1113
        assert d % 4 == 0
1114
        d = d // 4
1115
        ref = partial(ref, forget_bias=forget_bias, drop_states=drop_states)
1116

1117
        net = _prepare_rnn(t, n, d, create_lstm,
1118
                            outputs_with_grads=outputs_with_grads,
1119
                            memory_optim=memory_optim,
1120
                            forget_bias=forget_bias,
1121
                            forward_only=fwd_only,
1122
                            drop_states=drop_states)[1]
1123
        # here we don't provide a real input for the net but just for one of
1124
        # its ops (RecurrentNetworkOp). So have to hardcode this name
1125
        workspace.FeedBlob("test_name_scope/external/recurrent/i2h",
1126
                           input_tensor)
1127
        op = net._net.op[-1]
1128
        inputs = [workspace.FetchBlob(name) for name in op.input]
1129

1130
        # Validate forward only mode is in effect
1131
        if fwd_only:
1132
            for arg in op.arg:
1133
                self.assertFalse(arg.name == 'backward_step_net')
1134

1135
        self.assertReferenceChecks(
1136
            hu.cpu_do,
1137
            op,
1138
            inputs,
1139
            ref,
1140
            outputs_to_check=list(range(4)),
1141
        )
1142

1143
        # Checking for input, gates_t_w and gates_t_b gradients
1144
        if not fwd_only:
1145
            for param in range(5):
1146
                self.assertGradientChecks(
1147
                    device_option=hu.cpu_do,
1148
                    op=op,
1149
                    inputs=inputs,
1150
                    outputs_to_check=param,
1151
                    outputs_with_grads=outputs_with_grads,
1152
                    threshold=0.01,
1153
                    stepsize=0.005,
1154
                )
1155

1156
    def test_lstm_extract_predictor_net(self):
1157
        model = ModelHelper(name="lstm_extract_test")
1158

1159
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1160
            output, _, _, _ = rnn_cell.LSTM(
1161
                model=model,
1162
                input_blob="input",
1163
                seq_lengths="seqlengths",
1164
                initial_states=("hidden_init", "cell_init"),
1165
                dim_in=20,
1166
                dim_out=40,
1167
                scope="test",
1168
                drop_states=True,
1169
                return_last_layer_only=True,
1170
            )
1171
        # Run param init net to get the shapes for all inputs
1172
        shapes = {}
1173
        workspace.RunNetOnce(model.param_init_net)
1174
        for b in workspace.Blobs():
1175
            shapes[b] = workspace.FetchBlob(b).shape
1176

1177
        # But export in CPU
1178
        (predict_net, export_blobs) = ExtractPredictorNet(
1179
            net_proto=model.net.Proto(),
1180
            input_blobs=["input"],
1181
            output_blobs=[output],
1182
            device=core.DeviceOption(caffe2_pb2.CPU, 1),
1183
        )
1184

1185
        # Create the net and run once to see it is valid
1186
        # Populate external inputs with correctly shaped random input
1187
        # and also ensure that the export_blobs was constructed correctly.
1188
        workspace.ResetWorkspace()
1189
        shapes['input'] = [10, 4, 20]
1190
        shapes['cell_init'] = [1, 4, 40]
1191
        shapes['hidden_init'] = [1, 4, 40]
1192

1193
        print(predict_net.Proto().external_input)
1194
        self.assertTrue('seqlengths' in predict_net.Proto().external_input)
1195
        for einp in predict_net.Proto().external_input:
1196
            if einp == 'seqlengths':
1197
                workspace.FeedBlob(
1198
                    "seqlengths",
1199
                    np.array([10] * 4, dtype=np.int32)
1200
                )
1201
            else:
1202
                workspace.FeedBlob(
1203
                    einp,
1204
                    np.zeros(shapes[einp]).astype(np.float32),
1205
                )
1206
                if einp != 'input':
1207
                    self.assertTrue(einp in export_blobs)
1208

1209
        print(str(predict_net.Proto()))
1210
        self.assertTrue(workspace.CreateNet(predict_net.Proto()))
1211
        self.assertTrue(workspace.RunNet(predict_net.Proto().name))
1212

1213
        # Validate device options set correctly for the RNNs
1214
        for op in predict_net.Proto().op:
1215
            if op.type == 'RecurrentNetwork':
1216
                for arg in op.arg:
1217
                    if arg.name == "step_net":
1218
                        for step_op in arg.n.op:
1219
                            self.assertEqual(0, step_op.device_option.device_type)
1220
                            self.assertEqual(1, step_op.device_option.device_id)
1221
                    elif arg.name == 'backward_step_net':
1222
                        self.assertEqual(caffe2_pb2.NetDef(), arg.n)
1223

1224
    def test_lstm_params(self):
1225
        model = ModelHelper(name="lstm_params_test")
1226

1227
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1228
            output, _, _, _ = rnn_cell.LSTM(
1229
                model=model,
1230
                input_blob="input",
1231
                seq_lengths="seqlengths",
1232
                initial_states=None,
1233
                dim_in=20,
1234
                dim_out=40,
1235
                scope="test",
1236
                drop_states=True,
1237
                return_last_layer_only=True,
1238
            )
1239
        for param in model.GetParams():
1240
            self.assertNotEqual(model.get_param_info(param), None)
1241

1242
    def test_milstm_params(self):
1243
        model = ModelHelper(name="milstm_params_test")
1244

1245
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1246
            output, _, _, _ = rnn_cell.MILSTM(
1247
                model=model,
1248
                input_blob="input",
1249
                seq_lengths="seqlengths",
1250
                initial_states=None,
1251
                dim_in=20,
1252
                dim_out=[40, 20],
1253
                scope="test",
1254
                drop_states=True,
1255
                return_last_layer_only=True,
1256
            )
1257
        for param in model.GetParams():
1258
            self.assertNotEqual(model.get_param_info(param), None)
1259

1260
    def test_layer_norm_lstm_params(self):
1261
        model = ModelHelper(name="layer_norm_lstm_params_test")
1262

1263
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1264
            output, _, _, _ = rnn_cell.LayerNormLSTM(
1265
                model=model,
1266
                input_blob="input",
1267
                seq_lengths="seqlengths",
1268
                initial_states=None,
1269
                dim_in=20,
1270
                dim_out=40,
1271
                scope="test",
1272
                drop_states=True,
1273
                return_last_layer_only=True,
1274
            )
1275
        for param in model.GetParams():
1276
            self.assertNotEqual(model.get_param_info(param), None)
1277

1278
    @given(encoder_output_length=st.integers(1, 3),
1279
           encoder_output_dim=st.integers(1, 3),
1280
           decoder_input_length=st.integers(1, 3),
1281
           decoder_state_dim=st.integers(1, 3),
1282
           batch_size=st.integers(1, 3),
1283
           **hu.gcs)
1284
    @ht_settings(max_examples=10, deadline=None)
1285
    def test_lstm_with_regular_attention(
1286
        self,
1287
        encoder_output_length,
1288
        encoder_output_dim,
1289
        decoder_input_length,
1290
        decoder_state_dim,
1291
        batch_size,
1292
        gc,
1293
        dc,
1294
    ):
1295
        self.lstm_with_attention(
1296
            partial(
1297
                rnn_cell.LSTMWithAttention,
1298
                attention_type=AttentionType.Regular,
1299
            ),
1300
            encoder_output_length,
1301
            encoder_output_dim,
1302
            decoder_input_length,
1303
            decoder_state_dim,
1304
            batch_size,
1305
            lstm_with_regular_attention_reference,
1306
            gc,
1307
        )
1308

1309
    @given(encoder_output_length=st.integers(1, 3),
1310
           encoder_output_dim=st.integers(1, 3),
1311
           decoder_input_length=st.integers(1, 3),
1312
           decoder_state_dim=st.integers(1, 3),
1313
           batch_size=st.integers(1, 3),
1314
           **hu.gcs)
1315
    @ht_settings(max_examples=10, deadline=None)
1316
    def test_lstm_with_recurrent_attention(
1317
        self,
1318
        encoder_output_length,
1319
        encoder_output_dim,
1320
        decoder_input_length,
1321
        decoder_state_dim,
1322
        batch_size,
1323
        gc,
1324
        dc,
1325
    ):
1326
        self.lstm_with_attention(
1327
            partial(
1328
                rnn_cell.LSTMWithAttention,
1329
                attention_type=AttentionType.Recurrent,
1330
            ),
1331
            encoder_output_length,
1332
            encoder_output_dim,
1333
            decoder_input_length,
1334
            decoder_state_dim,
1335
            batch_size,
1336
            lstm_with_recurrent_attention_reference,
1337
            gc,
1338
        )
1339

1340
    @given(encoder_output_length=st.integers(2, 2),
1341
           encoder_output_dim=st.integers(4, 4),
1342
           decoder_input_length=st.integers(3, 3),
1343
           decoder_state_dim=st.integers(4, 4),
1344
           batch_size=st.integers(5, 5),
1345
           **hu.gcs)
1346
    @ht_settings(max_examples=2, deadline=None)
1347
    def test_lstm_with_dot_attention_same_dim(
1348
        self,
1349
        encoder_output_length,
1350
        encoder_output_dim,
1351
        decoder_input_length,
1352
        decoder_state_dim,
1353
        batch_size,
1354
        gc,
1355
        dc,
1356
    ):
1357
        self.lstm_with_attention(
1358
            partial(
1359
                rnn_cell.LSTMWithAttention,
1360
                attention_type=AttentionType.Dot,
1361
            ),
1362
            encoder_output_length,
1363
            encoder_output_dim,
1364
            decoder_input_length,
1365
            decoder_state_dim,
1366
            batch_size,
1367
            lstm_with_dot_attention_reference_same_dim,
1368
            gc,
1369
        )
1370

1371
    @given(encoder_output_length=st.integers(1, 3),
1372
           encoder_output_dim=st.integers(4, 4),
1373
           decoder_input_length=st.integers(1, 3),
1374
           decoder_state_dim=st.integers(5, 5),
1375
           batch_size=st.integers(1, 3),
1376
           **hu.gcs)
1377
    @ht_settings(max_examples=2, deadline=None)
1378
    def test_lstm_with_dot_attention_different_dim(
1379
        self,
1380
        encoder_output_length,
1381
        encoder_output_dim,
1382
        decoder_input_length,
1383
        decoder_state_dim,
1384
        batch_size,
1385
        gc,
1386
        dc,
1387
    ):
1388
        self.lstm_with_attention(
1389
            partial(
1390
                rnn_cell.LSTMWithAttention,
1391
                attention_type=AttentionType.Dot,
1392
            ),
1393
            encoder_output_length,
1394
            encoder_output_dim,
1395
            decoder_input_length,
1396
            decoder_state_dim,
1397
            batch_size,
1398
            lstm_with_dot_attention_reference_different_dim,
1399
            gc,
1400
        )
1401

1402
    @given(encoder_output_length=st.integers(2, 3),
1403
           encoder_output_dim=st.integers(1, 3),
1404
           decoder_input_length=st.integers(1, 3),
1405
           decoder_state_dim=st.integers(1, 3),
1406
           batch_size=st.integers(1, 3),
1407
           **hu.gcs)
1408
    @ht_settings(max_examples=5, deadline=None)
1409
    def test_lstm_with_coverage_attention(
1410
        self,
1411
        encoder_output_length,
1412
        encoder_output_dim,
1413
        decoder_input_length,
1414
        decoder_state_dim,
1415
        batch_size,
1416
        gc,
1417
        dc,
1418
    ):
1419
        self.lstm_with_attention(
1420
            partial(
1421
                rnn_cell.LSTMWithAttention,
1422
                attention_type=AttentionType.SoftCoverage,
1423
            ),
1424
            encoder_output_length,
1425
            encoder_output_dim,
1426
            decoder_input_length,
1427
            decoder_state_dim,
1428
            batch_size,
1429
            lstm_with_coverage_attention_reference,
1430
            gc,
1431
        )
1432

1433
    def lstm_with_attention(
1434
        self,
1435
        create_lstm_with_attention,
1436
        encoder_output_length,
1437
        encoder_output_dim,
1438
        decoder_input_length,
1439
        decoder_state_dim,
1440
        batch_size,
1441
        ref,
1442
        gc,
1443
    ):
1444
        model = ModelHelper(name='external')
1445
        with core.DeviceScope(gc):
1446
            (
1447
                encoder_outputs,
1448
                decoder_inputs,
1449
                decoder_input_lengths,
1450
                initial_decoder_hidden_state,
1451
                initial_decoder_cell_state,
1452
                initial_attention_weighted_encoder_context,
1453
            ) = model.net.AddExternalInputs(
1454
                'encoder_outputs',
1455
                'decoder_inputs',
1456
                'decoder_input_lengths',
1457
                'initial_decoder_hidden_state',
1458
                'initial_decoder_cell_state',
1459
                'initial_attention_weighted_encoder_context',
1460
            )
1461
            create_lstm_with_attention(
1462
                model=model,
1463
                decoder_inputs=decoder_inputs,
1464
                decoder_input_lengths=decoder_input_lengths,
1465
                initial_decoder_hidden_state=initial_decoder_hidden_state,
1466
                initial_decoder_cell_state=initial_decoder_cell_state,
1467
                initial_attention_weighted_encoder_context=(
1468
                    initial_attention_weighted_encoder_context
1469
                ),
1470
                encoder_output_dim=encoder_output_dim,
1471
                encoder_outputs=encoder_outputs,
1472
                encoder_lengths=None,
1473
                decoder_input_dim=decoder_state_dim,
1474
                decoder_state_dim=decoder_state_dim,
1475
                scope='external/LSTMWithAttention',
1476
            )
1477
            op = model.net._net.op[-2]
1478
        workspace.RunNetOnce(model.param_init_net)
1479

1480
        # This is original decoder_inputs after linear layer
1481
        decoder_input_blob = op.input[0]
1482

1483
        workspace.FeedBlob(
1484
            decoder_input_blob,
1485
            np.random.randn(
1486
                decoder_input_length,
1487
                batch_size,
1488
                decoder_state_dim * 4,
1489
            ).astype(np.float32))
1490
        workspace.FeedBlob(
1491
            'external/LSTMWithAttention/encoder_outputs_transposed',
1492
            np.random.randn(
1493
                batch_size,
1494
                encoder_output_dim,
1495
                encoder_output_length,
1496
            ).astype(np.float32),
1497
        )
1498
        workspace.FeedBlob(
1499
            'external/LSTMWithAttention/weighted_encoder_outputs',
1500
            np.random.randn(
1501
                encoder_output_length,
1502
                batch_size,
1503
                encoder_output_dim,
1504
            ).astype(np.float32),
1505
        )
1506
        workspace.FeedBlob(
1507
            'external/LSTMWithAttention/coverage_weights',
1508
            np.random.randn(
1509
                encoder_output_length,
1510
                batch_size,
1511
                encoder_output_dim,
1512
            ).astype(np.float32),
1513
        )
1514
        workspace.FeedBlob(
1515
            decoder_input_lengths,
1516
            np.random.randint(
1517
                0,
1518
                decoder_input_length + 1,
1519
                size=(batch_size,)
1520
            ).astype(np.int32))
1521
        workspace.FeedBlob(
1522
            initial_decoder_hidden_state,
1523
            np.random.randn(1, batch_size, decoder_state_dim).astype(np.float32)
1524
        )
1525
        workspace.FeedBlob(
1526
            initial_decoder_cell_state,
1527
            np.random.randn(1, batch_size, decoder_state_dim).astype(np.float32)
1528
        )
1529
        workspace.FeedBlob(
1530
            initial_attention_weighted_encoder_context,
1531
            np.random.randn(
1532
                1, batch_size, encoder_output_dim).astype(np.float32)
1533
        )
1534
        workspace.FeedBlob(
1535
            'external/LSTMWithAttention/initial_coverage',
1536
            np.zeros((1, batch_size, encoder_output_length)).astype(np.float32),
1537
        )
1538
        inputs = [workspace.FetchBlob(name) for name in op.input]
1539
        self.assertReferenceChecks(
1540
            device_option=gc,
1541
            op=op,
1542
            inputs=inputs,
1543
            reference=ref,
1544
            grad_reference=None,
1545
            output_to_grad=None,
1546
            outputs_to_check=list(range(6)),
1547
        )
1548
        gradients_to_check = [
1549
            index for (index, input_name) in enumerate(op.input)
1550
            if input_name != 'decoder_input_lengths'
1551
        ]
1552
        for param in gradients_to_check:
1553
            self.assertGradientChecks(
1554
                device_option=gc,
1555
                op=op,
1556
                inputs=inputs,
1557
                outputs_to_check=param,
1558
                outputs_with_grads=[0, 4],
1559
                threshold=0.01,
1560
                stepsize=0.001,
1561
            )
1562

1563
    @given(seed=st.integers(0, 2**32 - 1),
1564
           n=st.integers(1, 10),
1565
           d=st.integers(1, 10),
1566
           t=st.integers(1, 10),
1567
           dtype=st.sampled_from([np.float32, np.float16]),
1568
           use_sequence_lengths=st.booleans(),
1569
           **hu.gcs)
1570
    @ht_settings(max_examples=10, deadline=None)
1571
    def test_lstm_unit_recurrent_network(
1572
            self, seed, n, d, t, dtype, dc, use_sequence_lengths, gc):
1573
        np.random.seed(seed)
1574
        if dtype == np.float16:
1575
            # only supported with CUDA/HIP
1576
            assume(gc.device_type == workspace.GpuDeviceType)
1577
            dc = [do for do in dc if do.device_type == workspace.GpuDeviceType]
1578

1579
        if use_sequence_lengths:
1580
            op_inputs = ['hidden_t_prev', 'cell_t_prev', 'gates_t',
1581
                         'seq_lengths', 'timestep']
1582
        else:
1583
            op_inputs = ['hidden_t_prev', 'cell_t_prev', 'gates_t', 'timestep']
1584
        op = core.CreateOperator(
1585
            'LSTMUnit',
1586
            op_inputs,
1587
            ['hidden_t', 'cell_t'],
1588
            sequence_lengths=use_sequence_lengths,
1589
        )
1590
        cell_t_prev = np.random.randn(1, n, d).astype(dtype)
1591
        hidden_t_prev = np.random.randn(1, n, d).astype(dtype)
1592
        gates = np.random.randn(1, n, 4 * d).astype(dtype)
1593
        seq_lengths = np.random.randint(1, t + 1, size=(n,)).astype(np.int32)
1594
        timestep = np.random.randint(0, t, size=(1,)).astype(np.int32)
1595
        if use_sequence_lengths:
1596
            inputs = [hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep]
1597
        else:
1598
            inputs = [hidden_t_prev, cell_t_prev, gates, timestep]
1599
        input_device_options = {'timestep': hu.cpu_do}
1600
        self.assertDeviceChecks(
1601
            dc, op, inputs, [0],
1602
            input_device_options=input_device_options)
1603

1604
        kwargs = {}
1605
        if dtype == np.float16:
1606
            kwargs['threshold'] = 1e-1  # default is 1e-4
1607

1608
        def lstm_unit_reference(*args, **kwargs):
1609
            return lstm_unit(*args, sequence_lengths=use_sequence_lengths, **kwargs)
1610

1611
        self.assertReferenceChecks(
1612
            gc, op, inputs, lstm_unit_reference,
1613
            input_device_options=input_device_options,
1614
            **kwargs)
1615

1616
        kwargs = {}
1617
        if dtype == np.float16:
1618
            kwargs['threshold'] = 0.5  # default is 0.005
1619

1620
        for i in range(2):
1621
            self.assertGradientChecks(
1622
                gc, op, inputs, i, [0, 1],
1623
                input_device_options=input_device_options,
1624
                **kwargs)
1625

1626
    @given(input_length=st.integers(2, 5),
1627
           dim_in=st.integers(1, 3),
1628
           max_num_units=st.integers(1, 3),
1629
           num_layers=st.integers(2, 3),
1630
           batch_size=st.integers(1, 3))
1631
    @ht_settings(max_examples=10, deadline=None)
1632
    def test_multi_lstm(
1633
        self,
1634
        input_length,
1635
        dim_in,
1636
        max_num_units,
1637
        num_layers,
1638
        batch_size,
1639
    ):
1640
        model = ModelHelper(name='external')
1641
        (
1642
            input_sequence,
1643
            seq_lengths,
1644
        ) = model.net.AddExternalInputs(
1645
            'input_sequence',
1646
            'seq_lengths',
1647
        )
1648
        dim_out = [
1649
            np.random.randint(1, max_num_units + 1)
1650
            for _ in range(num_layers)
1651
        ]
1652
        h_all, h_last, c_all, c_last = rnn_cell.LSTM(
1653
            model=model,
1654
            input_blob=input_sequence,
1655
            seq_lengths=seq_lengths,
1656
            initial_states=None,
1657
            dim_in=dim_in,
1658
            dim_out=dim_out,
1659
            # scope='test',
1660
            outputs_with_grads=(0,),
1661
            return_params=False,
1662
            memory_optimization=False,
1663
            forget_bias=0.0,
1664
            forward_only=False,
1665
            return_last_layer_only=True,
1666
        )
1667

1668
        workspace.RunNetOnce(model.param_init_net)
1669

1670
        seq_lengths_val = np.random.randint(
1671
            1,
1672
            input_length + 1,
1673
            size=(batch_size),
1674
        ).astype(np.int32)
1675
        input_sequence_val = np.random.randn(
1676
            input_length,
1677
            batch_size,
1678
            dim_in,
1679
        ).astype(np.float32)
1680
        workspace.FeedBlob(seq_lengths, seq_lengths_val)
1681
        workspace.FeedBlob(input_sequence, input_sequence_val)
1682

1683
        hidden_input_list = []
1684
        cell_input_list = []
1685
        i2h_w_list = []
1686
        i2h_b_list = []
1687
        gates_w_list = []
1688
        gates_b_list = []
1689

1690
        for i in range(num_layers):
1691
            hidden_input_list.append(
1692
                workspace.FetchBlob(
1693
                    'layer_{}/initial_hidden_state'.format(i)),
1694
            )
1695
            cell_input_list.append(
1696
                workspace.FetchBlob(
1697
                    'layer_{}/initial_cell_state'.format(i)),
1698
            )
1699
            # Input projection for the first layer is produced outside
1700
            # of the cell ans thus not scoped
1701
            prefix = 'layer_{}/'.format(i) if i > 0 else ''
1702
            i2h_w_list.append(
1703
                workspace.FetchBlob('{}i2h_w'.format(prefix)),
1704
            )
1705
            i2h_b_list.append(
1706
                workspace.FetchBlob('{}i2h_b'.format(prefix)),
1707
            )
1708
            gates_w_list.append(
1709
                workspace.FetchBlob('layer_{}/gates_t_w'.format(i)),
1710
            )
1711
            gates_b_list.append(
1712
                workspace.FetchBlob('layer_{}/gates_t_b'.format(i)),
1713
            )
1714

1715
        workspace.RunNetOnce(model.net)
1716
        h_all_calc = workspace.FetchBlob(h_all)
1717
        h_last_calc = workspace.FetchBlob(h_last)
1718
        c_all_calc = workspace.FetchBlob(c_all)
1719
        c_last_calc = workspace.FetchBlob(c_last)
1720

1721
        h_all_ref, h_last_ref, c_all_ref, c_last_ref = multi_lstm_reference(
1722
            input_sequence_val,
1723
            hidden_input_list,
1724
            cell_input_list,
1725
            i2h_w_list,
1726
            i2h_b_list,
1727
            gates_w_list,
1728
            gates_b_list,
1729
            seq_lengths_val,
1730
            forget_bias=0.0,
1731
        )
1732

1733
        h_all_delta = np.abs(h_all_ref - h_all_calc).sum()
1734
        h_last_delta = np.abs(h_last_ref - h_last_calc).sum()
1735
        c_all_delta = np.abs(c_all_ref - c_all_calc).sum()
1736
        c_last_delta = np.abs(c_last_ref - c_last_calc).sum()
1737

1738
        self.assertAlmostEqual(h_all_delta, 0.0, places=5)
1739
        self.assertAlmostEqual(h_last_delta, 0.0, places=5)
1740
        self.assertAlmostEqual(c_all_delta, 0.0, places=5)
1741
        self.assertAlmostEqual(c_last_delta, 0.0, places=5)
1742

1743
        input_values = {
1744
            'input_sequence': input_sequence_val,
1745
            'seq_lengths': seq_lengths_val,
1746
        }
1747
        for param in model.GetParams():
1748
            value = workspace.FetchBlob(param)
1749
            input_values[str(param)] = value
1750

1751
        output_sum = model.net.SumElements(
1752
            [h_all],
1753
            'output_sum',
1754
            average=True,
1755
        )
1756
        fake_loss = model.net.Tanh(
1757
            output_sum,
1758
        )
1759
        for param in model.GetParams():
1760
            gradient_checker.NetGradientChecker.Check(
1761
                model.net,
1762
                outputs_with_grad=[fake_loss],
1763
                input_values=input_values,
1764
                input_to_check=str(param),
1765
                print_net=False,
1766
                step_size=0.0001,
1767
                threshold=0.05,
1768
            )
1769

1770

1771
if __name__ == "__main__":
1772
    workspace.GlobalInit([
1773
        'caffe2',
1774
        '--caffe2_log_level=0',
1775
    ])
1776
    unittest.main()
1777

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.