pytorch
1771 строка · 58.3 Кб
1
2
3
4
5
6from caffe2.python import (
7core, gradient_checker, rnn_cell, workspace, scope, utils
8)
9from caffe2.python.attention import AttentionType
10from caffe2.python.model_helper import ModelHelper, ExtractPredictorNet
11from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh, _prepare_rnn
12from caffe2.proto import caffe2_pb2
13import caffe2.python.hypothesis_test_util as hu
14
15from functools import partial
16from hypothesis import assume, given
17from hypothesis import settings as ht_settings
18import hypothesis.strategies as st
19import numpy as np
20import unittest
21
22
23def lstm_unit(*args, **kwargs):
24forget_bias = kwargs.get('forget_bias', 0.0)
25drop_states = kwargs.get('drop_states', False)
26sequence_lengths = kwargs.get('sequence_lengths', True)
27
28if sequence_lengths:
29hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep = args
30else:
31hidden_t_prev, cell_t_prev, gates, timestep = args
32D = cell_t_prev.shape[2]
33G = gates.shape[2]
34N = gates.shape[1]
35t = (timestep * np.ones(shape=(N, D))).astype(np.int32)
36assert t.shape == (N, D)
37assert G == 4 * D
38# Resize to avoid broadcasting inconsistencies with NumPy
39gates = gates.reshape(N, 4, D)
40cell_t_prev = cell_t_prev.reshape(N, D)
41i_t = gates[:, 0, :].reshape(N, D)
42f_t = gates[:, 1, :].reshape(N, D)
43o_t = gates[:, 2, :].reshape(N, D)
44g_t = gates[:, 3, :].reshape(N, D)
45i_t = sigmoid(i_t)
46f_t = sigmoid(f_t + forget_bias)
47o_t = sigmoid(o_t)
48g_t = tanh(g_t)
49if sequence_lengths:
50seq_lengths = (np.ones(shape=(N, D)) *
51seq_lengths.reshape(N, 1)).astype(np.int32)
52assert seq_lengths.shape == (N, D)
53valid = (t < seq_lengths).astype(np.int32)
54else:
55valid = np.ones(shape=(N, D))
56assert valid.shape == (N, D)
57cell_t = ((f_t * cell_t_prev) + (i_t * g_t)) * (valid) + \
58(1 - valid) * cell_t_prev * (1 - drop_states)
59assert cell_t.shape == (N, D)
60hidden_t = (o_t * tanh(cell_t)) * valid + hidden_t_prev * (
611 - valid) * (1 - drop_states)
62hidden_t = hidden_t.reshape(1, N, D)
63cell_t = cell_t.reshape(1, N, D)
64return hidden_t, cell_t
65
66
67def layer_norm_with_scale_and_bias_ref(X, scale, bias, axis=-1, epsilon=1e-4):
68left = np.prod(X.shape[:axis])
69reshaped = np.reshape(X, [left, -1])
70mean = np.mean(reshaped, axis=1).reshape([left, 1])
71stdev = np.sqrt(
72np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
73np.square(mean) + epsilon
74)
75norm = (reshaped - mean) / stdev
76norm = np.reshape(norm, X.shape)
77adjusted = scale * norm + bias
78
79return adjusted
80
81
82def layer_norm_lstm_reference(
83input,
84hidden_input,
85cell_input,
86gates_w,
87gates_b,
88gates_t_norm_scale,
89gates_t_norm_bias,
90seq_lengths,
91forget_bias,
92drop_states=False
93):
94T = input.shape[0]
95N = input.shape[1]
96G = input.shape[2]
97D = hidden_input.shape[hidden_input.ndim - 1]
98hidden = np.zeros(shape=(T + 1, N, D))
99cell = np.zeros(shape=(T + 1, N, D))
100assert hidden.shape[0] == T + 1
101assert cell.shape[0] == T + 1
102assert hidden.shape[1] == N
103assert cell.shape[1] == N
104cell[0, :, :] = cell_input
105hidden[0, :, :] = hidden_input
106for t in range(T):
107input_t = input[t].reshape(1, N, G)
108print(input_t.shape)
109hidden_t_prev = hidden[t].reshape(1, N, D)
110cell_t_prev = cell[t].reshape(1, N, D)
111gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
112gates = gates + input_t
113
114gates = layer_norm_with_scale_and_bias_ref(
115gates, gates_t_norm_scale, gates_t_norm_bias
116)
117
118hidden_t, cell_t = lstm_unit(
119hidden_t_prev,
120cell_t_prev,
121gates,
122seq_lengths,
123t,
124forget_bias=forget_bias,
125drop_states=drop_states,
126)
127hidden[t + 1] = hidden_t
128cell[t + 1] = cell_t
129return (
130hidden[1:],
131hidden[-1].reshape(1, N, D),
132cell[1:],
133cell[-1].reshape(1, N, D)
134)
135
136
137def lstm_reference(input, hidden_input, cell_input,
138gates_w, gates_b, seq_lengths, forget_bias,
139drop_states=False):
140T = input.shape[0]
141N = input.shape[1]
142G = input.shape[2]
143D = hidden_input.shape[hidden_input.ndim - 1]
144hidden = np.zeros(shape=(T + 1, N, D))
145cell = np.zeros(shape=(T + 1, N, D))
146assert hidden.shape[0] == T + 1
147assert cell.shape[0] == T + 1
148assert hidden.shape[1] == N
149assert cell.shape[1] == N
150cell[0, :, :] = cell_input
151hidden[0, :, :] = hidden_input
152for t in range(T):
153input_t = input[t].reshape(1, N, G)
154hidden_t_prev = hidden[t].reshape(1, N, D)
155cell_t_prev = cell[t].reshape(1, N, D)
156gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
157gates = gates + input_t
158hidden_t, cell_t = lstm_unit(
159hidden_t_prev,
160cell_t_prev,
161gates,
162seq_lengths,
163t,
164forget_bias=forget_bias,
165drop_states=drop_states,
166)
167hidden[t + 1] = hidden_t
168cell[t + 1] = cell_t
169return (
170hidden[1:],
171hidden[-1].reshape(1, N, D),
172cell[1:],
173cell[-1].reshape(1, N, D)
174)
175
176
177def multi_lstm_reference(input, hidden_input_list, cell_input_list,
178i2h_w_list, i2h_b_list, gates_w_list, gates_b_list,
179seq_lengths, forget_bias, drop_states=False):
180num_layers = len(hidden_input_list)
181assert len(cell_input_list) == num_layers
182assert len(i2h_w_list) == num_layers
183assert len(i2h_b_list) == num_layers
184assert len(gates_w_list) == num_layers
185assert len(gates_b_list) == num_layers
186
187for i in range(num_layers):
188layer_input = np.dot(input, i2h_w_list[i].T) + i2h_b_list[i]
189h_all, h_last, c_all, c_last = lstm_reference(
190layer_input,
191hidden_input_list[i],
192cell_input_list[i],
193gates_w_list[i],
194gates_b_list[i],
195seq_lengths,
196forget_bias,
197drop_states=drop_states,
198)
199input = h_all
200return h_all, h_last, c_all, c_last
201
202
203def compute_regular_attention_logits(
204hidden_t,
205weighted_decoder_hidden_state_t_w,
206weighted_decoder_hidden_state_t_b,
207attention_weighted_encoder_context_t_prev,
208weighted_prev_attention_context_w,
209weighted_prev_attention_context_b,
210attention_v,
211weighted_encoder_outputs,
212encoder_outputs_for_dot_product,
213coverage_prev,
214coverage_weights,
215):
216weighted_hidden_t = np.dot(
217hidden_t,
218weighted_decoder_hidden_state_t_w.T,
219) + weighted_decoder_hidden_state_t_b
220attention_v = attention_v.reshape([-1])
221return np.sum(
222attention_v * np.tanh(weighted_encoder_outputs + weighted_hidden_t),
223axis=2,
224)
225
226
227def compute_recurrent_attention_logits(
228hidden_t,
229weighted_decoder_hidden_state_t_w,
230weighted_decoder_hidden_state_t_b,
231attention_weighted_encoder_context_t_prev,
232weighted_prev_attention_context_w,
233weighted_prev_attention_context_b,
234attention_v,
235weighted_encoder_outputs,
236encoder_outputs_for_dot_product,
237coverage_prev,
238coverage_weights,
239):
240weighted_hidden_t = np.dot(
241hidden_t,
242weighted_decoder_hidden_state_t_w.T,
243) + weighted_decoder_hidden_state_t_b
244weighted_prev_attention_context = np.dot(
245attention_weighted_encoder_context_t_prev,
246weighted_prev_attention_context_w.T
247) + weighted_prev_attention_context_b
248attention_v = attention_v.reshape([-1])
249return np.sum(
250attention_v * np.tanh(
251weighted_encoder_outputs + weighted_hidden_t +
252weighted_prev_attention_context
253),
254axis=2,
255)
256
257
258def compute_dot_attention_logits(
259hidden_t,
260weighted_decoder_hidden_state_t_w,
261weighted_decoder_hidden_state_t_b,
262attention_weighted_encoder_context_t_prev,
263weighted_prev_attention_context_w,
264weighted_prev_attention_context_b,
265attention_v,
266weighted_encoder_outputs,
267encoder_outputs_for_dot_product,
268coverage_prev,
269coverage_weights,
270):
271hidden_t_for_dot_product = np.transpose(hidden_t, axes=[1, 2, 0])
272if (
273weighted_decoder_hidden_state_t_w is not None and
274weighted_decoder_hidden_state_t_b is not None
275):
276hidden_t_for_dot_product = np.matmul(
277weighted_decoder_hidden_state_t_w,
278hidden_t_for_dot_product,
279) + np.expand_dims(weighted_decoder_hidden_state_t_b, axis=1)
280attention_logits_t = np.sum(
281np.matmul(
282encoder_outputs_for_dot_product,
283hidden_t_for_dot_product,
284),
285axis=2,
286)
287return np.transpose(attention_logits_t)
288
289
290def compute_coverage_attention_logits(
291hidden_t,
292weighted_decoder_hidden_state_t_w,
293weighted_decoder_hidden_state_t_b,
294attention_weighted_encoder_context_t_prev,
295weighted_prev_attention_context_w,
296weighted_prev_attention_context_b,
297attention_v,
298weighted_encoder_outputs,
299encoder_outputs_for_dot_product,
300coverage_prev,
301coverage_weights,
302):
303weighted_hidden_t = np.dot(
304hidden_t,
305weighted_decoder_hidden_state_t_w.T,
306) + weighted_decoder_hidden_state_t_b
307coverage_part = coverage_prev.T * coverage_weights
308encoder_part = weighted_encoder_outputs + coverage_part
309attention_v = attention_v.reshape([-1])
310return np.sum(
311attention_v * np.tanh(encoder_part + weighted_hidden_t),
312axis=2,
313)
314
315
316def lstm_with_attention_reference(
317input,
318initial_hidden_state,
319initial_cell_state,
320initial_attention_weighted_encoder_context,
321gates_w,
322gates_b,
323decoder_input_lengths,
324encoder_outputs_transposed,
325weighted_prev_attention_context_w,
326weighted_prev_attention_context_b,
327weighted_decoder_hidden_state_t_w,
328weighted_decoder_hidden_state_t_b,
329weighted_encoder_outputs,
330coverage_weights,
331attention_v,
332attention_zeros,
333compute_attention_logits,
334):
335encoder_outputs = np.transpose(encoder_outputs_transposed, axes=[2, 0, 1])
336encoder_outputs_for_dot_product = np.transpose(
337encoder_outputs_transposed,
338[0, 2, 1],
339)
340decoder_input_length = input.shape[0]
341batch_size = input.shape[1]
342decoder_input_dim = input.shape[2]
343decoder_state_dim = initial_hidden_state.shape[2]
344encoder_output_dim = encoder_outputs.shape[2]
345hidden = np.zeros(
346shape=(decoder_input_length + 1, batch_size, decoder_state_dim))
347cell = np.zeros(
348shape=(decoder_input_length + 1, batch_size, decoder_state_dim))
349attention_weighted_encoder_context = np.zeros(
350shape=(decoder_input_length + 1, batch_size, encoder_output_dim))
351cell[0, :, :] = initial_cell_state
352hidden[0, :, :] = initial_hidden_state
353attention_weighted_encoder_context[0, :, :] = (
354initial_attention_weighted_encoder_context
355)
356encoder_length = encoder_outputs.shape[0]
357coverage = np.zeros(
358shape=(decoder_input_length + 1, batch_size, encoder_length))
359for t in range(decoder_input_length):
360input_t = input[t].reshape(1, batch_size, decoder_input_dim)
361hidden_t_prev = hidden[t].reshape(1, batch_size, decoder_state_dim)
362cell_t_prev = cell[t].reshape(1, batch_size, decoder_state_dim)
363attention_weighted_encoder_context_t_prev = (
364attention_weighted_encoder_context[t].reshape(
3651, batch_size, encoder_output_dim)
366)
367gates_input = np.concatenate(
368(hidden_t_prev, attention_weighted_encoder_context_t_prev),
369axis=2,
370)
371gates = np.dot(gates_input, gates_w.T) + gates_b
372gates = gates + input_t
373hidden_t, cell_t = lstm_unit(hidden_t_prev, cell_t_prev, gates,
374decoder_input_lengths, t)
375hidden[t + 1] = hidden_t
376cell[t + 1] = cell_t
377
378coverage_prev = coverage[t].reshape(1, batch_size, encoder_length)
379
380attention_logits_t = compute_attention_logits(
381hidden_t,
382weighted_decoder_hidden_state_t_w,
383weighted_decoder_hidden_state_t_b,
384attention_weighted_encoder_context_t_prev,
385weighted_prev_attention_context_w,
386weighted_prev_attention_context_b,
387attention_v,
388weighted_encoder_outputs,
389encoder_outputs_for_dot_product,
390coverage_prev,
391coverage_weights,
392)
393
394attention_logits_t_exp = np.exp(attention_logits_t)
395attention_weights_t = (
396attention_logits_t_exp /
397np.sum(attention_logits_t_exp, axis=0).reshape([1, -1])
398)
399coverage[t + 1, :, :] = coverage[t, :, :] + attention_weights_t.T
400attention_weighted_encoder_context[t + 1] = np.sum(
401(
402encoder_outputs *
403attention_weights_t.reshape([-1, batch_size, 1])
404),
405axis=0,
406)
407return (
408hidden[1:],
409hidden[-1].reshape(1, batch_size, decoder_state_dim),
410cell[1:],
411cell[-1].reshape(1, batch_size, decoder_state_dim),
412attention_weighted_encoder_context[1:],
413attention_weighted_encoder_context[-1].reshape(
4141,
415batch_size,
416encoder_output_dim,
417)
418)
419
420
421def lstm_with_regular_attention_reference(
422input,
423initial_hidden_state,
424initial_cell_state,
425initial_attention_weighted_encoder_context,
426gates_w,
427gates_b,
428decoder_input_lengths,
429weighted_decoder_hidden_state_t_w,
430weighted_decoder_hidden_state_t_b,
431weighted_encoder_outputs,
432attention_v,
433attention_zeros,
434encoder_outputs_transposed,
435):
436return lstm_with_attention_reference(
437input=input,
438initial_hidden_state=initial_hidden_state,
439initial_cell_state=initial_cell_state,
440initial_attention_weighted_encoder_context=(
441initial_attention_weighted_encoder_context
442),
443gates_w=gates_w,
444gates_b=gates_b,
445decoder_input_lengths=decoder_input_lengths,
446encoder_outputs_transposed=encoder_outputs_transposed,
447weighted_prev_attention_context_w=None,
448weighted_prev_attention_context_b=None,
449weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
450weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
451weighted_encoder_outputs=weighted_encoder_outputs,
452coverage_weights=None,
453attention_v=attention_v,
454attention_zeros=attention_zeros,
455compute_attention_logits=compute_regular_attention_logits,
456)
457
458
459def lstm_with_recurrent_attention_reference(
460input,
461initial_hidden_state,
462initial_cell_state,
463initial_attention_weighted_encoder_context,
464gates_w,
465gates_b,
466decoder_input_lengths,
467weighted_prev_attention_context_w,
468weighted_prev_attention_context_b,
469weighted_decoder_hidden_state_t_w,
470weighted_decoder_hidden_state_t_b,
471weighted_encoder_outputs,
472attention_v,
473attention_zeros,
474encoder_outputs_transposed,
475):
476return lstm_with_attention_reference(
477input=input,
478initial_hidden_state=initial_hidden_state,
479initial_cell_state=initial_cell_state,
480initial_attention_weighted_encoder_context=(
481initial_attention_weighted_encoder_context
482),
483gates_w=gates_w,
484gates_b=gates_b,
485decoder_input_lengths=decoder_input_lengths,
486encoder_outputs_transposed=encoder_outputs_transposed,
487weighted_prev_attention_context_w=weighted_prev_attention_context_w,
488weighted_prev_attention_context_b=weighted_prev_attention_context_b,
489weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
490weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
491weighted_encoder_outputs=weighted_encoder_outputs,
492coverage_weights=None,
493attention_v=attention_v,
494attention_zeros=attention_zeros,
495compute_attention_logits=compute_recurrent_attention_logits,
496)
497
498
499def lstm_with_dot_attention_reference(
500input,
501initial_hidden_state,
502initial_cell_state,
503initial_attention_weighted_encoder_context,
504gates_w,
505gates_b,
506decoder_input_lengths,
507encoder_outputs_transposed,
508weighted_decoder_hidden_state_t_w,
509weighted_decoder_hidden_state_t_b,
510):
511return lstm_with_attention_reference(
512input=input,
513initial_hidden_state=initial_hidden_state,
514initial_cell_state=initial_cell_state,
515initial_attention_weighted_encoder_context=(
516initial_attention_weighted_encoder_context
517),
518gates_w=gates_w,
519gates_b=gates_b,
520decoder_input_lengths=decoder_input_lengths,
521encoder_outputs_transposed=encoder_outputs_transposed,
522weighted_prev_attention_context_w=None,
523weighted_prev_attention_context_b=None,
524weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
525weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
526weighted_encoder_outputs=None,
527coverage_weights=None,
528attention_v=None,
529attention_zeros=None,
530compute_attention_logits=compute_dot_attention_logits,
531)
532
533
534def lstm_with_dot_attention_reference_same_dim(
535input,
536initial_hidden_state,
537initial_cell_state,
538initial_attention_weighted_encoder_context,
539gates_w,
540gates_b,
541decoder_input_lengths,
542encoder_outputs_transposed,
543):
544return lstm_with_dot_attention_reference(
545input=input,
546initial_hidden_state=initial_hidden_state,
547initial_cell_state=initial_cell_state,
548initial_attention_weighted_encoder_context=(
549initial_attention_weighted_encoder_context
550),
551gates_w=gates_w,
552gates_b=gates_b,
553decoder_input_lengths=decoder_input_lengths,
554encoder_outputs_transposed=encoder_outputs_transposed,
555weighted_decoder_hidden_state_t_w=None,
556weighted_decoder_hidden_state_t_b=None,
557)
558
559
560def lstm_with_dot_attention_reference_different_dim(
561input,
562initial_hidden_state,
563initial_cell_state,
564initial_attention_weighted_encoder_context,
565gates_w,
566gates_b,
567decoder_input_lengths,
568weighted_decoder_hidden_state_t_w,
569weighted_decoder_hidden_state_t_b,
570encoder_outputs_transposed,
571):
572return lstm_with_dot_attention_reference(
573input=input,
574initial_hidden_state=initial_hidden_state,
575initial_cell_state=initial_cell_state,
576initial_attention_weighted_encoder_context=(
577initial_attention_weighted_encoder_context
578),
579gates_w=gates_w,
580gates_b=gates_b,
581decoder_input_lengths=decoder_input_lengths,
582encoder_outputs_transposed=encoder_outputs_transposed,
583weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
584weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
585)
586
587
588def lstm_with_coverage_attention_reference(
589input,
590initial_hidden_state,
591initial_cell_state,
592initial_attention_weighted_encoder_context,
593initial_coverage,
594gates_w,
595gates_b,
596decoder_input_lengths,
597weighted_decoder_hidden_state_t_w,
598weighted_decoder_hidden_state_t_b,
599weighted_encoder_outputs,
600coverage_weights,
601attention_v,
602attention_zeros,
603encoder_outputs_transposed,
604):
605return lstm_with_attention_reference(
606input=input,
607initial_hidden_state=initial_hidden_state,
608initial_cell_state=initial_cell_state,
609initial_attention_weighted_encoder_context=(
610initial_attention_weighted_encoder_context
611),
612gates_w=gates_w,
613gates_b=gates_b,
614decoder_input_lengths=decoder_input_lengths,
615encoder_outputs_transposed=encoder_outputs_transposed,
616weighted_prev_attention_context_w=None,
617weighted_prev_attention_context_b=None,
618weighted_decoder_hidden_state_t_w=weighted_decoder_hidden_state_t_w,
619weighted_decoder_hidden_state_t_b=weighted_decoder_hidden_state_t_b,
620weighted_encoder_outputs=weighted_encoder_outputs,
621coverage_weights=coverage_weights,
622attention_v=attention_v,
623attention_zeros=attention_zeros,
624compute_attention_logits=compute_coverage_attention_logits,
625)
626
627
628def milstm_reference(
629input,
630hidden_input,
631cell_input,
632gates_w,
633gates_b,
634alpha,
635beta1,
636beta2,
637b,
638seq_lengths,
639forget_bias,
640drop_states=False):
641T = input.shape[0]
642N = input.shape[1]
643G = input.shape[2]
644D = hidden_input.shape[hidden_input.ndim - 1]
645hidden = np.zeros(shape=(T + 1, N, D))
646cell = np.zeros(shape=(T + 1, N, D))
647assert hidden.shape[0] == T + 1
648assert cell.shape[0] == T + 1
649assert hidden.shape[1] == N
650assert cell.shape[1] == N
651cell[0, :, :] = cell_input
652hidden[0, :, :] = hidden_input
653for t in range(T):
654input_t = input[t].reshape(1, N, G)
655hidden_t_prev = hidden[t].reshape(1, N, D)
656cell_t_prev = cell[t].reshape(1, N, D)
657gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
658gates = (alpha * gates * input_t) + \
659(beta1 * gates) + \
660(beta2 * input_t) + \
661b
662hidden_t, cell_t = lstm_unit(
663hidden_t_prev,
664cell_t_prev,
665gates,
666seq_lengths,
667t,
668forget_bias=forget_bias,
669drop_states=drop_states,
670)
671hidden[t + 1] = hidden_t
672cell[t + 1] = cell_t
673return (
674hidden[1:],
675hidden[-1].reshape(1, N, D),
676cell[1:],
677cell[-1].reshape(1, N, D)
678)
679
680
681def layer_norm_milstm_reference(
682input,
683hidden_input,
684cell_input,
685gates_w,
686gates_b,
687alpha,
688beta1,
689beta2,
690b,
691gates_t_norm_scale,
692gates_t_norm_bias,
693seq_lengths,
694forget_bias,
695drop_states=False):
696T = input.shape[0]
697N = input.shape[1]
698G = input.shape[2]
699D = hidden_input.shape[hidden_input.ndim - 1]
700hidden = np.zeros(shape=(T + 1, N, D))
701cell = np.zeros(shape=(T + 1, N, D))
702assert hidden.shape[0] == T + 1
703assert cell.shape[0] == T + 1
704assert hidden.shape[1] == N
705assert cell.shape[1] == N
706cell[0, :, :] = cell_input
707hidden[0, :, :] = hidden_input
708for t in range(T):
709input_t = input[t].reshape(1, N, G)
710hidden_t_prev = hidden[t].reshape(1, N, D)
711cell_t_prev = cell[t].reshape(1, N, D)
712gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
713gates = (alpha * gates * input_t) + \
714(beta1 * gates) + \
715(beta2 * input_t) + \
716b
717gates = layer_norm_with_scale_and_bias_ref(
718gates, gates_t_norm_scale, gates_t_norm_bias
719)
720hidden_t, cell_t = lstm_unit(
721hidden_t_prev,
722cell_t_prev,
723gates,
724seq_lengths,
725t,
726forget_bias=forget_bias,
727drop_states=drop_states,
728)
729hidden[t + 1] = hidden_t
730cell[t + 1] = cell_t
731return (
732hidden[1:],
733hidden[-1].reshape(1, N, D),
734cell[1:],
735cell[-1].reshape(1, N, D)
736)
737
738
739def lstm_input():
740'''
741Create input tensor where each dimension is from 1 to 4, ndim=3 and
742last dimension size is a factor of 4
743'''
744dims_ = st.tuples(
745st.integers(min_value=1, max_value=4), # t
746st.integers(min_value=1, max_value=4), # n
747st.integers(min_value=1, max_value=4), # d
748)
749
750def create_input(dims):
751dims = list(dims)
752dims[2] *= 4
753return hu.arrays(dims)
754
755return dims_.flatmap(create_input)
756
757
758def _prepare_attention(t, n, dim_in, encoder_dim,
759forward_only=False, T=None,
760dim_out=None, residual=False,
761final_dropout=False):
762if dim_out is None:
763dim_out = [dim_in]
764print("Dims: t={} n={} dim_in={} dim_out={}".format(t, n, dim_in, dim_out))
765
766model = ModelHelper(name='external')
767
768def generate_input_state(shape):
769return np.random.random(shape).astype(np.float32)
770
771initial_states = []
772for layer_id, d in enumerate(dim_out):
773h, c = model.net.AddExternalInputs(
774"hidden_init_{}".format(layer_id),
775"cell_init_{}".format(layer_id),
776)
777initial_states.extend([h, c])
778workspace.FeedBlob(h, generate_input_state((1, n, d)))
779workspace.FeedBlob(c, generate_input_state((1, n, d)))
780
781awec_init = model.net.AddExternalInputs([
782'initial_attention_weighted_encoder_context',
783])
784initial_states.append(awec_init)
785workspace.FeedBlob(
786awec_init,
787generate_input_state((1, n, encoder_dim)),
788)
789
790# Due to convoluted RNN scoping logic we make sure that things
791# work from a namescope
792with scope.NameScope("test_name_scope"):
793(
794input_blob,
795seq_lengths,
796encoder_outputs,
797weighted_encoder_outputs,
798) = model.net.AddScopedExternalInputs(
799'input_blob',
800'seq_lengths',
801'encoder_outputs',
802'weighted_encoder_outputs',
803)
804
805layer_input_dim = dim_in
806cells = []
807for layer_id, d in enumerate(dim_out):
808
809cell = rnn_cell.MILSTMCell(
810name='decoder_{}'.format(layer_id),
811forward_only=forward_only,
812input_size=layer_input_dim,
813hidden_size=d,
814forget_bias=0.0,
815memory_optimization=False,
816)
817cells.append(cell)
818layer_input_dim = d
819
820decoder_cell = rnn_cell.MultiRNNCell(
821cells,
822name='decoder',
823residual_output_layers=range(1, len(cells)) if residual else None,
824)
825
826attention_cell = rnn_cell.AttentionCell(
827encoder_output_dim=encoder_dim,
828encoder_outputs=encoder_outputs,
829encoder_lengths=None,
830decoder_cell=decoder_cell,
831decoder_state_dim=dim_out[-1],
832name='attention_decoder',
833attention_type=AttentionType.Recurrent,
834weighted_encoder_outputs=weighted_encoder_outputs,
835attention_memory_optimization=True,
836)
837if final_dropout:
838# dropout ratio of 0.0 used to test mechanism but not interfere
839# with numerical tests
840attention_cell = rnn_cell.DropoutCell(
841internal_cell=attention_cell,
842dropout_ratio=0.0,
843name='dropout',
844forward_only=forward_only,
845is_test=False,
846)
847
848attention_cell = (
849attention_cell if T is None
850else rnn_cell.UnrolledCell(attention_cell, T)
851)
852
853output_indices = decoder_cell.output_indices
854output_indices.append(2 * len(cells))
855outputs_with_grads = [2 * i for i in output_indices]
856
857final_output, state_outputs = attention_cell.apply_over_sequence(
858model=model,
859inputs=input_blob,
860seq_lengths=seq_lengths,
861initial_states=initial_states,
862outputs_with_grads=outputs_with_grads,
863)
864
865workspace.RunNetOnce(model.param_init_net)
866
867workspace.FeedBlob(
868seq_lengths,
869np.random.randint(1, t + 1, size=(n,)).astype(np.int32)
870)
871
872return {
873'final_output': final_output,
874'net': model.net,
875'initial_states': initial_states,
876'input_blob': input_blob,
877'encoder_outputs': encoder_outputs,
878'weighted_encoder_outputs': weighted_encoder_outputs,
879'outputs_with_grads': outputs_with_grads,
880}
881
882
883class MulCell(rnn_cell.RNNCell):
884def _apply(self, model, input_t,
885seq_lengths, states, timestep, extra_inputs):
886assert len(states) == 1
887result = model.net.Mul([input_t, states[0]])
888model.net.AddExternalOutput(result)
889return [result]
890
891def get_state_names(self):
892return [self.scope("state")]
893
894
895def prepare_mul_rnn(model, input_blob, shape, T, outputs_with_grad, num_layers):
896print("Shape: ", shape)
897t, n, d = shape
898cells = [MulCell(name="layer_{}".format(i)) for i in range(num_layers)]
899cell = rnn_cell.MultiRNNCell(name="multi_mul_rnn", cells=cells)
900if T is not None:
901cell = rnn_cell.UnrolledCell(cell, T=T)
902states = [
903model.param_init_net.ConstantFill(
904[], "initial_state_{}".format(i), value=1.0, shape=[1, n, d])
905for i in range(num_layers)]
906_, results = cell.apply_over_sequence(
907model=model,
908inputs=input_blob,
909initial_states=states,
910outputs_with_grads=[
911x + 2 * (num_layers - 1) for x in outputs_with_grad
912],
913seq_lengths=None,
914)
915return results[-2:]
916
917
918class RNNCellTest(hu.HypothesisTestCase):
919@given(
920input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
921num_layers=st.integers(1, 4),
922outputs_with_grad=st.sampled_from(
923[[0], [1], [0, 1]]
924),
925)
926@ht_settings(max_examples=10, deadline=None)
927def test_unroll_mul(self, input_tensor, num_layers, outputs_with_grad):
928outputs = []
929nets = []
930input_blob = None
931for T in [input_tensor.shape[0], None]:
932model = ModelHelper("rnn_mul_{}".format(
933"unroll" if T else "dynamic"))
934input_blob = model.net.AddExternalInputs("input_blob")
935outputs.append(
936prepare_mul_rnn(model, input_blob, input_tensor.shape, T,
937outputs_with_grad, num_layers))
938workspace.RunNetOnce(model.param_init_net)
939nets.append(model.net)
940workspace.blobs[input_blob] = input_tensor
941
942gradient_checker.NetGradientChecker.CompareNets(
943nets, outputs, outputs_with_grad_ids=outputs_with_grad,
944inputs_with_grads=[input_blob],
945)
946
947@given(
948input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
949forget_bias=st.floats(-10.0, 10.0),
950drop_states=st.booleans(),
951dim_out=st.lists(
952elements=st.integers(min_value=1, max_value=3),
953min_size=1, max_size=3,
954),
955outputs_with_grads=st.sampled_from(
956[[0], [1], [0, 1], [0, 2], [0, 1, 2, 3]]
957)
958)
959@ht_settings(max_examples=10, deadline=None)
960@utils.debug
961def test_unroll_lstm(self, input_tensor, dim_out, outputs_with_grads,
962**kwargs):
963lstms = [
964_prepare_rnn(
965*input_tensor.shape,
966create_rnn=rnn_cell.LSTM,
967outputs_with_grads=outputs_with_grads,
968T=T,
969two_d_initial_states=False,
970dim_out=dim_out,
971**kwargs
972) for T in [input_tensor.shape[0], None]
973]
974outputs, nets, inputs = zip(*lstms)
975workspace.FeedBlob(inputs[0][-1], input_tensor)
976
977assert inputs[0] == inputs[1]
978gradient_checker.NetGradientChecker.CompareNets(
979nets, outputs, outputs_with_grads,
980inputs_with_grads=inputs[0],
981)
982
983@given(
984input_tensor=hu.tensor(min_dim=3, max_dim=3, max_value=3),
985encoder_length=st.integers(min_value=1, max_value=3),
986encoder_dim=st.integers(min_value=1, max_value=3),
987hidden_units=st.integers(min_value=1, max_value=3),
988num_layers=st.integers(min_value=1, max_value=3),
989residual=st.booleans(),
990final_dropout=st.booleans(),
991)
992@ht_settings(max_examples=10, deadline=None)
993@utils.debug
994def test_unroll_attention(self, input_tensor, encoder_length,
995encoder_dim, hidden_units,
996num_layers, residual,
997final_dropout):
998
999dim_out = [hidden_units] * num_layers
1000encoder_tensor = np.random.random(
1001(encoder_length, input_tensor.shape[1], encoder_dim),
1002).astype('float32')
1003
1004print('Decoder input shape: {}'.format(input_tensor.shape))
1005print('Encoder output shape: {}'.format(encoder_tensor.shape))
1006
1007# Necessary because otherwise test fails for networks with fewer
1008# layers than previous test. TODO: investigate why.
1009workspace.ResetWorkspace()
1010
1011net, unrolled = [
1012_prepare_attention(
1013t=input_tensor.shape[0],
1014n=input_tensor.shape[1],
1015dim_in=input_tensor.shape[2],
1016encoder_dim=encoder_dim,
1017T=T,
1018dim_out=dim_out,
1019residual=residual,
1020final_dropout=final_dropout,
1021) for T in [input_tensor.shape[0], None]
1022]
1023
1024workspace.FeedBlob(net['input_blob'], input_tensor)
1025workspace.FeedBlob(net['encoder_outputs'], encoder_tensor)
1026workspace.FeedBlob(
1027net['weighted_encoder_outputs'],
1028np.random.random(encoder_tensor.shape).astype('float32'),
1029)
1030
1031for input_name in [
1032'input_blob',
1033'encoder_outputs',
1034'weighted_encoder_outputs',
1035]:
1036assert net[input_name] == unrolled[input_name]
1037for state_name, unrolled_state_name in zip(
1038net['initial_states'],
1039unrolled['initial_states'],
1040):
1041assert state_name == unrolled_state_name
1042
1043inputs_with_grads = net['initial_states'] + [
1044net['input_blob'],
1045net['encoder_outputs'],
1046net['weighted_encoder_outputs'],
1047]
1048
1049gradient_checker.NetGradientChecker.CompareNets(
1050[net['net'], unrolled['net']],
1051[[net['final_output']], [unrolled['final_output']]],
1052[0],
1053inputs_with_grads=inputs_with_grads,
1054threshold=0.000001,
1055)
1056
1057@given(
1058input_tensor=hu.tensor(min_dim=3, max_dim=3),
1059forget_bias=st.floats(-10.0, 10.0),
1060forward_only=st.booleans(),
1061drop_states=st.booleans(),
1062)
1063@ht_settings(max_examples=10, deadline=None)
1064def test_layered_lstm(self, input_tensor, **kwargs):
1065for outputs_with_grads in [[0], [1], [0, 1, 2, 3]]:
1066for memory_optim in [False, True]:
1067_, net, inputs = _prepare_rnn(
1068*input_tensor.shape,
1069create_rnn=rnn_cell.LSTM,
1070outputs_with_grads=outputs_with_grads,
1071memory_optim=memory_optim,
1072**kwargs
1073)
1074workspace.FeedBlob(inputs[-1], input_tensor)
1075workspace.RunNetOnce(net)
1076workspace.ResetWorkspace()
1077
1078def test_lstm(self):
1079self.lstm_base(lstm_type=(rnn_cell.LSTM, lstm_reference))
1080
1081def test_milstm(self):
1082self.lstm_base(lstm_type=(rnn_cell.MILSTM, milstm_reference))
1083
1084@unittest.skip("This is currently numerically unstable")
1085def test_norm_lstm(self):
1086self.lstm_base(
1087lstm_type=(rnn_cell.LayerNormLSTM, layer_norm_lstm_reference),
1088)
1089
1090@unittest.skip("This is currently numerically unstable")
1091def test_norm_milstm(self):
1092self.lstm_base(
1093lstm_type=(rnn_cell.LayerNormMILSTM, layer_norm_milstm_reference)
1094)
1095
1096@given(
1097seed=st.integers(0, 2**32 - 1),
1098input_tensor=lstm_input(),
1099forget_bias=st.floats(-10.0, 10.0),
1100fwd_only=st.booleans(),
1101drop_states=st.booleans(),
1102memory_optim=st.booleans(),
1103outputs_with_grads=st.sampled_from([[0], [1], [0, 1, 2, 3]]),
1104)
1105@ht_settings(max_examples=10, deadline=None)
1106def lstm_base(self, seed, lstm_type, outputs_with_grads, memory_optim,
1107input_tensor, forget_bias, fwd_only, drop_states):
1108np.random.seed(seed)
1109create_lstm, ref = lstm_type
1110ref = partial(ref, forget_bias=forget_bias)
1111
1112t, n, d = input_tensor.shape
1113assert d % 4 == 0
1114d = d // 4
1115ref = partial(ref, forget_bias=forget_bias, drop_states=drop_states)
1116
1117net = _prepare_rnn(t, n, d, create_lstm,
1118outputs_with_grads=outputs_with_grads,
1119memory_optim=memory_optim,
1120forget_bias=forget_bias,
1121forward_only=fwd_only,
1122drop_states=drop_states)[1]
1123# here we don't provide a real input for the net but just for one of
1124# its ops (RecurrentNetworkOp). So have to hardcode this name
1125workspace.FeedBlob("test_name_scope/external/recurrent/i2h",
1126input_tensor)
1127op = net._net.op[-1]
1128inputs = [workspace.FetchBlob(name) for name in op.input]
1129
1130# Validate forward only mode is in effect
1131if fwd_only:
1132for arg in op.arg:
1133self.assertFalse(arg.name == 'backward_step_net')
1134
1135self.assertReferenceChecks(
1136hu.cpu_do,
1137op,
1138inputs,
1139ref,
1140outputs_to_check=list(range(4)),
1141)
1142
1143# Checking for input, gates_t_w and gates_t_b gradients
1144if not fwd_only:
1145for param in range(5):
1146self.assertGradientChecks(
1147device_option=hu.cpu_do,
1148op=op,
1149inputs=inputs,
1150outputs_to_check=param,
1151outputs_with_grads=outputs_with_grads,
1152threshold=0.01,
1153stepsize=0.005,
1154)
1155
1156def test_lstm_extract_predictor_net(self):
1157model = ModelHelper(name="lstm_extract_test")
1158
1159with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1160output, _, _, _ = rnn_cell.LSTM(
1161model=model,
1162input_blob="input",
1163seq_lengths="seqlengths",
1164initial_states=("hidden_init", "cell_init"),
1165dim_in=20,
1166dim_out=40,
1167scope="test",
1168drop_states=True,
1169return_last_layer_only=True,
1170)
1171# Run param init net to get the shapes for all inputs
1172shapes = {}
1173workspace.RunNetOnce(model.param_init_net)
1174for b in workspace.Blobs():
1175shapes[b] = workspace.FetchBlob(b).shape
1176
1177# But export in CPU
1178(predict_net, export_blobs) = ExtractPredictorNet(
1179net_proto=model.net.Proto(),
1180input_blobs=["input"],
1181output_blobs=[output],
1182device=core.DeviceOption(caffe2_pb2.CPU, 1),
1183)
1184
1185# Create the net and run once to see it is valid
1186# Populate external inputs with correctly shaped random input
1187# and also ensure that the export_blobs was constructed correctly.
1188workspace.ResetWorkspace()
1189shapes['input'] = [10, 4, 20]
1190shapes['cell_init'] = [1, 4, 40]
1191shapes['hidden_init'] = [1, 4, 40]
1192
1193print(predict_net.Proto().external_input)
1194self.assertTrue('seqlengths' in predict_net.Proto().external_input)
1195for einp in predict_net.Proto().external_input:
1196if einp == 'seqlengths':
1197workspace.FeedBlob(
1198"seqlengths",
1199np.array([10] * 4, dtype=np.int32)
1200)
1201else:
1202workspace.FeedBlob(
1203einp,
1204np.zeros(shapes[einp]).astype(np.float32),
1205)
1206if einp != 'input':
1207self.assertTrue(einp in export_blobs)
1208
1209print(str(predict_net.Proto()))
1210self.assertTrue(workspace.CreateNet(predict_net.Proto()))
1211self.assertTrue(workspace.RunNet(predict_net.Proto().name))
1212
1213# Validate device options set correctly for the RNNs
1214for op in predict_net.Proto().op:
1215if op.type == 'RecurrentNetwork':
1216for arg in op.arg:
1217if arg.name == "step_net":
1218for step_op in arg.n.op:
1219self.assertEqual(0, step_op.device_option.device_type)
1220self.assertEqual(1, step_op.device_option.device_id)
1221elif arg.name == 'backward_step_net':
1222self.assertEqual(caffe2_pb2.NetDef(), arg.n)
1223
1224def test_lstm_params(self):
1225model = ModelHelper(name="lstm_params_test")
1226
1227with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1228output, _, _, _ = rnn_cell.LSTM(
1229model=model,
1230input_blob="input",
1231seq_lengths="seqlengths",
1232initial_states=None,
1233dim_in=20,
1234dim_out=40,
1235scope="test",
1236drop_states=True,
1237return_last_layer_only=True,
1238)
1239for param in model.GetParams():
1240self.assertNotEqual(model.get_param_info(param), None)
1241
1242def test_milstm_params(self):
1243model = ModelHelper(name="milstm_params_test")
1244
1245with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1246output, _, _, _ = rnn_cell.MILSTM(
1247model=model,
1248input_blob="input",
1249seq_lengths="seqlengths",
1250initial_states=None,
1251dim_in=20,
1252dim_out=[40, 20],
1253scope="test",
1254drop_states=True,
1255return_last_layer_only=True,
1256)
1257for param in model.GetParams():
1258self.assertNotEqual(model.get_param_info(param), None)
1259
1260def test_layer_norm_lstm_params(self):
1261model = ModelHelper(name="layer_norm_lstm_params_test")
1262
1263with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
1264output, _, _, _ = rnn_cell.LayerNormLSTM(
1265model=model,
1266input_blob="input",
1267seq_lengths="seqlengths",
1268initial_states=None,
1269dim_in=20,
1270dim_out=40,
1271scope="test",
1272drop_states=True,
1273return_last_layer_only=True,
1274)
1275for param in model.GetParams():
1276self.assertNotEqual(model.get_param_info(param), None)
1277
1278@given(encoder_output_length=st.integers(1, 3),
1279encoder_output_dim=st.integers(1, 3),
1280decoder_input_length=st.integers(1, 3),
1281decoder_state_dim=st.integers(1, 3),
1282batch_size=st.integers(1, 3),
1283**hu.gcs)
1284@ht_settings(max_examples=10, deadline=None)
1285def test_lstm_with_regular_attention(
1286self,
1287encoder_output_length,
1288encoder_output_dim,
1289decoder_input_length,
1290decoder_state_dim,
1291batch_size,
1292gc,
1293dc,
1294):
1295self.lstm_with_attention(
1296partial(
1297rnn_cell.LSTMWithAttention,
1298attention_type=AttentionType.Regular,
1299),
1300encoder_output_length,
1301encoder_output_dim,
1302decoder_input_length,
1303decoder_state_dim,
1304batch_size,
1305lstm_with_regular_attention_reference,
1306gc,
1307)
1308
1309@given(encoder_output_length=st.integers(1, 3),
1310encoder_output_dim=st.integers(1, 3),
1311decoder_input_length=st.integers(1, 3),
1312decoder_state_dim=st.integers(1, 3),
1313batch_size=st.integers(1, 3),
1314**hu.gcs)
1315@ht_settings(max_examples=10, deadline=None)
1316def test_lstm_with_recurrent_attention(
1317self,
1318encoder_output_length,
1319encoder_output_dim,
1320decoder_input_length,
1321decoder_state_dim,
1322batch_size,
1323gc,
1324dc,
1325):
1326self.lstm_with_attention(
1327partial(
1328rnn_cell.LSTMWithAttention,
1329attention_type=AttentionType.Recurrent,
1330),
1331encoder_output_length,
1332encoder_output_dim,
1333decoder_input_length,
1334decoder_state_dim,
1335batch_size,
1336lstm_with_recurrent_attention_reference,
1337gc,
1338)
1339
1340@given(encoder_output_length=st.integers(2, 2),
1341encoder_output_dim=st.integers(4, 4),
1342decoder_input_length=st.integers(3, 3),
1343decoder_state_dim=st.integers(4, 4),
1344batch_size=st.integers(5, 5),
1345**hu.gcs)
1346@ht_settings(max_examples=2, deadline=None)
1347def test_lstm_with_dot_attention_same_dim(
1348self,
1349encoder_output_length,
1350encoder_output_dim,
1351decoder_input_length,
1352decoder_state_dim,
1353batch_size,
1354gc,
1355dc,
1356):
1357self.lstm_with_attention(
1358partial(
1359rnn_cell.LSTMWithAttention,
1360attention_type=AttentionType.Dot,
1361),
1362encoder_output_length,
1363encoder_output_dim,
1364decoder_input_length,
1365decoder_state_dim,
1366batch_size,
1367lstm_with_dot_attention_reference_same_dim,
1368gc,
1369)
1370
1371@given(encoder_output_length=st.integers(1, 3),
1372encoder_output_dim=st.integers(4, 4),
1373decoder_input_length=st.integers(1, 3),
1374decoder_state_dim=st.integers(5, 5),
1375batch_size=st.integers(1, 3),
1376**hu.gcs)
1377@ht_settings(max_examples=2, deadline=None)
1378def test_lstm_with_dot_attention_different_dim(
1379self,
1380encoder_output_length,
1381encoder_output_dim,
1382decoder_input_length,
1383decoder_state_dim,
1384batch_size,
1385gc,
1386dc,
1387):
1388self.lstm_with_attention(
1389partial(
1390rnn_cell.LSTMWithAttention,
1391attention_type=AttentionType.Dot,
1392),
1393encoder_output_length,
1394encoder_output_dim,
1395decoder_input_length,
1396decoder_state_dim,
1397batch_size,
1398lstm_with_dot_attention_reference_different_dim,
1399gc,
1400)
1401
1402@given(encoder_output_length=st.integers(2, 3),
1403encoder_output_dim=st.integers(1, 3),
1404decoder_input_length=st.integers(1, 3),
1405decoder_state_dim=st.integers(1, 3),
1406batch_size=st.integers(1, 3),
1407**hu.gcs)
1408@ht_settings(max_examples=5, deadline=None)
1409def test_lstm_with_coverage_attention(
1410self,
1411encoder_output_length,
1412encoder_output_dim,
1413decoder_input_length,
1414decoder_state_dim,
1415batch_size,
1416gc,
1417dc,
1418):
1419self.lstm_with_attention(
1420partial(
1421rnn_cell.LSTMWithAttention,
1422attention_type=AttentionType.SoftCoverage,
1423),
1424encoder_output_length,
1425encoder_output_dim,
1426decoder_input_length,
1427decoder_state_dim,
1428batch_size,
1429lstm_with_coverage_attention_reference,
1430gc,
1431)
1432
1433def lstm_with_attention(
1434self,
1435create_lstm_with_attention,
1436encoder_output_length,
1437encoder_output_dim,
1438decoder_input_length,
1439decoder_state_dim,
1440batch_size,
1441ref,
1442gc,
1443):
1444model = ModelHelper(name='external')
1445with core.DeviceScope(gc):
1446(
1447encoder_outputs,
1448decoder_inputs,
1449decoder_input_lengths,
1450initial_decoder_hidden_state,
1451initial_decoder_cell_state,
1452initial_attention_weighted_encoder_context,
1453) = model.net.AddExternalInputs(
1454'encoder_outputs',
1455'decoder_inputs',
1456'decoder_input_lengths',
1457'initial_decoder_hidden_state',
1458'initial_decoder_cell_state',
1459'initial_attention_weighted_encoder_context',
1460)
1461create_lstm_with_attention(
1462model=model,
1463decoder_inputs=decoder_inputs,
1464decoder_input_lengths=decoder_input_lengths,
1465initial_decoder_hidden_state=initial_decoder_hidden_state,
1466initial_decoder_cell_state=initial_decoder_cell_state,
1467initial_attention_weighted_encoder_context=(
1468initial_attention_weighted_encoder_context
1469),
1470encoder_output_dim=encoder_output_dim,
1471encoder_outputs=encoder_outputs,
1472encoder_lengths=None,
1473decoder_input_dim=decoder_state_dim,
1474decoder_state_dim=decoder_state_dim,
1475scope='external/LSTMWithAttention',
1476)
1477op = model.net._net.op[-2]
1478workspace.RunNetOnce(model.param_init_net)
1479
1480# This is original decoder_inputs after linear layer
1481decoder_input_blob = op.input[0]
1482
1483workspace.FeedBlob(
1484decoder_input_blob,
1485np.random.randn(
1486decoder_input_length,
1487batch_size,
1488decoder_state_dim * 4,
1489).astype(np.float32))
1490workspace.FeedBlob(
1491'external/LSTMWithAttention/encoder_outputs_transposed',
1492np.random.randn(
1493batch_size,
1494encoder_output_dim,
1495encoder_output_length,
1496).astype(np.float32),
1497)
1498workspace.FeedBlob(
1499'external/LSTMWithAttention/weighted_encoder_outputs',
1500np.random.randn(
1501encoder_output_length,
1502batch_size,
1503encoder_output_dim,
1504).astype(np.float32),
1505)
1506workspace.FeedBlob(
1507'external/LSTMWithAttention/coverage_weights',
1508np.random.randn(
1509encoder_output_length,
1510batch_size,
1511encoder_output_dim,
1512).astype(np.float32),
1513)
1514workspace.FeedBlob(
1515decoder_input_lengths,
1516np.random.randint(
15170,
1518decoder_input_length + 1,
1519size=(batch_size,)
1520).astype(np.int32))
1521workspace.FeedBlob(
1522initial_decoder_hidden_state,
1523np.random.randn(1, batch_size, decoder_state_dim).astype(np.float32)
1524)
1525workspace.FeedBlob(
1526initial_decoder_cell_state,
1527np.random.randn(1, batch_size, decoder_state_dim).astype(np.float32)
1528)
1529workspace.FeedBlob(
1530initial_attention_weighted_encoder_context,
1531np.random.randn(
15321, batch_size, encoder_output_dim).astype(np.float32)
1533)
1534workspace.FeedBlob(
1535'external/LSTMWithAttention/initial_coverage',
1536np.zeros((1, batch_size, encoder_output_length)).astype(np.float32),
1537)
1538inputs = [workspace.FetchBlob(name) for name in op.input]
1539self.assertReferenceChecks(
1540device_option=gc,
1541op=op,
1542inputs=inputs,
1543reference=ref,
1544grad_reference=None,
1545output_to_grad=None,
1546outputs_to_check=list(range(6)),
1547)
1548gradients_to_check = [
1549index for (index, input_name) in enumerate(op.input)
1550if input_name != 'decoder_input_lengths'
1551]
1552for param in gradients_to_check:
1553self.assertGradientChecks(
1554device_option=gc,
1555op=op,
1556inputs=inputs,
1557outputs_to_check=param,
1558outputs_with_grads=[0, 4],
1559threshold=0.01,
1560stepsize=0.001,
1561)
1562
1563@given(seed=st.integers(0, 2**32 - 1),
1564n=st.integers(1, 10),
1565d=st.integers(1, 10),
1566t=st.integers(1, 10),
1567dtype=st.sampled_from([np.float32, np.float16]),
1568use_sequence_lengths=st.booleans(),
1569**hu.gcs)
1570@ht_settings(max_examples=10, deadline=None)
1571def test_lstm_unit_recurrent_network(
1572self, seed, n, d, t, dtype, dc, use_sequence_lengths, gc):
1573np.random.seed(seed)
1574if dtype == np.float16:
1575# only supported with CUDA/HIP
1576assume(gc.device_type == workspace.GpuDeviceType)
1577dc = [do for do in dc if do.device_type == workspace.GpuDeviceType]
1578
1579if use_sequence_lengths:
1580op_inputs = ['hidden_t_prev', 'cell_t_prev', 'gates_t',
1581'seq_lengths', 'timestep']
1582else:
1583op_inputs = ['hidden_t_prev', 'cell_t_prev', 'gates_t', 'timestep']
1584op = core.CreateOperator(
1585'LSTMUnit',
1586op_inputs,
1587['hidden_t', 'cell_t'],
1588sequence_lengths=use_sequence_lengths,
1589)
1590cell_t_prev = np.random.randn(1, n, d).astype(dtype)
1591hidden_t_prev = np.random.randn(1, n, d).astype(dtype)
1592gates = np.random.randn(1, n, 4 * d).astype(dtype)
1593seq_lengths = np.random.randint(1, t + 1, size=(n,)).astype(np.int32)
1594timestep = np.random.randint(0, t, size=(1,)).astype(np.int32)
1595if use_sequence_lengths:
1596inputs = [hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep]
1597else:
1598inputs = [hidden_t_prev, cell_t_prev, gates, timestep]
1599input_device_options = {'timestep': hu.cpu_do}
1600self.assertDeviceChecks(
1601dc, op, inputs, [0],
1602input_device_options=input_device_options)
1603
1604kwargs = {}
1605if dtype == np.float16:
1606kwargs['threshold'] = 1e-1 # default is 1e-4
1607
1608def lstm_unit_reference(*args, **kwargs):
1609return lstm_unit(*args, sequence_lengths=use_sequence_lengths, **kwargs)
1610
1611self.assertReferenceChecks(
1612gc, op, inputs, lstm_unit_reference,
1613input_device_options=input_device_options,
1614**kwargs)
1615
1616kwargs = {}
1617if dtype == np.float16:
1618kwargs['threshold'] = 0.5 # default is 0.005
1619
1620for i in range(2):
1621self.assertGradientChecks(
1622gc, op, inputs, i, [0, 1],
1623input_device_options=input_device_options,
1624**kwargs)
1625
1626@given(input_length=st.integers(2, 5),
1627dim_in=st.integers(1, 3),
1628max_num_units=st.integers(1, 3),
1629num_layers=st.integers(2, 3),
1630batch_size=st.integers(1, 3))
1631@ht_settings(max_examples=10, deadline=None)
1632def test_multi_lstm(
1633self,
1634input_length,
1635dim_in,
1636max_num_units,
1637num_layers,
1638batch_size,
1639):
1640model = ModelHelper(name='external')
1641(
1642input_sequence,
1643seq_lengths,
1644) = model.net.AddExternalInputs(
1645'input_sequence',
1646'seq_lengths',
1647)
1648dim_out = [
1649np.random.randint(1, max_num_units + 1)
1650for _ in range(num_layers)
1651]
1652h_all, h_last, c_all, c_last = rnn_cell.LSTM(
1653model=model,
1654input_blob=input_sequence,
1655seq_lengths=seq_lengths,
1656initial_states=None,
1657dim_in=dim_in,
1658dim_out=dim_out,
1659# scope='test',
1660outputs_with_grads=(0,),
1661return_params=False,
1662memory_optimization=False,
1663forget_bias=0.0,
1664forward_only=False,
1665return_last_layer_only=True,
1666)
1667
1668workspace.RunNetOnce(model.param_init_net)
1669
1670seq_lengths_val = np.random.randint(
16711,
1672input_length + 1,
1673size=(batch_size),
1674).astype(np.int32)
1675input_sequence_val = np.random.randn(
1676input_length,
1677batch_size,
1678dim_in,
1679).astype(np.float32)
1680workspace.FeedBlob(seq_lengths, seq_lengths_val)
1681workspace.FeedBlob(input_sequence, input_sequence_val)
1682
1683hidden_input_list = []
1684cell_input_list = []
1685i2h_w_list = []
1686i2h_b_list = []
1687gates_w_list = []
1688gates_b_list = []
1689
1690for i in range(num_layers):
1691hidden_input_list.append(
1692workspace.FetchBlob(
1693'layer_{}/initial_hidden_state'.format(i)),
1694)
1695cell_input_list.append(
1696workspace.FetchBlob(
1697'layer_{}/initial_cell_state'.format(i)),
1698)
1699# Input projection for the first layer is produced outside
1700# of the cell ans thus not scoped
1701prefix = 'layer_{}/'.format(i) if i > 0 else ''
1702i2h_w_list.append(
1703workspace.FetchBlob('{}i2h_w'.format(prefix)),
1704)
1705i2h_b_list.append(
1706workspace.FetchBlob('{}i2h_b'.format(prefix)),
1707)
1708gates_w_list.append(
1709workspace.FetchBlob('layer_{}/gates_t_w'.format(i)),
1710)
1711gates_b_list.append(
1712workspace.FetchBlob('layer_{}/gates_t_b'.format(i)),
1713)
1714
1715workspace.RunNetOnce(model.net)
1716h_all_calc = workspace.FetchBlob(h_all)
1717h_last_calc = workspace.FetchBlob(h_last)
1718c_all_calc = workspace.FetchBlob(c_all)
1719c_last_calc = workspace.FetchBlob(c_last)
1720
1721h_all_ref, h_last_ref, c_all_ref, c_last_ref = multi_lstm_reference(
1722input_sequence_val,
1723hidden_input_list,
1724cell_input_list,
1725i2h_w_list,
1726i2h_b_list,
1727gates_w_list,
1728gates_b_list,
1729seq_lengths_val,
1730forget_bias=0.0,
1731)
1732
1733h_all_delta = np.abs(h_all_ref - h_all_calc).sum()
1734h_last_delta = np.abs(h_last_ref - h_last_calc).sum()
1735c_all_delta = np.abs(c_all_ref - c_all_calc).sum()
1736c_last_delta = np.abs(c_last_ref - c_last_calc).sum()
1737
1738self.assertAlmostEqual(h_all_delta, 0.0, places=5)
1739self.assertAlmostEqual(h_last_delta, 0.0, places=5)
1740self.assertAlmostEqual(c_all_delta, 0.0, places=5)
1741self.assertAlmostEqual(c_last_delta, 0.0, places=5)
1742
1743input_values = {
1744'input_sequence': input_sequence_val,
1745'seq_lengths': seq_lengths_val,
1746}
1747for param in model.GetParams():
1748value = workspace.FetchBlob(param)
1749input_values[str(param)] = value
1750
1751output_sum = model.net.SumElements(
1752[h_all],
1753'output_sum',
1754average=True,
1755)
1756fake_loss = model.net.Tanh(
1757output_sum,
1758)
1759for param in model.GetParams():
1760gradient_checker.NetGradientChecker.Check(
1761model.net,
1762outputs_with_grad=[fake_loss],
1763input_values=input_values,
1764input_to_check=str(param),
1765print_net=False,
1766step_size=0.0001,
1767threshold=0.05,
1768)
1769
1770
1771if __name__ == "__main__":
1772workspace.GlobalInit([
1773'caffe2',
1774'--caffe2_log_level=0',
1775])
1776unittest.main()
1777