pytorch

Форк
0
/
rnn_cell.py 
1978 строк · 66.4 Кб
1
## @package rnn_cell
2
# Module caffe2.python.rnn_cell
3

4

5

6

7

8
import functools
9
import inspect
10
import logging
11
import numpy as np
12
import random
13

14
from caffe2.proto import caffe2_pb2
15
from caffe2.python.attention import (
16
    apply_dot_attention,
17
    apply_recurrent_attention,
18
    apply_regular_attention,
19
    apply_soft_coverage_attention,
20
    AttentionType,
21
)
22
from caffe2.python import core, recurrent, workspace, brew, scope, utils
23
from caffe2.python.modeling.parameter_sharing import ParameterSharing
24
from caffe2.python.modeling.parameter_info import ParameterTags
25
from caffe2.python.modeling.initializers import Initializer
26
from caffe2.python.model_helper import ModelHelper
27

28

29
def _RectifyName(blob_reference_or_name):
30
    if blob_reference_or_name is None:
31
        return None
32
    if isinstance(blob_reference_or_name, str):
33
        return core.ScopedBlobReference(blob_reference_or_name)
34
    if not isinstance(blob_reference_or_name, core.BlobReference):
35
        raise Exception("Unknown blob reference type")
36
    return blob_reference_or_name
37

38

39
def _RectifyNames(blob_references_or_names):
40
    if blob_references_or_names is None:
41
        return None
42
    return [_RectifyName(i) for i in blob_references_or_names]
43

44

45
class RNNCell:
46
    '''
47
    Base class for writing recurrent / stateful operations.
48

49
    One needs to implement 2 methods: apply_override
50
    and get_state_names_override.
51

52
    As a result base class will provice apply_over_sequence method, which
53
    allows you to apply recurrent operations over a sequence of any length.
54

55
    As optional you could add input and output preparation steps by overriding
56
    corresponding methods.
57
    '''
58
    def __init__(self, name=None, forward_only=False, initializer=None):
59
        self.name = name
60
        self.recompute_blobs = []
61
        self.forward_only = forward_only
62
        self._initializer = initializer
63

64
    @property
65
    def initializer(self):
66
        return self._initializer
67

68
    @initializer.setter
69
    def initializer(self, value):
70
        self._initializer = value
71

72
    def scope(self, name):
73
        return self.name + '/' + name if self.name is not None else name
74

75
    def apply_over_sequence(
76
        self,
77
        model,
78
        inputs,
79
        seq_lengths=None,
80
        initial_states=None,
81
        outputs_with_grads=None,
82
    ):
83
        if initial_states is None:
84
            with scope.NameScope(self.name):
85
                if self.initializer is None:
86
                    raise Exception("Either initial states "
87
                                    "or initializer have to be set")
88
                initial_states = self.initializer.create_states(model)
89

90
        preprocessed_inputs = self.prepare_input(model, inputs)
91
        step_model = ModelHelper(name=self.name, param_model=model)
92
        input_t, timestep = step_model.net.AddScopedExternalInputs(
93
            'input_t',
94
            'timestep',
95
        )
96
        utils.raiseIfNotEqual(
97
            len(initial_states), len(self.get_state_names()),
98
            "Number of initial state values provided doesn't match the number "
99
            "of states"
100
        )
101
        states_prev = step_model.net.AddScopedExternalInputs(*[
102
            s + '_prev' for s in self.get_state_names()
103
        ])
104
        states = self._apply(
105
            model=step_model,
106
            input_t=input_t,
107
            seq_lengths=seq_lengths,
108
            states=states_prev,
109
            timestep=timestep,
110
        )
111

112
        external_outputs = set(step_model.net.Proto().external_output)
113
        for state in states:
114
            if state not in external_outputs:
115
                step_model.net.AddExternalOutput(state)
116

117
        if outputs_with_grads is None:
118
            outputs_with_grads = [self.get_output_state_index() * 2]
119

120
        # states_for_all_steps consists of combination of
121
        # states gather for all steps and final states. It looks like this:
122
        # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
123
        states_for_all_steps = recurrent.recurrent_net(
124
            net=model.net,
125
            cell_net=step_model.net,
126
            inputs=[(input_t, preprocessed_inputs)],
127
            initial_cell_inputs=list(zip(states_prev, initial_states)),
128
            links=dict(zip(states_prev, states)),
129
            timestep=timestep,
130
            scope=self.name,
131
            forward_only=self.forward_only,
132
            outputs_with_grads=outputs_with_grads,
133
            recompute_blobs_on_backward=self.recompute_blobs,
134
        )
135

136
        output = self._prepare_output_sequence(
137
            model,
138
            states_for_all_steps,
139
        )
140
        return output, states_for_all_steps
141

142
    def apply(self, model, input_t, seq_lengths, states, timestep):
143
        input_t = self.prepare_input(model, input_t)
144
        states = self._apply(
145
            model, input_t, seq_lengths, states, timestep)
146
        output = self._prepare_output(model, states)
147
        return output, states
148

149
    def _apply(
150
        self,
151
        model, input_t, seq_lengths, states, timestep, extra_inputs=None
152
    ):
153
        '''
154
        This  method uses apply_override provided by a custom cell.
155
        On the top it takes care of applying self.scope() to all the outputs.
156
        While all the inputs stay within the scope this function was called
157
        from.
158
        '''
159
        args = self._rectify_apply_inputs(
160
            input_t, seq_lengths, states, timestep, extra_inputs)
161
        with core.NameScope(self.name):
162
            return self.apply_override(model, *args)
163

164
    def _rectify_apply_inputs(
165
            self, input_t, seq_lengths, states, timestep, extra_inputs):
166
        '''
167
        Before applying a scope we make sure that all external blob names
168
        are converted to blob reference. So further scoping doesn't affect them
169
        '''
170

171
        input_t, seq_lengths, timestep = _RectifyNames(
172
            [input_t, seq_lengths, timestep])
173
        states = _RectifyNames(states)
174
        if extra_inputs:
175
            extra_input_names, extra_input_sizes = zip(*extra_inputs)
176
            extra_inputs = _RectifyNames(extra_input_names)
177
            extra_inputs = zip(extra_input_names, extra_input_sizes)
178

179
        arg_names = inspect.getargspec(self.apply_override).args
180
        rectified = [input_t, seq_lengths, states, timestep]
181
        if 'extra_inputs' in arg_names:
182
            rectified.append(extra_inputs)
183
        return rectified
184

185

186
    def apply_override(
187
        self,
188
        model, input_t, seq_lengths, timestep, extra_inputs=None,
189
    ):
190
        '''
191
        A single step of a recurrent network to be implemented by each custom
192
        RNNCell.
193

194
        model: ModelHelper object new operators would be added to
195

196
        input_t: singlse input with shape (1, batch_size, input_dim)
197

198
        seq_lengths: blob containing sequence lengths which would be passed to
199
        LSTMUnit operator
200

201
        states: previous recurrent states
202

203
        timestep: current recurrent iteration. Could be used together with
204
        seq_lengths in order to determine, if some shorter sequences
205
        in the batch have already ended.
206

207
        extra_inputs: list of tuples (input, dim). specifies additional input
208
        which is not subject to prepare_input(). (useful when a cell is a
209
        component of a larger recurrent structure, e.g., attention)
210
        '''
211
        raise NotImplementedError('Abstract method')
212

213
    def prepare_input(self, model, input_blob):
214
        '''
215
        If some operations in _apply method depend only on the input,
216
        not on recurrent states, they could be computed in advance.
217

218
        model: ModelHelper object new operators would be added to
219

220
        input_blob: either the whole input sequence with shape
221
        (sequence_length, batch_size, input_dim) or a single input with shape
222
        (1, batch_size, input_dim).
223
        '''
224
        return input_blob
225

226
    def get_output_state_index(self):
227
        '''
228
        Return index into state list of the "primary" step-wise output.
229
        '''
230
        return 0
231

232
    def get_state_names(self):
233
        '''
234
        Returns recurrent state names with self.name scoping applied
235
        '''
236
        return [self.scope(name) for name in self.get_state_names_override()]
237

238
    def get_state_names_override(self):
239
        '''
240
        Override this function in your custom cell.
241
        It should return the names of the recurrent states.
242

243
        It's required by apply_over_sequence method in order to allocate
244
        recurrent states for all steps with meaningful names.
245
        '''
246
        raise NotImplementedError('Abstract method')
247

248
    def get_output_dim(self):
249
        '''
250
        Specifies the dimension (number of units) of stepwise output.
251
        '''
252
        raise NotImplementedError('Abstract method')
253

254
    def _prepare_output(self, model, states):
