8
from caffe2.python import brew
12
Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
18
return "{}/{}".format(str(scope), str(name))
22
def _calc_weighted_context(
24
encoder_outputs_transposed,
30
attention_weighted_encoder_context = brew.batch_mat_mul(
32
[encoder_outputs_transposed, attention_weights_3d],
33
s(scope, 'attention_weighted_encoder_context'),
36
attention_weighted_encoder_context, _ = model.net.Reshape(
37
attention_weighted_encoder_context,
39
attention_weighted_encoder_context,
40
s(scope, 'attention_weighted_encoder_context_old_shape'),
42
shape=[1, -1, encoder_output_dim],
44
return attention_weighted_encoder_context
48
def _calc_attention_weights(
50
attention_logits_transposed,
54
if encoder_lengths is not None:
55
attention_logits_transposed = model.net.SequenceMask(
56
[attention_logits_transposed, encoder_lengths],
57
['masked_attention_logits'],
62
attention_weights_3d = brew.softmax(
64
attention_logits_transposed,
65
s(scope, 'attention_weights_3d'),
69
return attention_weights_3d
73
def _calc_attention_logits_from_sum_match(
75
decoder_hidden_encoder_outputs_sum,
80
decoder_hidden_encoder_outputs_sum = model.net.Tanh(
81
decoder_hidden_encoder_outputs_sum,
82
decoder_hidden_encoder_outputs_sum,
86
attention_logits = brew.fc(
88
decoder_hidden_encoder_outputs_sum,
89
s(scope, 'attention_logits'),
90
dim_in=encoder_output_dim,
97
attention_logits_transposed = brew.transpose(
100
s(scope, 'attention_logits_transposed'),
103
return attention_logits_transposed
107
def _apply_fc_weight_for_sum_match(
123
output = model.net.Squeeze(
132
def apply_recurrent_attention(
135
encoder_outputs_transposed,
136
weighted_encoder_outputs,
137
decoder_hidden_state_t,
138
decoder_hidden_state_dim,
139
attention_weighted_encoder_context_t_prev,
141
encoder_lengths=None,
143
weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
145
input=attention_weighted_encoder_context_t_prev,
146
dim_in=encoder_output_dim,
147
dim_out=encoder_output_dim,
149
name='weighted_prev_attention_context',
152
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
154
input=decoder_hidden_state_t,
155
dim_in=decoder_hidden_state_dim,
156
dim_out=encoder_output_dim,
158
name='weighted_decoder_hidden_state',
161
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
163
weighted_prev_attention_context,
164
weighted_decoder_hidden_state,
166
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
169
decoder_hidden_encoder_outputs_sum = model.net.Add(
171
weighted_encoder_outputs,
172
decoder_hidden_encoder_outputs_sum_tmp,
174
s(scope, 'decoder_hidden_encoder_outputs_sum'),
177
attention_logits_transposed = _calc_attention_logits_from_sum_match(
179
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
180
encoder_output_dim=encoder_output_dim,
185
attention_weights_3d = _calc_attention_weights(
187
attention_logits_transposed=attention_logits_transposed,
189
encoder_lengths=encoder_lengths,
193
attention_weighted_encoder_context = _calc_weighted_context(
195
encoder_outputs_transposed=encoder_outputs_transposed,
196
encoder_output_dim=encoder_output_dim,
197
attention_weights_3d=attention_weights_3d,
200
return attention_weighted_encoder_context, attention_weights_3d, [
201
decoder_hidden_encoder_outputs_sum,
205
def apply_regular_attention(
208
encoder_outputs_transposed,
209
weighted_encoder_outputs,
210
decoder_hidden_state_t,
211
decoder_hidden_state_dim,
213
encoder_lengths=None,
215
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
217
input=decoder_hidden_state_t,
218
dim_in=decoder_hidden_state_dim,
219
dim_out=encoder_output_dim,
221
name='weighted_decoder_hidden_state',
225
decoder_hidden_encoder_outputs_sum = model.net.Add(
226
[weighted_encoder_outputs, weighted_decoder_hidden_state],
227
s(scope, 'decoder_hidden_encoder_outputs_sum'),
232
attention_logits_transposed = _calc_attention_logits_from_sum_match(
234
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
235
encoder_output_dim=encoder_output_dim,
240
attention_weights_3d = _calc_attention_weights(
242
attention_logits_transposed=attention_logits_transposed,
244
encoder_lengths=encoder_lengths,
248
attention_weighted_encoder_context = _calc_weighted_context(
250
encoder_outputs_transposed=encoder_outputs_transposed,
251
encoder_output_dim=encoder_output_dim,
252
attention_weights_3d=attention_weights_3d,
255
return attention_weighted_encoder_context, attention_weights_3d, [
256
decoder_hidden_encoder_outputs_sum,
260
def apply_dot_attention(
264
encoder_outputs_transposed,
266
decoder_hidden_state_t,
267
decoder_hidden_state_dim,
269
encoder_lengths=None,
271
if decoder_hidden_state_dim != encoder_output_dim:
272
weighted_decoder_hidden_state = brew.fc(
274
decoder_hidden_state_t,
275
s(scope, 'weighted_decoder_hidden_state'),
276
dim_in=decoder_hidden_state_dim,
277
dim_out=encoder_output_dim,
281
weighted_decoder_hidden_state = decoder_hidden_state_t
284
squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
285
weighted_decoder_hidden_state,
286
s(scope, 'squeezed_weighted_decoder_hidden_state'),
291
expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
292
squeezed_weighted_decoder_hidden_state,
293
squeezed_weighted_decoder_hidden_state,
298
attention_logits_transposed = model.net.BatchMatMul(
300
encoder_outputs_transposed,
301
expanddims_squeezed_weighted_decoder_hidden_state,
303
s(scope, 'attention_logits'),
308
attention_weights_3d = _calc_attention_weights(
310
attention_logits_transposed=attention_logits_transposed,
312
encoder_lengths=encoder_lengths,
316
attention_weighted_encoder_context = _calc_weighted_context(
318
encoder_outputs_transposed=encoder_outputs_transposed,
319
encoder_output_dim=encoder_output_dim,
320
attention_weights_3d=attention_weights_3d,
323
return attention_weighted_encoder_context, attention_weights_3d, []
326
def apply_soft_coverage_attention(
329
encoder_outputs_transposed,
330
weighted_encoder_outputs,
331
decoder_hidden_state_t,
332
decoder_hidden_state_dim,
339
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
341
input=decoder_hidden_state_t,
342
dim_in=decoder_hidden_state_dim,
343
dim_out=encoder_output_dim,
345
name='weighted_decoder_hidden_state',
349
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
350
[weighted_encoder_outputs, weighted_decoder_hidden_state],
351
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
355
coverage_t_prev_2d = model.net.Squeeze(
357
s(scope, 'coverage_t_prev_2d'),
361
coverage_t_prev_transposed = brew.transpose(
364
s(scope, 'coverage_t_prev_transposed'),
368
scaled_coverage_weights = model.net.Mul(
369
[coverage_weights, coverage_t_prev_transposed],
370
s(scope, 'scaled_coverage_weights'),
376
decoder_hidden_encoder_outputs_sum = model.net.Add(
377
[decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
378
s(scope, 'decoder_hidden_encoder_outputs_sum'),
382
attention_logits_transposed = _calc_attention_logits_from_sum_match(
384
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
385
encoder_output_dim=encoder_output_dim,
390
attention_weights_3d = _calc_attention_weights(
392
attention_logits_transposed=attention_logits_transposed,
394
encoder_lengths=encoder_lengths,
398
attention_weighted_encoder_context = _calc_weighted_context(
400
encoder_outputs_transposed=encoder_outputs_transposed,
401
encoder_output_dim=encoder_output_dim,
402
attention_weights_3d=attention_weights_3d,
407
attention_weights_2d = model.net.Squeeze(
408
attention_weights_3d,
409
s(scope, 'attention_weights_2d'),
413
coverage_t = model.net.Add(
414
[coverage_t_prev, attention_weights_2d],
415
s(scope, 'coverage_t'),
420
attention_weighted_encoder_context,
421
attention_weights_3d,
422
[decoder_hidden_encoder_outputs_sum],