pytorch

Форк
0
/
attention.py 
424 строки · 12.1 Кб
1
## @package attention
2
# Module caffe2.python.attention
3

4

5

6

7

8
from caffe2.python import brew
9

10

11
class AttentionType:
12
    Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
13

14

15
def s(scope, name):
16
    # We have to manually scope due to our internal/external blob
17
    # relationships.
18
    return "{}/{}".format(str(scope), str(name))
19

20

21
# c_i = \sum_j w_{ij}\textbf{s}_j
22
def _calc_weighted_context(
23
    model,
24
    encoder_outputs_transposed,
25
    encoder_output_dim,
26
    attention_weights_3d,
27
    scope,
28
):
29
    # [batch_size, encoder_output_dim, 1]
30
    attention_weighted_encoder_context = brew.batch_mat_mul(
31
        model,
32
        [encoder_outputs_transposed, attention_weights_3d],
33
        s(scope, 'attention_weighted_encoder_context'),
34
    )
35
    # [batch_size, encoder_output_dim]
36
    attention_weighted_encoder_context, _ = model.net.Reshape(
37
        attention_weighted_encoder_context,
38
        [
39
            attention_weighted_encoder_context,
40
            s(scope, 'attention_weighted_encoder_context_old_shape'),
41
        ],
42
        shape=[1, -1, encoder_output_dim],
43
    )
44
    return attention_weighted_encoder_context
45

46

47
# Calculate a softmax over the passed in attention energy logits
48
def _calc_attention_weights(
49
    model,
50
    attention_logits_transposed,
51
    scope,
52
    encoder_lengths=None,
53
):
54
    if encoder_lengths is not None:
55
        attention_logits_transposed = model.net.SequenceMask(
56
            [attention_logits_transposed, encoder_lengths],
57
            ['masked_attention_logits'],
58
            mode='sequence',
59
        )
60

61
    # [batch_size, encoder_length, 1]
62
    attention_weights_3d = brew.softmax(
63
        model,
64
        attention_logits_transposed,
65
        s(scope, 'attention_weights_3d'),
66
        engine='CUDNN',
67
        axis=1,
68
    )
69
    return attention_weights_3d
70

71

72
# e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
73
def _calc_attention_logits_from_sum_match(
74
    model,
75
    decoder_hidden_encoder_outputs_sum,
76
    encoder_output_dim,
77
    scope,
78
):
79
    # [encoder_length, batch_size, encoder_output_dim]
80
    decoder_hidden_encoder_outputs_sum = model.net.Tanh(
81
        decoder_hidden_encoder_outputs_sum,
82
        decoder_hidden_encoder_outputs_sum,
83
    )
84

85
    # [encoder_length, batch_size, 1]
86
    attention_logits = brew.fc(
87
        model,
88
        decoder_hidden_encoder_outputs_sum,
89
        s(scope, 'attention_logits'),
90
        dim_in=encoder_output_dim,
91
        dim_out=1,
92
        axis=2,
93
        freeze_bias=True,
94
    )
95

96
    # [batch_size, encoder_length, 1]
97
    attention_logits_transposed = brew.transpose(
98
        model,
99
        attention_logits,
100
        s(scope, 'attention_logits_transposed'),
101
        axes=[1, 0, 2],
102
    )
103
    return attention_logits_transposed
104

105

106
# \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
107
def _apply_fc_weight_for_sum_match(
108
    model,
109
    input,
110
    dim_in,
111
    dim_out,
112
    scope,
113
    name,
114
):
115
    output = brew.fc(
116
        model,
117
        input,
118
        s(scope, name),
119
        dim_in=dim_in,
120
        dim_out=dim_out,
121
        axis=2,
122
    )
123
    output = model.net.Squeeze(
124
        output,
125
        output,
126
        dims=[0],
127
    )
128
    return output
129

130

131
# Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
132
def apply_recurrent_attention(
133
    model,
134
    encoder_output_dim,
135
    encoder_outputs_transposed,
136
    weighted_encoder_outputs,
137
    decoder_hidden_state_t,
138
    decoder_hidden_state_dim,
139
    attention_weighted_encoder_context_t_prev,
140
    scope,
141
    encoder_lengths=None,
142
):
143
    weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
144
        model=model,
145
        input=attention_weighted_encoder_context_t_prev,
146
        dim_in=encoder_output_dim,
147
        dim_out=encoder_output_dim,
148
        scope=scope,
149
        name='weighted_prev_attention_context',
150
    )
151

152
    weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
153
        model=model,
154
        input=decoder_hidden_state_t,
155
        dim_in=decoder_hidden_state_dim,
156
        dim_out=encoder_output_dim,
157
        scope=scope,
158
        name='weighted_decoder_hidden_state',