255
        '''
256
        Allows arbitrary post-processing of primary output.
257
        '''
258
        return states[self.get_output_state_index()]
259

260
    def _prepare_output_sequence(self, model, state_outputs):
261
        '''
262
        Allows arbitrary post-processing of primary sequence output.
263

264
        (Note that state_outputs alternates between full-sequence and final
265
        output for each state, thus the index multiplier 2.)
266
        '''
267
        output_sequence_index = 2 * self.get_output_state_index()
268
        return state_outputs[output_sequence_index]
269

270

271
class LSTMInitializer:
272
    def __init__(self, hidden_size):
273
        self.hidden_size = hidden_size
274

275
    def create_states(self, model):
276
        return [
277
            model.create_param(
278
                param_name='initial_hidden_state',
279
                initializer=Initializer(operator_name='ConstantFill',
280
                                        value=0.0),
281
                shape=[self.hidden_size],
282
            ),
283
            model.create_param(
284
                param_name='initial_cell_state',
285
                initializer=Initializer(operator_name='ConstantFill',
286
                                        value=0.0),
287
                shape=[self.hidden_size],
288
            )
289
        ]
290

291

292
# based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
293
class BasicRNNCell(RNNCell):
294
    def __init__(
295
        self,
296
        input_size,
297
        hidden_size,
298
        forget_bias,
299
        memory_optimization,
300
        drop_states=False,
301
        initializer=None,
302
        activation=None,
303
        **kwargs
304
    ):
305
        super().__init__(**kwargs)
306
        self.drop_states = drop_states
307
        self.input_size = input_size
308
        self.hidden_size = hidden_size
309
        self.activation = activation
310

311
        if self.activation not in ['relu', 'tanh']:
312
            raise RuntimeError(
313
                'BasicRNNCell with unknown activation function (%s)'
314
                % self.activation)
315

316
    def apply_override(
317
        self,
318
        model,
319
        input_t,
320
        seq_lengths,
321
        states,
322
        timestep,
323
        extra_inputs=None,
324
    ):
325
        hidden_t_prev = states[0]
326

327
        gates_t = brew.fc(
328
            model,
329
            hidden_t_prev,
330
            'gates_t',
331
            dim_in=self.hidden_size,
332
            dim_out=self.hidden_size,
333
            axis=2,
334
        )
335

336
        brew.sum(model, [gates_t, input_t], gates_t)
337
        if self.activation == 'tanh':
338
            hidden_t = model.net.Tanh(gates_t, 'hidden_t')
339
        elif self.activation == 'relu':
340
            hidden_t = model.net.Relu(gates_t, 'hidden_t')
341
        else:
342
            raise RuntimeError(
343
                'BasicRNNCell with unknown activation function (%s)'
344
                % self.activation)
345

346
        if seq_lengths is not None:
347
            # TODO If this codepath becomes popular, it may be worth
348
            # taking a look at optimizing it - for now a simple
349
            # implementation is used to round out compatibility with
350
            # ONNX.
351
            timestep = model.net.CopyFromCPUInput(
352
                timestep, 'timestep_gpu')
353
            valid_b = model.net.GT(
354
                [seq_lengths, timestep], 'valid_b', broadcast=1)
355
            invalid_b = model.net.LE(
356
                [seq_lengths, timestep], 'invalid_b', broadcast=1)
357
            valid = model.net.Cast(valid_b, 'valid', to='float')
358
            invalid = model.net.Cast(invalid_b, 'invalid', to='float')
359

360
            hidden_valid = model.net.Mul(
361
                [hidden_t, valid],
362
                'hidden_valid',
363
                broadcast=1,
364
                axis=1,
365
            )
366
            if self.drop_states:
367
                hidden_t = hidden_valid
368
            else:
369
                hidden_invalid = model.net.Mul(
370
                    [hidden_t_prev, invalid],
371
                    'hidden_invalid',
372
                    broadcast=1, axis=1)
373
                hidden_t = model.net.Add(
374
                    [hidden_valid, hidden_invalid], hidden_t)
375
        return (hidden_t,)
376

377
    def prepare_input(self, model, input_blob):
378
        return brew.fc(
379
            model,
380
            input_blob,
381
            self.scope('i2h'),
382
            dim_in=self.input_size,
383
            dim_out=self.hidden_size,
384
            axis=2,
385
        )
386

387
    def get_state_names(self):
388
        return (self.scope('hidden_t'),)
389

390
    def get_output_dim(self):
391
        return self.hidden_size
392

393

394
class LSTMCell(RNNCell):
395

396
    def __init__(
397
        self,
398
        input_size,
399
        hidden_size,
400
        forget_bias,
401
        memory_optimization,
402
        drop_states=False,
403
        initializer=None,
404
        **kwargs
405
    ):
406
        super().__init__(initializer=initializer, **kwargs)
407
        self.initializer = initializer or LSTMInitializer(
408
            hidden_size=hidden_size)
409

410
        self.input_size = input_size
411
        self.hidden_size = hidden_size
412
        self.forget_bias = float(forget_bias)
413
        self.memory_optimization = memory_optimization
414
        self.drop_states = drop_states
415
        self.gates_size = 4 * self.hidden_size
416

417
    def apply_override(
418
        self,
419
        model,
420
        input_t,
421
        seq_lengths,
422
        states,
423
        timestep,
424
        extra_inputs=None,
425
    ):
426
        hidden_t_prev, cell_t_prev = states
427

428
        fc_input = hidden_t_prev
429
        fc_input_dim = self.hidden_size
430

431
        if extra_inputs is not None:
432
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
433
            fc_input = brew.concat(
434
                model,
435
                [hidden_t_prev] + list(extra_input_blobs),
436
                'gates_concatenated_input_t',
437
                axis=2,
438
            )
439
            fc_input_dim += sum(extra_input_sizes)
440

441
        gates_t = brew.fc(
442
            model,
443
            fc_input,
444
            'gates_t',
445
            dim_in=fc_input_dim,
446
            dim_out=self.gates_size,
447
            axis=2,
448
        )
449
        brew.sum(model, [gates_t, input_t], gates_t)
450

451
        if seq_lengths is not None:
452
            inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
453
        else:
454
            inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
455

456
        hidden_t, cell_t = model.net.LSTMUnit(
457
            inputs,
458
            ['hidden_state', 'cell_state'],
459
            forget_bias=self.forget_bias,
460
            drop_states=self.drop_states,
461
            sequence_lengths=(seq_lengths is not None),
462
        )
463
        model.net.AddExternalOutputs(hidden_t, cell_t)
464
        if self.memory_optimization:
465
            self.recompute_blobs = [gates_t]
466

467
        return hidden_t, cell_t
468

469
    def get_input_params(self):
470
        return {
471
            'weights': self.scope('i2h') + '_w',
472
            'biases': self.scope('i2h') + '_b',
473
        }
474

475
    def get_recurrent_params(self):
476
        return {
477
            'weights': self.scope('gates_t') + '_w',
478
            'biases': self.scope('gates_t') + '_b',
479
        }
480

481
    def prepare_input(self, model, input_blob):
482
        return brew.fc(
483
            model,
484
            input_blob,
485
            self.scope('i2h'),
486
            dim_in=self.input_size,
487
            dim_out=self.gates_size,
488
            axis=2,
489
        )
490

491
    def get_state_names_override(self):
492
        return ['hidden_t', 'cell_t']
493

494
    def get_output_dim(self):
495
        return self.hidden_size
496

497

498
class LayerNormLSTMCell(RNNCell):
499

500
    def __init__(
501
        self,
502
        input_size,
503
        hidden_size,
504
        forget_bias,
505
        memory_optimization,
506
        drop_states=False,
507
        initializer=None,
508
        **kwargs
509
    ):
510
        super().__init__(initializer=initializer, **kwargs)
511
        self.initializer = initializer or LSTMInitializer(
512
            hidden_size=hidden_size
513
        )
514

515
        self.input_size = input_size
516
        self.hidden_size = hidden_size
517
        self.forget_bias = float(forget_bias)
518
        self.memory_optimization = memory_optimization
519
        self.drop_states = drop_states
520
        self.gates_size = 4 * self.hidden_size
521

522
    def _apply(
523
        self,
524
        model,
525
        input_t,
526
        seq_lengths,
527
        states,
528
        timestep,
529
        extra_inputs=None,
530
    ):
531
        hidden_t_prev, cell_t_prev = states
532

533
        fc_input = hidden_t_prev
534
        fc_input_dim = self.hidden_size
535

536
        if extra_inputs is not None:
537
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
538
            fc_input = brew.concat(
539
                model,
540
                [hidden_t_prev] + list(extra_input_blobs),
541
                self.scope('gates_concatenated_input_t'),
542
                axis=2,
543
            )
544
            fc_input_dim += sum(extra_input_sizes)
545

546
        gates_t = brew.fc(
547
            model,
548
            fc_input,
549
            self.scope('gates_t'),
550
            dim_in=fc_input_dim,
551
            dim_out=self.gates_size,
552
            axis=2,
553
        )
554
        brew.sum(model, [gates_t, input_t], gates_t)
555

556
        # brew.layer_norm call is only difference from LSTMCell
557
        gates_t, _, _ = brew.layer_norm(
558
            model,
559
            self.scope('gates_t'),
560
            self.scope('gates_t_norm'),
561
            dim_in=self.gates_size,
562
            axis=-1,
563
        )
564

565
        hidden_t, cell_t = model.net.LSTMUnit(
566
            [
567
                hidden_t_prev,
568
                cell_t_prev,
569
                gates_t,
570
                seq_lengths,
571
                timestep,
572
            ],
573
            self.get_state_names(),
574
            forget_bias=self.forget_bias,
575
            drop_states=self.drop_states,
576
        )
577
        model.net.AddExternalOutputs(hidden_t, cell_t)
578
        if self.memory_optimization:
579
            self.recompute_blobs = [gates_t]
580

581
        return hidden_t, cell_t
582

583
    def get_input_params(self):
584
        return {
585
            'weights': self.scope('i2h') + '_w',
586
            'biases': self.scope('i2h') + '_b',
587
        }
588

589
    def prepare_input(self, model, input_blob):
590
        return brew.fc(
591
            model,
592
            input_blob,
593
            self.scope('i2h'),
594
            dim_in=self.input_size,
595
            dim_out=self.gates_size,
596
            axis=2,
597
        )
598

599
    def get_state_names(self):
600
        return (self.scope('hidden_t'), self.scope('cell_t'))
601

602

603
class MILSTMCell(LSTMCell):
604

605
    def _apply(
606
        self,
607
        model,
608
        input_t,
609
        seq_lengths,
610
        states,
611
        timestep,
612
        extra_inputs=None,
613
    ):
614
        hidden_t_prev, cell_t_prev = states
615

616
        fc_input = hidden_t_prev
617
        fc_input_dim = self.hidden_size
618

619
        if extra_inputs is not None:
620
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
621
            fc_input = brew.concat(
622
                model,
623
                [hidden_t_prev] + list(extra_input_blobs),
624
                self.scope('gates_concatenated_input_t'),
625
                axis=2,
626
            )
627
            fc_input_dim += sum(extra_input_sizes)
628

629
        prev_t = brew.fc(
630
            model,
631
            fc_input,
632
            self.scope('prev_t'),
633
            dim_in=fc_input_dim,
634
            dim_out=self.gates_size,
635
            axis=2,
636
        )
637

638
        # defining initializers for MI parameters
639
        alpha = model.create_param(
640
            self.scope('alpha'),
641
            shape=[self.gates_size],
642
            initializer=Initializer('ConstantFill', value=1.0),
643
        )
644
        beta_h = model.create_param(
645
            self.scope('beta1'),
646
            shape=[self.gates_size],
647
            initializer=Initializer('ConstantFill', value=1.0),
648
        )
649
        beta_i = model.create_param(
650
            self.scope('beta2'),
651
            shape=[self.gates_size],
652
            initializer=Initializer('ConstantFill', value=1.0),
653
        )
654
        b = model.create_param(
655
            self.scope('b'),
656
            shape=[self.gates_size],
657
            initializer=Initializer('ConstantFill', value=0.0),
658
        )
659

660
        # alpha * input_t + beta_h
661
        # Shape: [1, batch_size, 4 * hidden_size]
662
        alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
663
            [input_t, alpha, beta_h],
664
            self.scope('alpha_by_input_t_plus_beta_h'),
665
            axis=2,
666
        )
667
        # (alpha * input_t + beta_h) * prev_t =
668
        # alpha * input_t * prev_t + beta_h * prev_t
669
        # Shape: [1, batch_size, 4 * hidden_size]
670
        alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
671
            [alpha_by_input_t_plus_beta_h, prev_t],
672
            self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
673
        )
674
        # beta_i * input_t + b
675
        # Shape: [1, batch_size, 4 * hidden_size]
676
        beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
677
            [input_t, beta_i, b],
678
            self.scope('beta_i_by_input_t_plus_b'),
679
            axis=2,
680
        )
681
        # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
682
        # Shape: [1, batch_size, 4 * hidden_size]
683
        gates_t = brew.sum(
684
            model,
685
            [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
686
            self.scope('gates_t')
687
        )
688
        hidden_t, cell_t = model.net.LSTMUnit(
689
            [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
690
            [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
691
            forget_bias=self.forget_bias,
692
            drop_states=self.drop_states,
693
        )
694
        model.net.AddExternalOutputs(
695
            cell_t,
696
            hidden_t,
697
        )
698
        if self.memory_optimization:
699
            self.recompute_blobs = [gates_t]
700
        return hidden_t, cell_t
701

702

703
class LayerNormMILSTMCell(LSTMCell):
704

705
    def _apply(
706
        self,
707
        model,
708
        input_t,
709
        seq_lengths,
710
        states,
711
        timestep,
712
        extra_inputs=None,
713
    ):
714
        hidden_t_prev, cell_t_prev = states
715

716
        fc_input = hidden_t_prev
717
        fc_input_dim = self.hidden_size
718

719
        if extra_inputs is not None:
720
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
721
            fc_input = brew.concat(
722
                model,
723
                [hidden_t_prev] + list(extra_input_blobs),
724
                self.scope('gates_concatenated_input_t'),
725
                axis=2,
726
            )
727
            fc_input_dim += sum(extra_input_sizes)
728

729
        prev_t = brew.fc(
730
            model,
731
            fc_input,
732
            self.scope('prev_t'),
733
            dim_in=fc_input_dim,
734
            dim_out=self.gates_size,
735
            axis=2,
736
        )
737

738
        # defining initializers for MI parameters
739
        alpha = model.create_param(
740
            self.scope('alpha'),
741
            shape=[self.gates_size],
742
            initializer=Initializer('ConstantFill', value=1.0),
743
        )
744
        beta_h = model.create_param(
745
            self.scope('beta1'),
746
            shape=[self.gates_size],
747
            initializer=Initializer('ConstantFill', value=1.0),
748
        )
749
        beta_i = model.create_param(
750
            self.scope('beta2'),
751
            shape=[self.gates_size],
752
            initializer=Initializer('ConstantFill', value=1.0),
753
        )
754
        b = model.create_param(
755
            self.scope('b'),
756
            shape=[self.gates_size],
757
            initializer=Initializer('ConstantFill', value=0.0),
758
        )
759

760
        # alpha * input_t + beta_h
761
        # Shape: [1, batch_size, 4 * hidden_size]
762
        alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
763
            [input_t, alpha, beta_h],
764
            self.scope('alpha_by_input_t_plus_beta_h'),
765
            axis=2,
766
        )
767
        # (alpha * input_t + beta_h) * prev_t =
768
        # alpha * input_t * prev_t + beta_h * prev_t
769
        # Shape: [1, batch_size, 4 * hidden_size]
770
        alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
771
            [alpha_by_input_t_plus_beta_h, prev_t],
772
            self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
773
        )
774
        # beta_i * input_t + b
775
        # Shape: [1, batch_size, 4 * hidden_size]
776
        beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
777
            [input_t, beta_i, b],
778
            self.scope('beta_i_by_input_t_plus_b'),
779
            axis=2,
780
        )
781
        # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
782
        # Shape: [1, batch_size, 4 * hidden_size]