159
    )
160
    # [1, batch_size, encoder_output_dim]
161
    decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
162
        [
163
            weighted_prev_attention_context,
164
            weighted_decoder_hidden_state,
165
        ],
166
        s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
167
    )
168
    # [encoder_length, batch_size, encoder_output_dim]
169
    decoder_hidden_encoder_outputs_sum = model.net.Add(
170
        [
171
            weighted_encoder_outputs,
172
            decoder_hidden_encoder_outputs_sum_tmp,
173
        ],
174
        s(scope, 'decoder_hidden_encoder_outputs_sum'),
175
        broadcast=1,
176
    )
177
    attention_logits_transposed = _calc_attention_logits_from_sum_match(
178
        model=model,
179
        decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
180
        encoder_output_dim=encoder_output_dim,
181
        scope=scope,
182
    )
183

184
    # [batch_size, encoder_length, 1]
185
    attention_weights_3d = _calc_attention_weights(
186
        model=model,
187
        attention_logits_transposed=attention_logits_transposed,
188
        scope=scope,
189
        encoder_lengths=encoder_lengths,
190
    )
191

192
    # [batch_size, encoder_output_dim, 1]
193
    attention_weighted_encoder_context = _calc_weighted_context(
194
        model=model,
195
        encoder_outputs_transposed=encoder_outputs_transposed,
196
        encoder_output_dim=encoder_output_dim,
197
        attention_weights_3d=attention_weights_3d,
198
        scope=scope,
199
    )
200
    return attention_weighted_encoder_context, attention_weights_3d, [
201
        decoder_hidden_encoder_outputs_sum,
202
    ]
203

204

205
def apply_regular_attention(
206
    model,
207
    encoder_output_dim,
208
    encoder_outputs_transposed,
209
    weighted_encoder_outputs,
210
    decoder_hidden_state_t,
211
    decoder_hidden_state_dim,
212
    scope,
213
    encoder_lengths=None,
214
):
215
    weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
216
        model=model,
217
        input=decoder_hidden_state_t,
218
        dim_in=decoder_hidden_state_dim,
219
        dim_out=encoder_output_dim,
220
        scope=scope,
221
        name='weighted_decoder_hidden_state',
222
    )
223

224
    # [encoder_length, batch_size, encoder_output_dim]
225
    decoder_hidden_encoder_outputs_sum = model.net.Add(
226
        [weighted_encoder_outputs, weighted_decoder_hidden_state],
227
        s(scope, 'decoder_hidden_encoder_outputs_sum'),
228
        broadcast=1,
229
        use_grad_hack=1,
230
    )
231

232
    attention_logits_transposed = _calc_attention_logits_from_sum_match(
233
        model=model,
234
        decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
235
        encoder_output_dim=encoder_output_dim,
236
        scope=scope,
237
    )
238

239
    # [batch_size, encoder_length, 1]
240
    attention_weights_3d = _calc_attention_weights(
241
        model=model,
242
        attention_logits_transposed=attention_logits_transposed,
243
        scope=scope,
244
        encoder_lengths=encoder_lengths,
245
    )
246

247
    # [batch_size, encoder_output_dim, 1]
248
    attention_weighted_encoder_context = _calc_weighted_context(
249
        model=model,
250
        encoder_outputs_transposed=encoder_outputs_transposed,
251
        encoder_output_dim=encoder_output_dim,
252
        attention_weights_3d=attention_weights_3d,
253
        scope=scope,
254
    )
255
    return attention_weighted_encoder_context, attention_weights_3d, [
256
        decoder_hidden_encoder_outputs_sum,
257
    ]
258

259

260
def apply_dot_attention(
261
    model,
262
    encoder_output_dim,
263
    # [batch_size, encoder_output_dim, encoder_length]
264
    encoder_outputs_transposed,
265
    # [1, batch_size, decoder_state_dim]
266
    decoder_hidden_state_t,
267
    decoder_hidden_state_dim,
268
    scope,
269
    encoder_lengths=None,
270
):
271
    if decoder_hidden_state_dim != encoder_output_dim:
272
        weighted_decoder_hidden_state = brew.fc(
273
            model,
274
            decoder_hidden_state_t,
275
            s(scope, 'weighted_decoder_hidden_state'),
276
            dim_in=decoder_hidden_state_dim,
277
            dim_out=encoder_output_dim,
278
            axis=2,
279
        )
280
    else:
281
        weighted_decoder_hidden_state = decoder_hidden_state_t
282

283
    # [batch_size, decoder_state_dim]
284
    squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
285
        weighted_decoder_hidden_state,
286
        s(scope, 'squeezed_weighted_decoder_hidden_state'),
287
        dims=[0],
288
    )
289

290
    # [batch_size, decoder_state_dim, 1]
291
    expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
292
        squeezed_weighted_decoder_hidden_state,
293
        squeezed_weighted_decoder_hidden_state,
294
        dims=[2],
295
    )
296

297
    # [batch_size, encoder_output_dim, 1]
298
    attention_logits_transposed = model.net.BatchMatMul(
299
        [
300
            encoder_outputs_transposed,
301
            expanddims_squeezed_weighted_decoder_hidden_state,
302
        ],
303
        s(scope, 'attention_logits'),
304
        trans_a=1,
305
    )
306

307
    # [batch_size, encoder_length, 1]
308
    attention_weights_3d = _calc_attention_weights(
309
        model=model,
310
        attention_logits_transposed=attention_logits_transposed,
311
        scope=scope,
312
        encoder_lengths=encoder_lengths,
313
    )
314

315
    # [batch_size, encoder_output_dim, 1]
316
    attention_weighted_encoder_context = _calc_weighted_context(
317
        model=model,
318
        encoder_outputs_transposed=encoder_outputs_transposed,
319
        encoder_output_dim=encoder_output_dim,
320
        attention_weights_3d=attention_weights_3d,
321
        scope=scope,
322
    )
323
    return attention_weighted_encoder_context, attention_weights_3d, []
324

325

326
def apply_soft_coverage_attention(
327
    model,
328
    encoder_output_dim,
329
    encoder_outputs_transposed,
330
    weighted_encoder_outputs,
331
    decoder_hidden_state_t,
332
    decoder_hidden_state_dim,
333
    scope,
334
    encoder_lengths,
335
    coverage_t_prev,
336
    coverage_weights,
337
):
338

339
    weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
340
        model=model,
341
        input=decoder_hidden_state_t,
342
        dim_in=decoder_hidden_state_dim,
343
        dim_out=encoder_output_dim,
344
        scope=scope,
345
        name='weighted_decoder_hidden_state',
346
    )
347

348
    # [encoder_length, batch_size, encoder_output_dim]
349
    decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
350
        [weighted_encoder_outputs, weighted_decoder_hidden_state],
351
        s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
352
        broadcast=1,
353
    )
354
    # [batch_size, encoder_length]
355
    coverage_t_prev_2d = model.net.Squeeze(
356
        coverage_t_prev,
357
        s(scope, 'coverage_t_prev_2d'),
358
        dims=[0],
359
    )
360
    # [encoder_length, batch_size]
361
    coverage_t_prev_transposed = brew.transpose(
362
        model,
363
        coverage_t_prev_2d,
364
        s(scope, 'coverage_t_prev_transposed'),
365
    )
366

367
    # [encoder_length, batch_size, encoder_output_dim]
368
    scaled_coverage_weights = model.net.Mul(
369
        [coverage_weights, coverage_t_prev_transposed],
370
        s(scope, 'scaled_coverage_weights'),
371
        broadcast=1,
372
        axis=0,
373
    )
374

375
    # [encoder_length, batch_size, encoder_output_dim]
376
    decoder_hidden_encoder_outputs_sum = model.net.Add(
377
        [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
378
        s(scope, 'decoder_hidden_encoder_outputs_sum'),
379
    )
380

381
    # [batch_size, encoder_length, 1]
382
    attention_logits_transposed = _calc_attention_logits_from_sum_match(
383
        model=model,
384
        decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
385
        encoder_output_dim=encoder_output_dim,
386
        scope=scope,
387
    )
388

389
    # [batch_size, encoder_length, 1]
390
    attention_weights_3d = _calc_attention_weights(
391
        model=model,
392
        attention_logits_transposed=attention_logits_transposed,
393
        scope=scope,
394
        encoder_lengths=encoder_lengths,
395
    )
396

397
    # [batch_size, encoder_output_dim, 1]
398
    attention_weighted_encoder_context = _calc_weighted_context(
399
        model=model,
400
        encoder_outputs_transposed=encoder_outputs_transposed,
401
        encoder_output_dim=encoder_output_dim,
402
        attention_weights_3d=attention_weights_3d,
403
        scope=scope,
404
    )
405

406
    # [batch_size, encoder_length]
407
    attention_weights_2d = model.net.Squeeze(
408
        attention_weights_3d,
409
        s(scope, 'attention_weights_2d'),
410
        dims=[2],
411
    )
412

413
    coverage_t = model.net.Add(
414
        [coverage_t_prev, attention_weights_2d],
415
        s(scope, 'coverage_t'),
416
        broadcast=1,
417
    )
418

419
    return (
420
        attention_weighted_encoder_context,
421
        attention_weights_3d,
422
        [decoder_hidden_encoder_outputs_sum],
423
        coverage_t,
424
    )
425

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.