783
        gates_t = brew.sum(
784
            model,
785
            [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
786
            self.scope('gates_t')
787
        )
788
        # brew.layer_norm call is only difference from MILSTMCell._apply
789
        gates_t, _, _ = brew.layer_norm(
790
            model,
791
            self.scope('gates_t'),
792
            self.scope('gates_t_norm'),
793
            dim_in=self.gates_size,
794
            axis=-1,
795
        )
796
        hidden_t, cell_t = model.net.LSTMUnit(
797
            [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
798
            [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
799
            forget_bias=self.forget_bias,
800
            drop_states=self.drop_states,
801
        )
802
        model.net.AddExternalOutputs(
803
            cell_t,
804
            hidden_t,
805
        )
806
        if self.memory_optimization:
807
            self.recompute_blobs = [gates_t]
808
        return hidden_t, cell_t
809

810

811
class DropoutCell(RNNCell):
812
    '''
813
    Wraps arbitrary RNNCell, applying dropout to its output (but not to the
814
    recurrent connection for the corresponding state).
815
    '''
816

817
    def __init__(
818
        self,
819
        internal_cell,
820
        dropout_ratio=None,
821
        use_cudnn=False,
822
        **kwargs
823
    ):
824
        self.internal_cell = internal_cell
825
        self.dropout_ratio = dropout_ratio
826
        assert 'is_test' in kwargs, "Argument 'is_test' is required"
827
        self.is_test = kwargs.pop('is_test')
828
        self.use_cudnn = use_cudnn
829
        super().__init__(**kwargs)
830

831
        self.prepare_input = internal_cell.prepare_input
832
        self.get_output_state_index = internal_cell.get_output_state_index
833
        self.get_state_names = internal_cell.get_state_names
834
        self.get_output_dim = internal_cell.get_output_dim
835

836
        self.mask = 0
837

838
    def _apply(
839
        self,
840
        model,
841
        input_t,
842
        seq_lengths,
843
        states,
844
        timestep,
845
        extra_inputs=None,
846
    ):
847
        return self.internal_cell._apply(
848
            model,
849
            input_t,
850
            seq_lengths,
851
            states,
852
            timestep,
853
            extra_inputs,
854
        )
855

856
    def _prepare_output(self, model, states):
857
        output = self.internal_cell._prepare_output(
858
            model,
859
            states,
860
        )
861
        if self.dropout_ratio is not None:
862
            output = self._apply_dropout(model, output)
863
        return output
864

865
    def _prepare_output_sequence(self, model, state_outputs):
866
        output = self.internal_cell._prepare_output_sequence(
867
            model,
868
            state_outputs,
869
        )
870
        if self.dropout_ratio is not None:
871
            output = self._apply_dropout(model, output)
872
        return output
873

874
    def _apply_dropout(self, model, output):
875
        if self.dropout_ratio and not self.forward_only:
876
            with core.NameScope(self.name or ''):
877
                output = brew.dropout(
878
                    model,
879
                    output,
880
                    str(output) + '_with_dropout_mask{}'.format(self.mask),
881
                    ratio=float(self.dropout_ratio),
882
                    is_test=self.is_test,
883
                    use_cudnn=self.use_cudnn,
884
                )
885
                self.mask += 1
886
        return output
887

888

889
class MultiRNNCellInitializer:
890
    def __init__(self, cells):
891
        self.cells = cells
892

893
    def create_states(self, model):
894
        states = []
895
        for i, cell in enumerate(self.cells):
896
            if cell.initializer is None:
897
                raise Exception("Either initial states "
898
                                "or initializer have to be set")
899

900
            with core.NameScope("layer_{}".format(i)),\
901
                    core.NameScope(cell.name):
902
                states.extend(cell.initializer.create_states(model))
903
        return states
904

905

906
class MultiRNNCell(RNNCell):
907
    '''
908
    Multilayer RNN via the composition of RNNCell instance.
909

910
    It is the responsibility of calling code to ensure the compatibility
911
    of the successive layers in terms of input/output dimensiality, etc.,
912
    and to ensure that their blobs do not have name conflicts, typically by
913
    creating the cells with names that specify layer number.
914

915
    Assumes first state (recurrent output) for each layer should be the input
916
    to the next layer.
917
    '''
918

919
    def __init__(self, cells, residual_output_layers=None, **kwargs):
920
        '''
921
        cells: list of RNNCell instances, from input to output side.
922

923
        name: string designating network component (for scoping)
924

925
        residual_output_layers: list of indices of layers whose input will
926
        be added elementwise to their output elementwise. (It is the
927
        responsibility of the client code to ensure shape compatibility.)
928
        Note that layer 0 (zero) cannot have residual output because of the
929
        timing of prepare_input().
930

931
        forward_only: used to construct inference-only network.
932
        '''
933
        super().__init__(**kwargs)
934
        self.cells = cells
935

936
        if residual_output_layers is None:
937
            self.residual_output_layers = []
938
        else:
939
            self.residual_output_layers = residual_output_layers
940

941
        output_index_per_layer = []
942
        base_index = 0
943
        for cell in self.cells:
944
            output_index_per_layer.append(
945
                base_index + cell.get_output_state_index(),
946
            )
947
            base_index += len(cell.get_state_names())
948

949
        self.output_connected_layers = []
950
        self.output_indices = []
951
        for i in range(len(self.cells) - 1):
952
            if (i + 1) in self.residual_output_layers:
953
                self.output_connected_layers.append(i)
954
                self.output_indices.append(output_index_per_layer[i])
955
            else:
956
                self.output_connected_layers = []
957
                self.output_indices = []
958
        self.output_connected_layers.append(len(self.cells) - 1)
959
        self.output_indices.append(output_index_per_layer[-1])
960

961
        self.state_names = []
962
        for i, cell in enumerate(self.cells):
963
            self.state_names.extend(
964
                map(self.layer_scoper(i), cell.get_state_names())
965
            )
966

967
        self.initializer = MultiRNNCellInitializer(cells)
968

969
    def layer_scoper(self, layer_id):
970
        def helper(name):
971
            return "{}/layer_{}/{}".format(self.name, layer_id, name)
972
        return helper
973

974
    def prepare_input(self, model, input_blob):
975
        input_blob = _RectifyName(input_blob)
976
        with core.NameScope(self.name or ''):
977
            return self.cells[0].prepare_input(model, input_blob)
978

979
    def _apply(
980
        self,
981
        model,
982
        input_t,
983
        seq_lengths,
984
        states,
985
        timestep,
986
        extra_inputs=None,
987
    ):
988
        '''
989
        Because below we will do scoping across layers, we need
990
        to make sure that string blob names are convereted to BlobReference
991
        objects.
992
        '''
993

994
        input_t, seq_lengths, states, timestep, extra_inputs = \
995
            self._rectify_apply_inputs(
996
                input_t, seq_lengths, states, timestep, extra_inputs)
997

998
        states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
999
        assert len(states) == sum(states_per_layer)
1000

1001
        next_states = []
1002
        states_index = 0
1003

1004
        layer_input = input_t
1005
        for i, layer_cell in enumerate(self.cells):
1006
            # # If cells don't have different names we still
1007
            # take care of scoping
1008
            with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
1009
                num_states = states_per_layer[i]
1010
                layer_states = states[states_index:(states_index + num_states)]
1011
                states_index += num_states
1012

1013
                if i > 0:
1014
                    prepared_input = layer_cell.prepare_input(
1015
                        model, layer_input)
1016
                else:
1017
                    prepared_input = layer_input
1018

1019
                layer_next_states = layer_cell._apply(
1020
                    model,
1021
                    prepared_input,
1022
                    seq_lengths,
1023
                    layer_states,
1024
                    timestep,
1025
                    extra_inputs=(None if i > 0 else extra_inputs),
1026
                )
1027
                # Since we're using here non-public method _apply,
1028
                # instead of apply, we have to manually extract output
1029
                # from states
1030
                if i != len(self.cells) - 1:
1031
                    layer_output = layer_cell._prepare_output(
1032
                        model,
1033
                        layer_next_states,
1034
                    )
1035
                    if i > 0 and i in self.residual_output_layers:
1036
                        layer_input = brew.sum(
1037
                            model,
1038
                            [layer_output, layer_input],
1039
                            self.scope('residual_output_{}'.format(i)),
1040
                        )
1041
                    else:
1042
                        layer_input = layer_output
1043

1044
                next_states.extend(layer_next_states)
1045
        return next_states
1046

1047
    def get_state_names(self):
1048
        return self.state_names
1049

1050
    def get_output_state_index(self):
1051
        index = 0
1052
        for cell in self.cells[:-1]:
1053
            index += len(cell.get_state_names())
1054
        index += self.cells[-1].get_output_state_index()
1055
        return index
1056

1057
    def _prepare_output(self, model, states):
1058
        connected_outputs = []
1059
        state_index = 0
1060
        for i, cell in enumerate(self.cells):
1061
            num_states = len(cell.get_state_names())
1062
            if i in self.output_connected_layers:
1063
                layer_states = states[state_index:state_index + num_states]
1064
                layer_output = cell._prepare_output(
1065
                    model,
1066
                    layer_states
1067
                )
1068
                connected_outputs.append(layer_output)
1069
            state_index += num_states
1070
        if len(connected_outputs) > 1:
1071
            output = brew.sum(
1072
                model,
1073
                connected_outputs,
1074
                self.scope('residual_output'),
1075
            )
1076
        else:
1077
            output = connected_outputs[0]
1078
        return output
1079

1080
    def _prepare_output_sequence(self, model, states):
1081
        connected_outputs = []
1082
        state_index = 0
1083
        for i, cell in enumerate(self.cells):
1084
            num_states = 2 * len(cell.get_state_names())
1085
            if i in self.output_connected_layers:
1086
                layer_states = states[state_index:state_index + num_states]
1087
                layer_output = cell._prepare_output_sequence(
1088
                    model,
1089
                    layer_states
1090
                )
1091
                connected_outputs.append(layer_output)
1092
            state_index += num_states
1093
        if len(connected_outputs) > 1:
1094
            output = brew.sum(
1095
                model,
1096
                connected_outputs,
1097
                self.scope('residual_output_sequence'),
1098
            )
1099
        else:
1100
            output = connected_outputs[0]
1101
        return output
1102

1103

1104
class AttentionCell(RNNCell):
1105

1106
    def __init__(
1107
        self,
1108
        encoder_output_dim,
1109
        encoder_outputs,
1110
        encoder_lengths,
1111
        decoder_cell,
1112
        decoder_state_dim,
1113
        attention_type,
1114
        weighted_encoder_outputs,
1115
        attention_memory_optimization,
1116
        **kwargs
1117
    ):
1118
        super().__init__(**kwargs)
1119
        self.encoder_output_dim = encoder_output_dim
1120
        self.encoder_outputs = encoder_outputs
1121
        self.encoder_lengths = encoder_lengths
1122
        self.decoder_cell = decoder_cell
1123
        self.decoder_state_dim = decoder_state_dim
1124
        self.weighted_encoder_outputs = weighted_encoder_outputs
1125
        self.encoder_outputs_transposed = None
1126
        assert attention_type in [
1127
            AttentionType.Regular,
1128
            AttentionType.Recurrent,
1129
            AttentionType.Dot,
1130
            AttentionType.SoftCoverage,
1131
        ]
1132
        self.attention_type = attention_type
1133
        self.attention_memory_optimization = attention_memory_optimization
1134

1135
    def _apply(
1136
        self,
1137
        model,
1138
        input_t,
1139
        seq_lengths,
1140
        states,
1141
        timestep,
1142
        extra_inputs=None,
1143
    ):
1144
        if self.attention_type == AttentionType.SoftCoverage:
1145
            decoder_prev_states = states[:-2]
1146
            attention_weighted_encoder_context_t_prev = states[-2]
1147
            coverage_t_prev = states[-1]
1148
        else:
1149
            decoder_prev_states = states[:-1]
1150
            attention_weighted_encoder_context_t_prev = states[-1]
1151

1152
        assert extra_inputs is None
1153

1154
        decoder_states = self.decoder_cell._apply(
1155
            model,
1156
            input_t,
1157
            seq_lengths,
1158
            decoder_prev_states,
1159
            timestep,
1160
            extra_inputs=[(
1161
                attention_weighted_encoder_context_t_prev,
1162
                self.encoder_output_dim,
1163
            )],
1164
        )
1165

1166
        self.hidden_t_intermediate = self.decoder_cell._prepare_output(
1167
            model,
1168
            decoder_states,
1169
        )
1170

1171
        if self.attention_type == AttentionType.Recurrent:
1172
            (
1173
                attention_weighted_encoder_context_t,
1174
                self.attention_weights_3d,
1175
                attention_blobs,
1176
            ) = apply_recurrent_attention(
1177
                model=model,
1178
                encoder_output_dim=self.encoder_output_dim,
1179
                encoder_outputs_transposed=self.encoder_outputs_transposed,
1180
                weighted_encoder_outputs=self.weighted_encoder_outputs,
1181
                decoder_hidden_state_t=self.hidden_t_intermediate,
1182
                decoder_hidden_state_dim=self.decoder_state_dim,
1183
                scope=self.name,
1184
                attention_weighted_encoder_context_t_prev=(
1185
                    attention_weighted_encoder_context_t_prev
1186
                ),
1187
                encoder_lengths=self.encoder_lengths,
1188
            )
1189
        elif self.attention_type == AttentionType.Regular:
1190
            (
1191
                attention_weighted_encoder_context_t,
1192
                self.attention_weights_3d,
1193
                attention_blobs,
1194
            ) = apply_regular_attention(
1195
                model=model,
1196
                encoder_output_dim=self.encoder_output_dim,
1197
                encoder_outputs_transposed=self.encoder_outputs_transposed,
1198
                weighted_encoder_outputs=self.weighted_encoder_outputs,
1199
                decoder_hidden_state_t=self.hidden_t_intermediate,
1200
                decoder_hidden_state_dim=self.decoder_state_dim,
1201
                scope=self.name,
1202
                encoder_lengths=self.encoder_lengths,
1203
            )
1204
        elif self.attention_type == AttentionType.Dot:
1205
            (
1206
                attention_weighted_encoder_context_t,
1207
                self.attention_weights_3d,
1208
                attention_blobs,
1209
            ) = apply_dot_attention(
1210
                model=model,
1211
                encoder_output_dim=self.encoder_output_dim,
1212
                encoder_outputs_transposed=self.encoder_outputs_transposed,
1213
                decoder_hidden_state_t=self.hidden_t_intermediate,
1214
                decoder_hidden_state_dim=self.decoder_state_dim,
1215
                scope=self.name,
1216
                encoder_lengths=self.encoder_lengths,
1217
            )
1218
        elif self.attention_type == AttentionType.SoftCoverage:
1219
            (
1220
                attention_weighted_encoder_context_t,
1221
                self.attention_weights_3d,
1222
                attention_blobs,
1223
                coverage_t,
1224
            ) = apply_soft_coverage_attention(
1225
                model=model,
1226
                encoder_output_dim=self.encoder_output_dim,
1227
                encoder_outputs_transposed=self.encoder_outputs_transposed,
1228
                weighted_encoder_outputs=self.weighted_encoder_outputs,
1229
                decoder_hidden_state_t=self.hidden_t_intermediate,
1230
                decoder_hidden_state_dim=self.decoder_state_dim,
1231
                scope=self.name,
1232
                encoder_lengths=self.encoder_lengths,
1233
                coverage_t_prev=coverage_t_prev,
1234
                coverage_weights=self.coverage_weights,
1235
            )
1236
        else:
1237
            raise Exception('Attention type {} not implemented'.format(
1238
                self.attention_type
1239
            ))
1240

1241
        if self.attention_memory_optimization:
1242
            self.recompute_blobs.extend(attention_blobs)
1243

1244
        output = list(decoder_states) + [attention_weighted_encoder_context_t]
1245
        if self.attention_type == AttentionType.SoftCoverage:
1246
            output.append(coverage_t)
1247

1248
        output[self.decoder_cell.get_output_state_index()] = model.Copy(
1249
            output[self.decoder_cell.get_output_state_index()],
1250
            self.scope('hidden_t_external'),
1251
        )
1252
        model.net.AddExternalOutputs(*output)
1253

1254
        return output
1255

1256
    def get_attention_weights(self):
1257
        # [batch_size, encoder_length, 1]
1258
        return self.attention_weights_3d
1259

1260
    def prepare_input(self, model, input_blob):
1261
        if self.encoder_outputs_transposed is None:
1262
            self.encoder_outputs_transposed = brew.transpose(
1263
                model,
1264
                self.encoder_outputs,
1265
                self.scope('encoder_outputs_transposed'),
1266
                axes=[1, 2, 0],
1267
            )
1268
        if (
1269
            self.weighted_encoder_outputs is None and
1270
            self.attention_type != AttentionType.Dot
1271
        ):
1272
            self.weighted_encoder_outputs = brew.fc(
1273
                model,
1274
                self.encoder_outputs,
1275
                self.scope('weighted_encoder_outputs'),
1276
                dim_in=self.encoder_output_dim,
1277
                dim_out=self.encoder_output_dim,
1278
                axis=2,
1279
            )
1280

1281
        return self.decoder_cell.prepare_input(model, input_blob)
1282

1283
    def build_initial_coverage(self, model):
1284
        """
1285
        initial_coverage is always zeros of shape [encoder_length],
1286
        which shape must be determined programmatically dureing network
1287
        computation.
1288

1289
        This method also sets self.coverage_weights, a separate transform
1290
        of encoder_outputs which is used to determine coverage contribution
1291
        tp attention.
1292
        """
1293
        assert self.attention_type == AttentionType.SoftCoverage
1294

1295
        # [encoder_length, batch_size, encoder_output_dim]
1296
        self.coverage_weights = brew.fc(
1297
            model,
1298
            self.encoder_outputs,
1299
            self.scope('coverage_weights'),
1300
            dim_in=self.encoder_output_dim,
1301
            dim_out=self.encoder_output_dim,
1302
            axis=2,
1303
        )
1304

1305
        encoder_length = model.net.Slice(
1306
            model.net.Shape(self.encoder_outputs),
1307
            starts=[0],
1308
            ends=[1],
1309
        )
1310
        if (
1311
            scope.CurrentDeviceScope() is not None and
1312
            core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
1313
        ):
1314
            encoder_length = model.net.CopyGPUToCPU(
1315
                encoder_length,
1316
                'encoder_length_cpu',
1317
            )
1318
        # total attention weight applied across decoding steps_per_checkpoint
1319
        # shape: [encoder_length]
1320
        initial_coverage = model.net.ConstantFill(
1321
            encoder_length,
1322
            self.scope('initial_coverage'),
1323
            value=0.0,
1324
            input_as_shape=1,
1325
        )
1326
        return initial_coverage
1327

1328
    def get_state_names(self):
1329
        state_names = list(self.decoder_cell.get_state_names())
1330
        state_names[self.get_output_state_index()] = self.scope(
1331
            'hidden_t_external',
1332
        )
1333
        state_names.append(self.scope('attention_weighted_encoder_context_t'))
1334
        if self.attention_type == AttentionType.SoftCoverage:
1335
            state_names.append(self.scope('coverage_t'))
1336
        return state_names
1337

1338
    def get_output_dim(self):
1339
        return self.decoder_state_dim + self.encoder_output_dim
1340

1341
    def get_output_state_index(self):
1342
        return self.decoder_cell.get_output_state_index()
1343

1344
    def _prepare_output(self, model, states):
1345
        if self.attention_type == AttentionType.SoftCoverage:
1346
            attention_context = states[-2]
1347
        else:
1348
            attention_context = states[-1]
1349

1350
        with core.NameScope(self.name or ''):
1351
            output = brew.concat(
1352
                model,
1353
                [self.hidden_t_intermediate, attention_context],
1354
                'states_and_context_combination',
1355
                axis=2,
1356
            )
1357

1358
        return output
1359

1360
    def _prepare_output_sequence(self, model, state_outputs):
1361
        if self.attention_type == AttentionType.SoftCoverage:
1362
            decoder_state_outputs = state_outputs[:-4]
1363
        else:
1364
            decoder_state_outputs = state_outputs[:-2]
1365

1366
        decoder_output = self.decoder_cell._prepare_output_sequence(
1367
            model,
1368
            decoder_state_outputs,
1369
        )
1370

1371
        if self.attention_type == AttentionType.SoftCoverage:
1372
            attention_context_index = 2 * (len(self.get_state_names()) - 2)
1373
        else:
1374
            attention_context_index = 2 * (len(self.get_state_names()) - 1)
1375

1376
        with core.NameScope(self.name or ''):
1377
            output = brew.concat(
1378
                model,
1379
                [
1380
                    decoder_output,
1381
                    state_outputs[attention_context_index],
1382
                ],
1383
                'states_and_context_combination',
1384
                axis=2,
1385
            )
1386
        return output
1387

1388

1389
class LSTMWithAttentionCell(AttentionCell):
1390

1391
    def __init__(
1392
        self,
1393
        encoder_output_dim,
1394
        encoder_outputs,
1395
        encoder_lengths,
1396
        decoder_input_dim,
1397
        decoder_state_dim,
1398
        name,
1399
        attention_type,
1400
        weighted_encoder_outputs,
1401
        forget_bias,
1402
        lstm_memory_optimization,
1403
        attention_memory_optimization,
1404
        forward_only=False,
1405
    ):
1406
        decoder_cell = LSTMCell(
1407
            input_size=decoder_input_dim,
1408
            hidden_size=decoder_state_dim,
1409
            forget_bias=forget_bias,
1410
            memory_optimization=lstm_memory_optimization,
1411
            name='{}/decoder'.format(name),
1412
            forward_only=False,
1413
            drop_states=False,
1414
        )
1415
        super().__init__(
1416
            encoder_output_dim=encoder_output_dim,
1417
            encoder_outputs=encoder_outputs,
1418
            encoder_lengths=encoder_lengths,
1419
            decoder_cell=decoder_cell,
1420
            decoder_state_dim=decoder_state_dim,
1421
            name=name,
1422
            attention_type=attention_type,
1423
            weighted_encoder_outputs=weighted_encoder_outputs,
1424
            attention_memory_optimization=attention_memory_optimization,
1425
            forward_only=forward_only,
1426
        )
1427

1428

1429
class MILSTMWithAttentionCell(AttentionCell):
1430

1431
    def __init__(
1432
        self,
1433
        encoder_output_dim,
1434
        encoder_outputs,
1435
        decoder_input_dim,
1436
        decoder_state_dim,
1437
        name,
1438
        attention_type,
1439
        weighted_encoder_outputs,
1440
        forget_bias,
1441
        lstm_memory_optimization,
1442
        attention_memory_optimization,
1443
        forward_only=False,
1444
    ):
1445
        decoder_cell = MILSTMCell(
1446
            input_size=decoder_input_dim,
1447
            hidden_size=decoder_state_dim,
1448
            forget_bias=forget_bias,
1449
            memory_optimization=lstm_memory_optimization,
1450
            name='{}/decoder'.format(name),
1451
            forward_only=False,
1452
            drop_states=False,
1453
        )
1454
        super().__init__(
1455
            encoder_output_dim=encoder_output_dim,
1456
            encoder_outputs=encoder_outputs,
1457
            decoder_cell=decoder_cell,
1458
            decoder_state_dim=decoder_state_dim,
1459
            name=name,
1460
            attention_type=attention_type,
1461
            weighted_encoder_outputs=weighted_encoder_outputs,
1462
            attention_memory_optimization=attention_memory_optimization,
1463
            forward_only=forward_only,
1464
        )
1465

1466

1467
def _LSTM(
1468
    cell_class,
1469
    model,
1470
    input_blob,
1471
    seq_lengths,
1472
    initial_states,
1473
    dim_in,
1474
    dim_out,
1475
    scope=None,
1476
    outputs_with_grads=(0,),
1477
    return_params=False,
1478
    memory_optimization=False,
1479
    forget_bias=0.0,
1480
    forward_only=False,
1481
    drop_states=False,
1482
    return_last_layer_only=True,
1483
    static_rnn_unroll_size=None,
1484
    **cell_kwargs
1485
):
1486
    '''
1487
    Adds a standard LSTM recurrent network operator to a model.
1488

1489
    cell_class: LSTMCell or compatible subclass
1490

1491
    model: ModelHelper object new operators would be added to
1492

1493
    input_blob: the input sequence in a format T x N x D
1494
            where T is sequence size, N - batch size and D - input dimension
1495

1496
    seq_lengths: blob containing sequence lengths which would be passed to
1497
            LSTMUnit operator
1498

1499
    initial_states: a list of (2 * num_layers) blobs representing the initial
1500
            hidden and cell states of each layer. If this argument is None,
1501
            these states will be added to the model as network parameters.
1502

1503
    dim_in: input dimension
1504

1505
    dim_out: number of units per LSTM layer
1506
            (use int for single-layer LSTM, list of ints for multi-layer)
1507

1508
    outputs_with_grads : position indices of output blobs for LAST LAYER which
1509
            will receive external error gradient during backpropagation.
1510
            These outputs are: (h_all, h_last, c_all, c_last)
1511

1512
    return_params: if True, will return a dictionary of parameters of the LSTM
1513

1514
    memory_optimization: if enabled, the LSTM step is recomputed on backward
1515
            step so that we don't need to store forward activations for each
1516
            timestep. Saves memory with cost of computation.
1517

1518
    forget_bias: forget gate bias (default 0.0)
1519

1520
    forward_only: whether to create a backward pass
1521

1522
    drop_states: drop invalid states, passed through to LSTMUnit operator
1523

1524
    return_last_layer_only: only return outputs from final layer
1525
            (so that length of results does depend on number of layers)
1526

1527
    static_rnn_unroll_size: if not None, we will use static RNN which is
1528
    unrolled into Caffe2 graph. The size of the unroll is the value of
1529
    this parameter.
1530
    '''
1531
    if type(dim_out) is not list and type(dim_out) is not tuple:
1532
        dim_out = [dim_out]
1533
    num_layers = len(dim_out)
1534

1535
    cells = []
1536
    for i in range(num_layers):
1537
        cell = cell_class(
1538
            input_size=(dim_in if i == 0 else dim_out[i - 1]),
1539
            hidden_size=dim_out[i],
1540
            forget_bias=forget_bias,
1541
            memory_optimization=memory_optimization,
1542
            name=scope if num_layers == 1 else None,
1543
            forward_only=forward_only,
1544
            drop_states=drop_states,
1545
            **cell_kwargs
1546
        )
1547
        cells.append(cell)
1548

1549
    cell = MultiRNNCell(
1550
        cells,
1551
        name=scope,
1552
        forward_only=forward_only,
1553
    ) if num_layers > 1 else cells[0]
1554

1555
    cell = (
1556
        cell if static_rnn_unroll_size is None
1557
        else UnrolledCell(cell, static_rnn_unroll_size))
1558

1559
    # outputs_with_grads argument indexes into final layer
1560
    outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
1561
    _, result = cell.apply_over_sequence(
1562
        model=model,
1563
        inputs=input_blob,
1564
        seq_lengths=seq_lengths,
1565
        initial_states=initial_states,
1566
        outputs_with_grads=outputs_with_grads,
1567
    )
1568

1569
    if return_last_layer_only:
1570
        result = result[4 * (num_layers - 1):]
1571
    if return_params:
1572
        result = list(result) + [{
1573
            'input': cell.get_input_params(),
1574
            'recurrent': cell.get_recurrent_params(),
1575
        }]
1576
    return tuple(result)
1577

1578

1579
LSTM = functools.partial(_LSTM, LSTMCell)
1580
BasicRNN = functools.partial(_LSTM, BasicRNNCell)
1581
MILSTM = functools.partial(_LSTM, MILSTMCell)
1582
LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
1583
LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
1584

1585

1586
class UnrolledCell(RNNCell):
1587
    def __init__(self, cell, T):
1588
        self.T = T
1589
        self.cell = cell
1590

1591
    def apply_over_sequence(
1592
        self,
1593
        model,
1594
        inputs,
1595
        seq_lengths,
1596
        initial_states,
1597
        outputs_with_grads=None,
1598
    ):
1599
        inputs = self.cell.prepare_input(model, inputs)
1600

1601
        # Now they are blob references - outputs of splitting the input sequence
1602
        split_inputs = model.net.Split(
1603
            inputs,
1604
            [str(inputs) + "_timestep_{}".format(i)
1605
             for i in range(self.T)],
1606
            axis=0)
1607
        if self.T == 1:
1608
            split_inputs = [split_inputs]
1609

1610
        states = initial_states
1611
        all_states = []
1612
        for t in range(0, self.T):
1613
            scope_name = "timestep_{}".format(t)
1614
            # Parameters of all timesteps are shared
1615
            with ParameterSharing({scope_name: ''}),\
1616
                    scope.NameScope(scope_name):
1617
                timestep = model.param_init_net.ConstantFill(
1618
                    [], "timestep", value=t, shape=[1],
1619
                    dtype=core.DataType.INT32,
1620
                    device_option=core.DeviceOption(caffe2_pb2.CPU))
1621
                states = self.cell._apply(
1622
                    model=model,
1623
                    input_t=split_inputs[t],
1624
                    seq_lengths=seq_lengths,
1625
                    states=states,
1626
                    timestep=timestep,
1627
                )
1628
            all_states.append(states)
1629

1630
        all_states = zip(*all_states)
1631
        all_states = [
1632
            model.net.Concat(
1633
                list(full_output),
1634
                [
1635
                    str(full_output[0])[len("timestep_0/"):] + "_concat",
1636
                    str(full_output[0])[len("timestep_0/"):] + "_concat_info"
1637

1638
                ],
1639
                axis=0)[0]
1640
            for full_output in all_states
1641
        ]
1642
        # Interleave the state values similar to
1643
        #
1644
        #   x = [1, 3, 5]
1645
        #   y = [2, 4, 6]
1646
        #   z = [val for pair in zip(x, y) for val in pair]
1647
        #   # z is [1, 2, 3, 4, 5, 6]
1648
        #
1649
        # and returns it as outputs
1650
        outputs = tuple(
1651
            state for state_pair in zip(all_states, states) for state in state_pair
1652
        )
1653
        outputs_without_grad = set(range(len(outputs))) - set(
1654
            outputs_with_grads)
1655
        for i in outputs_without_grad:
1656
            model.net.ZeroGradient(outputs[i], [])
1657
        logging.debug("Added 0 gradients for blobs:",
1658
                      [outputs[i] for i in outputs_without_grad])
1659

1660
        final_output = self.cell._prepare_output_sequence(model, outputs)
1661

1662
        return final_output, outputs
1663

1664

1665
def GetLSTMParamNames():
1666
    weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
1667
    bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
1668
    return {'weights': weight_params, 'biases': bias_params}
1669

1670

1671
def InitFromLSTMParams(lstm_pblobs, param_values):
1672
    '''
1673
    Set the parameters of LSTM based on predefined values
1674
    '''
1675
    weight_params = GetLSTMParamNames()['weights']
1676
    bias_params = GetLSTMParamNames()['biases']
1677
    for input_type in param_values.keys():
1678
        weight_values = [
1679
            param_values[input_type][w].flatten()
1680
            for w in weight_params
1681
        ]
1682
        wmat = np.array([])
1683
        for w in weight_values:
1684
            wmat = np.append(wmat, w)
1685
        bias_values = [
1686
            param_values[input_type][b].flatten()
1687
            for b in bias_params
1688
        ]
1689
        bm = np.array([])
1690
        for b in bias_values:
1691
            bm = np.append(bm, b)
1692

1693
        weights_blob = lstm_pblobs[input_type]['weights']
1694
        bias_blob = lstm_pblobs[input_type]['biases']
1695
        cur_weight = workspace.FetchBlob(weights_blob)
1696
        cur_biases = workspace.FetchBlob(bias_blob)
1697

1698
        workspace.FeedBlob(
1699
            weights_blob,
1700
            wmat.reshape(cur_weight.shape).astype(np.float32))
1701
        workspace.FeedBlob(
1702
            bias_blob,
1703
            bm.reshape(cur_biases.shape).astype(np.float32))
1704

1705

1706
def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
1707
               scope, recurrent_params=None, input_params=None,
1708
               num_layers=1, return_params=False):
1709
    '''
1710
    CuDNN version of LSTM for GPUs.
1711
    input_blob          Blob containing the input. Will need to be available
1712
                        when param_init_net is run, because the sequence lengths
1713
                        and batch sizes will be inferred from the size of this
1714
                        blob.
1715
    initial_states      tuple of (hidden_init, cell_init) blobs
1716
    dim_in              input dimensions
1717
    dim_out             output/hidden dimension
1718
    scope               namescope to apply
1719
    recurrent_params    dict of blobs containing values for recurrent
1720
                        gate weights, biases (if None, use random init values)
1721
                        See GetLSTMParamNames() for format.
1722
    input_params        dict of blobs containing values for input
1723
                        gate weights, biases (if None, use random init values)
1724
                        See GetLSTMParamNames() for format.
1725
    num_layers          number of LSTM layers
1726
    return_params       if True, returns (param_extract_net, param_mapping)
1727
                        where param_extract_net is a net that when run, will
1728
                        populate the blobs specified in param_mapping with the
1729
                        current gate weights and biases (input/recurrent).
1730
                        Useful for assigning the values back to non-cuDNN
1731
                        LSTM.
1732
    '''
1733
    with core.NameScope(scope):
1734
        weight_params = GetLSTMParamNames()['weights']
1735
        bias_params = GetLSTMParamNames()['biases']
1736

1737
        input_weight_size = dim_out * dim_in
1738
        upper_layer_input_weight_size = dim_out * dim_out
1739
        recurrent_weight_size = dim_out * dim_out
1740
        input_bias_size = dim_out
1741
        recurrent_bias_size = dim_out
1742

1743
        def init(layer, pname, input_type):
1744
            input_weight_size_for_layer = input_weight_size if layer == 0 else \
1745
                upper_layer_input_weight_size
1746
            if pname in weight_params:
1747
                sz = input_weight_size_for_layer if input_type == 'input' \
1748
                    else recurrent_weight_size
1749
            elif pname in bias_params:
1750
                sz = input_bias_size if input_type == 'input' \
1751
                    else recurrent_bias_size
1752
            else:
1753
                assert False, "unknown parameter type {}".format(pname)
1754
            return model.param_init_net.UniformFill(
1755
                [],
1756
                "lstm_init_{}_{}_{}".format(input_type, pname, layer),
1757
                shape=[sz])
1758

1759
        # Multiply by 4 since we have 4 gates per LSTM unit
1760
        first_layer_sz = input_weight_size + recurrent_weight_size + \
1761
                         input_bias_size + recurrent_bias_size
1762
        upper_layer_sz = upper_layer_input_weight_size + \
1763
                         recurrent_weight_size + input_bias_size + \
1764
                         recurrent_bias_size
1765
        total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
1766

1767
        weights = model.create_param(
1768
            'lstm_weight',
1769
            shape=[total_sz],
1770
            initializer=Initializer('UniformFill'),
1771
            tags=ParameterTags.WEIGHT,
1772
        )
1773

1774
        lstm_args = {
1775
            'hidden_size': dim_out,
1776
            'rnn_mode': 'lstm',
1777
            'bidirectional': 0,  # TODO
1778
            'dropout': 1.0,  # TODO
1779
            'input_mode': 'linear',  # TODO
1780
            'num_layers': num_layers,
1781
            'engine': 'CUDNN'
1782
        }
1783

1784
        param_extract_net = core.Net("lstm_param_extractor")
1785
        param_extract_net.AddExternalInputs([input_blob, weights])
1786
        param_extract_mapping = {}
1787

1788
        # Populate the weights-blob from blobs containing parameters for
1789
        # the individual components of the LSTM, such as forget/input gate
1790
        # weights and bises. Also, create a special param_extract_net that
1791
        # can be used to grab those individual params from the black-box
1792
        # weights blob. These results can be then fed to InitFromLSTMParams()
1793
        for input_type in ['input', 'recurrent']:
1794
            param_extract_mapping[input_type] = {}
1795
            p = recurrent_params if input_type == 'recurrent' else input_params
1796
            if p is None:
1797
                p = {}
1798
            for pname in weight_params + bias_params:
1799
                for j in range(0, num_layers):
1800
                    values = p[pname] if pname in p else init(j, pname, input_type)
1801
                    model.param_init_net.RecurrentParamSet(
1802
                        [input_blob, weights, values],
1803
                        weights,
1804
                        layer=j,
1805
                        input_type=input_type,
1806
                        param_type=pname,
1807
                        **lstm_args
1808
                    )
1809
                    if pname not in param_extract_mapping[input_type]:
1810
                        param_extract_mapping[input_type][pname] = {}
1811
                    b = param_extract_net.RecurrentParamGet(
1812
                        [input_blob, weights],
1813
                        ["lstm_{}_{}_{}".format(input_type, pname, j)],
1814
                        layer=j,
1815
                        input_type=input_type,
1816
                        param_type=pname,
1817
                        **lstm_args
1818
                    )
1819
                    param_extract_mapping[input_type][pname][j] = b
1820

1821
        (hidden_input_blob, cell_input_blob) = initial_states
1822
        output, hidden_output, cell_output, rnn_scratch, dropout_states = \
1823
            model.net.Recurrent(
1824
                [input_blob, hidden_input_blob, cell_input_blob, weights],
1825
                ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
1826
                 "lstm_rnn_scratch", "lstm_dropout_states"],
1827
                seed=random.randint(0, 100000),  # TODO: dropout seed
1828
                **lstm_args
1829
            )
1830
        model.net.AddExternalOutputs(
1831
            hidden_output, cell_output, rnn_scratch, dropout_states)
1832

1833
    if return_params:
1834
        param_extract = param_extract_net, param_extract_mapping
1835
        return output, hidden_output, cell_output, param_extract
1836
    else:
1837
        return output, hidden_output, cell_output
1838

1839

1840
def LSTMWithAttention(
1841
    model,
1842
    decoder_inputs,
1843
    decoder_input_lengths,
1844
    initial_decoder_hidden_state,
1845
    initial_decoder_cell_state,
1846
    initial_attention_weighted_encoder_context,
1847
    encoder_output_dim,
1848
    encoder_outputs,
1849
    encoder_lengths,
1850
    decoder_input_dim,
1851
    decoder_state_dim,
1852
    scope,
1853
    attention_type=AttentionType.Regular,
1854
    outputs_with_grads=(0, 4),
1855
    weighted_encoder_outputs=None,
1856
    lstm_memory_optimization=False,
1857
    attention_memory_optimization=False,
1858
    forget_bias=0.0,
1859
    forward_only=False,
1860
):
1861
    '''
1862
    Adds a LSTM with attention mechanism to a model.
1863

1864
    The implementation is based on https://arxiv.org/abs/1409.0473, with
1865
    a small difference in the order
1866
    how we compute new attention context and new hidden state, similarly to
1867
    https://arxiv.org/abs/1508.04025.
1868

1869
    The model uses encoder-decoder naming conventions,
1870
    where the decoder is the sequence the op is iterating over,
1871
    while computing the attention context over the encoder.
1872

1873
    model: ModelHelper object new operators would be added to
1874

1875
    decoder_inputs: the input sequence in a format T x N x D
1876
    where T is sequence size, N - batch size and D - input dimension
1877

1878
    decoder_input_lengths: blob containing sequence lengths
1879
    which would be passed to LSTMUnit operator
1880

1881
    initial_decoder_hidden_state: initial hidden state of LSTM
1882

1883
    initial_decoder_cell_state: initial cell state of LSTM
1884

1885
    initial_attention_weighted_encoder_context: initial attention context
1886

1887
    encoder_output_dim: dimension of encoder outputs
1888

1889
    encoder_outputs: the sequence, on which we compute the attention context
1890
    at every iteration
1891

1892
    encoder_lengths: a tensor with lengths of each encoder sequence in batch
1893
    (may be None, meaning all encoder sequences are of same length)
1894

1895
    decoder_input_dim: input dimension (last dimension on decoder_inputs)
1896

1897
    decoder_state_dim: size of hidden states of LSTM
1898

1899
    attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
1900
    Determines which type of attention mechanism to use.
1901

1902
    outputs_with_grads : position indices of output blobs which will receive
1903
    external error gradient during backpropagation
1904

1905
    weighted_encoder_outputs: encoder outputs to be used to compute attention
1906
    weights. In the basic case it's just linear transformation of
1907
    encoder outputs (that the default, when weighted_encoder_outputs is None).
1908
    However, it can be something more complicated - like a separate
1909
    encoder network (for example, in case of convolutional encoder)
1910

1911
    lstm_memory_optimization: recompute LSTM activations on backward pass, so
1912
                 we don't need to store their values in forward passes
1913

1914
    attention_memory_optimization: recompute attention for backward pass
1915

1916
    forward_only: whether to create only forward pass
1917
    '''
1918
    cell = LSTMWithAttentionCell(
1919
        encoder_output_dim=encoder_output_dim,
1920
        encoder_outputs=encoder_outputs,
1921
        encoder_lengths=encoder_lengths,
1922
        decoder_input_dim=decoder_input_dim,
1923
        decoder_state_dim=decoder_state_dim,
1924
        name=scope,
1925
        attention_type=attention_type,
1926
        weighted_encoder_outputs=weighted_encoder_outputs,
1927
        forget_bias=forget_bias,
1928
        lstm_memory_optimization=lstm_memory_optimization,
1929
        attention_memory_optimization=attention_memory_optimization,
1930
        forward_only=forward_only,
1931
    )
1932
    initial_states = [
1933
        initial_decoder_hidden_state,
1934
        initial_decoder_cell_state,
1935
        initial_attention_weighted_encoder_context,
1936
    ]
1937
    if attention_type == AttentionType.SoftCoverage:
1938
        initial_states.append(cell.build_initial_coverage(model))
1939
    _, result = cell.apply_over_sequence(
1940
        model=model,
1941
        inputs=decoder_inputs,
1942
        seq_lengths=decoder_input_lengths,
1943
        initial_states=initial_states,
1944
        outputs_with_grads=outputs_with_grads,
1945
    )
1946
    return result
1947

1948

1949
def _layered_LSTM(
1950
        model, input_blob, seq_lengths, initial_states,
1951
        dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
1952
        memory_optimization=False, forget_bias=0.0, forward_only=False,
1953
        drop_states=False, create_lstm=None):
1954
    params = locals()  # leave it as a first line to grab all params
1955
    params.pop('create_lstm')
1956
    if not isinstance(dim_out, list):
1957
        return create_lstm(**params)
1958
    elif len(dim_out) == 1:
1959
        params['dim_out'] = dim_out[0]
1960
        return create_lstm(**params)
1961

1962
    assert len(dim_out) != 0, "dim_out list can't be empty"
1963
    assert return_params is False, "return_params not supported for layering"
1964
    for i, output_dim in enumerate(dim_out):
1965
        params.update({
1966
            'dim_out': output_dim
1967
        })
1968
        output, last_output, all_states, last_state = create_lstm(**params)
1969
        params.update({
1970
            'input_blob': output,
1971
            'dim_in': output_dim,
1972
            'initial_states': (last_output, last_state),
1973
            'scope': scope + '_layer_{}'.format(i + 1)
1974
        })
1975
    return output, last_output, all_states, last_state
1976

1977

1978
layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
1979

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.