14
from caffe2.proto import caffe2_pb2
15
from caffe2.python.attention import (
17
apply_recurrent_attention,
18
apply_regular_attention,
19
apply_soft_coverage_attention,
22
from caffe2.python import core, recurrent, workspace, brew, scope, utils
23
from caffe2.python.modeling.parameter_sharing import ParameterSharing
24
from caffe2.python.modeling.parameter_info import ParameterTags
25
from caffe2.python.modeling.initializers import Initializer
26
from caffe2.python.model_helper import ModelHelper
29
def _RectifyName(blob_reference_or_name):
30
if blob_reference_or_name is None:
32
if isinstance(blob_reference_or_name, str):
33
return core.ScopedBlobReference(blob_reference_or_name)
34
if not isinstance(blob_reference_or_name, core.BlobReference):
35
raise Exception("Unknown blob reference type")
36
return blob_reference_or_name
39
def _RectifyNames(blob_references_or_names):
40
if blob_references_or_names is None:
42
return [_RectifyName(i) for i in blob_references_or_names]
47
Base class for writing recurrent / stateful operations.
49
One needs to implement 2 methods: apply_override
50
and get_state_names_override.
52
As a result base class will provice apply_over_sequence method, which
53
allows you to apply recurrent operations over a sequence of any length.
55
As optional you could add input and output preparation steps by overriding
56
corresponding methods.
58
def __init__(self, name=None, forward_only=False, initializer=None):
60
self.recompute_blobs = []
61
self.forward_only = forward_only
62
self._initializer = initializer
65
def initializer(self):
66
return self._initializer
69
def initializer(self, value):
70
self._initializer = value
72
def scope(self, name):
73
return self.name + '/' + name if self.name is not None else name
75
def apply_over_sequence(
81
outputs_with_grads=None,
83
if initial_states is None:
84
with scope.NameScope(self.name):
85
if self.initializer is None:
86
raise Exception("Either initial states "
87
"or initializer have to be set")
88
initial_states = self.initializer.create_states(model)
90
preprocessed_inputs = self.prepare_input(model, inputs)
91
step_model = ModelHelper(name=self.name, param_model=model)
92
input_t, timestep = step_model.net.AddScopedExternalInputs(
96
utils.raiseIfNotEqual(
97
len(initial_states), len(self.get_state_names()),
98
"Number of initial state values provided doesn't match the number "
101
states_prev = step_model.net.AddScopedExternalInputs(*[
102
s + '_prev' for s in self.get_state_names()
104
states = self._apply(
107
seq_lengths=seq_lengths,
112
external_outputs = set(step_model.net.Proto().external_output)
114
if state not in external_outputs:
115
step_model.net.AddExternalOutput(state)
117
if outputs_with_grads is None:
118
outputs_with_grads = [self.get_output_state_index() * 2]
123
states_for_all_steps = recurrent.recurrent_net(
125
cell_net=step_model.net,
126
inputs=[(input_t, preprocessed_inputs)],
127
initial_cell_inputs=list(zip(states_prev, initial_states)),
128
links=dict(zip(states_prev, states)),
131
forward_only=self.forward_only,
132
outputs_with_grads=outputs_with_grads,
133
recompute_blobs_on_backward=self.recompute_blobs,
136
output = self._prepare_output_sequence(
138
states_for_all_steps,
140
return output, states_for_all_steps
142
def apply(self, model, input_t, seq_lengths, states, timestep):
143
input_t = self.prepare_input(model, input_t)
144
states = self._apply(
145
model, input_t, seq_lengths, states, timestep)
146
output = self._prepare_output(model, states)
147
return output, states
151
model, input_t, seq_lengths, states, timestep, extra_inputs=None
154
This method uses apply_override provided by a custom cell.
155
On the top it takes care of applying self.scope() to all the outputs.
156
While all the inputs stay within the scope this function was called
159
args = self._rectify_apply_inputs(
160
input_t, seq_lengths, states, timestep, extra_inputs)
161
with core.NameScope(self.name):
162
return self.apply_override(model, *args)
164
def _rectify_apply_inputs(
165
self, input_t, seq_lengths, states, timestep, extra_inputs):
167
Before applying a scope we make sure that all external blob names
168
are converted to blob reference. So further scoping doesn't affect them
171
input_t, seq_lengths, timestep = _RectifyNames(
172
[input_t, seq_lengths, timestep])
173
states = _RectifyNames(states)
175
extra_input_names, extra_input_sizes = zip(*extra_inputs)
176
extra_inputs = _RectifyNames(extra_input_names)
177
extra_inputs = zip(extra_input_names, extra_input_sizes)
179
arg_names = inspect.getargspec(self.apply_override).args
180
rectified = [input_t, seq_lengths, states, timestep]
181
if 'extra_inputs' in arg_names:
182
rectified.append(extra_inputs)
188
model, input_t, seq_lengths, timestep, extra_inputs=None,
191
A single step of a recurrent network to be implemented by each custom
194
model: ModelHelper object new operators would be added to
196
input_t: singlse input with shape (1, batch_size, input_dim)
198
seq_lengths: blob containing sequence lengths which would be passed to
201
states: previous recurrent states
203
timestep: current recurrent iteration. Could be used together with
204
seq_lengths in order to determine, if some shorter sequences
205
in the batch have already ended.
207
extra_inputs: list of tuples (input, dim). specifies additional input
208
which is not subject to prepare_input(). (useful when a cell is a
209
component of a larger recurrent structure, e.g., attention)
211
raise NotImplementedError('Abstract method')
213
def prepare_input(self, model, input_blob):
215
If some operations in _apply method depend only on the input,
216
not on recurrent states, they could be computed in advance.
218
model: ModelHelper object new operators would be added to
220
input_blob: either the whole input sequence with shape
221
(sequence_length, batch_size, input_dim) or a single input with shape
222
(1, batch_size, input_dim).
226
def get_output_state_index(self):
228
Return index into state list of the "primary" step-wise output.
232
def get_state_names(self):
234
Returns recurrent state names with self.name scoping applied
236
return [self.scope(name) for name in self.get_state_names_override()]
238
def get_state_names_override(self):
240
Override this function in your custom cell.
241
It should return the names of the recurrent states.
243
It's required by apply_over_sequence method in order to allocate
244
recurrent states for all steps with meaningful names.
246
raise NotImplementedError('Abstract method')
248
def get_output_dim(self):
250
Specifies the dimension (number of units) of stepwise output.
252
raise NotImplementedError('Abstract method')
254
def _prepare_output(self, model, states):
256
Allows arbitrary post-processing of primary output.
258
return states[self.get_output_state_index()]
260
def _prepare_output_sequence(self, model, state_outputs):
262
Allows arbitrary post-processing of primary sequence output.
264
(Note that state_outputs alternates between full-sequence and final
265
output for each state, thus the index multiplier 2.)
267
output_sequence_index = 2 * self.get_output_state_index()
268
return state_outputs[output_sequence_index]
271
class LSTMInitializer:
272
def __init__(self, hidden_size):
273
self.hidden_size = hidden_size
275
def create_states(self, model):
278
param_name='initial_hidden_state',
279
initializer=Initializer(operator_name='ConstantFill',
281
shape=[self.hidden_size],
284
param_name='initial_cell_state',
285
initializer=Initializer(operator_name='ConstantFill',
287
shape=[self.hidden_size],
293
class BasicRNNCell(RNNCell):
305
super().__init__(**kwargs)
306
self.drop_states = drop_states
307
self.input_size = input_size
308
self.hidden_size = hidden_size
309
self.activation = activation
311
if self.activation not in ['relu', 'tanh']:
313
'BasicRNNCell with unknown activation function (%s)'
325
hidden_t_prev = states[0]
331
dim_in=self.hidden_size,
332
dim_out=self.hidden_size,
336
brew.sum(model, [gates_t, input_t], gates_t)
337
if self.activation == 'tanh':
338
hidden_t = model.net.Tanh(gates_t, 'hidden_t')
339
elif self.activation == 'relu':
340
hidden_t = model.net.Relu(gates_t, 'hidden_t')
343
'BasicRNNCell with unknown activation function (%s)'
346
if seq_lengths is not None:
351
timestep = model.net.CopyFromCPUInput(
352
timestep, 'timestep_gpu')
353
valid_b = model.net.GT(
354
[seq_lengths, timestep], 'valid_b', broadcast=1)
355
invalid_b = model.net.LE(
356
[seq_lengths, timestep], 'invalid_b', broadcast=1)
357
valid = model.net.Cast(valid_b, 'valid', to='float')
358
invalid = model.net.Cast(invalid_b, 'invalid', to='float')
360
hidden_valid = model.net.Mul(
367
hidden_t = hidden_valid
369
hidden_invalid = model.net.Mul(
370
[hidden_t_prev, invalid],
373
hidden_t = model.net.Add(
374
[hidden_valid, hidden_invalid], hidden_t)
377
def prepare_input(self, model, input_blob):
382
dim_in=self.input_size,
383
dim_out=self.hidden_size,
387
def get_state_names(self):
388
return (self.scope('hidden_t'),)
390
def get_output_dim(self):
391
return self.hidden_size
394
class LSTMCell(RNNCell):
406
super().__init__(initializer=initializer, **kwargs)
407
self.initializer = initializer or LSTMInitializer(
408
hidden_size=hidden_size)
410
self.input_size = input_size
411
self.hidden_size = hidden_size
412
self.forget_bias = float(forget_bias)
413
self.memory_optimization = memory_optimization
414
self.drop_states = drop_states
415
self.gates_size = 4 * self.hidden_size
426
hidden_t_prev, cell_t_prev = states
428
fc_input = hidden_t_prev
429
fc_input_dim = self.hidden_size
431
if extra_inputs is not None:
432
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
433
fc_input = brew.concat(
435
[hidden_t_prev] + list(extra_input_blobs),
436
'gates_concatenated_input_t',
439
fc_input_dim += sum(extra_input_sizes)
446
dim_out=self.gates_size,
449
brew.sum(model, [gates_t, input_t], gates_t)
451
if seq_lengths is not None:
452
inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
454
inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
456
hidden_t, cell_t = model.net.LSTMUnit(
458
['hidden_state', 'cell_state'],
459
forget_bias=self.forget_bias,
460
drop_states=self.drop_states,
461
sequence_lengths=(seq_lengths is not None),
463
model.net.AddExternalOutputs(hidden_t, cell_t)
464
if self.memory_optimization:
465
self.recompute_blobs = [gates_t]
467
return hidden_t, cell_t
469
def get_input_params(self):
471
'weights': self.scope('i2h') + '_w',
472
'biases': self.scope('i2h') + '_b',
475
def get_recurrent_params(self):
477
'weights': self.scope('gates_t') + '_w',
478
'biases': self.scope('gates_t') + '_b',
481
def prepare_input(self, model, input_blob):
486
dim_in=self.input_size,
487
dim_out=self.gates_size,
491
def get_state_names_override(self):
492
return ['hidden_t', 'cell_t']
494
def get_output_dim(self):
495
return self.hidden_size
498
class LayerNormLSTMCell(RNNCell):
510
super().__init__(initializer=initializer, **kwargs)
511
self.initializer = initializer or LSTMInitializer(
512
hidden_size=hidden_size
515
self.input_size = input_size
516
self.hidden_size = hidden_size
517
self.forget_bias = float(forget_bias)
518
self.memory_optimization = memory_optimization
519
self.drop_states = drop_states
520
self.gates_size = 4 * self.hidden_size
531
hidden_t_prev, cell_t_prev = states
533
fc_input = hidden_t_prev
534
fc_input_dim = self.hidden_size
536
if extra_inputs is not None:
537
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
538
fc_input = brew.concat(
540
[hidden_t_prev] + list(extra_input_blobs),
541
self.scope('gates_concatenated_input_t'),
544
fc_input_dim += sum(extra_input_sizes)
549
self.scope('gates_t'),
551
dim_out=self.gates_size,
554
brew.sum(model, [gates_t, input_t], gates_t)
557
gates_t, _, _ = brew.layer_norm(
559
self.scope('gates_t'),
560
self.scope('gates_t_norm'),
561
dim_in=self.gates_size,
565
hidden_t, cell_t = model.net.LSTMUnit(
573
self.get_state_names(),
574
forget_bias=self.forget_bias,
575
drop_states=self.drop_states,
577
model.net.AddExternalOutputs(hidden_t, cell_t)
578
if self.memory_optimization:
579
self.recompute_blobs = [gates_t]
581
return hidden_t, cell_t
583
def get_input_params(self):
585
'weights': self.scope('i2h') + '_w',
586
'biases': self.scope('i2h') + '_b',
589
def prepare_input(self, model, input_blob):
594
dim_in=self.input_size,
595
dim_out=self.gates_size,
599
def get_state_names(self):
600
return (self.scope('hidden_t'), self.scope('cell_t'))
603
class MILSTMCell(LSTMCell):
614
hidden_t_prev, cell_t_prev = states
616
fc_input = hidden_t_prev
617
fc_input_dim = self.hidden_size
619
if extra_inputs is not None:
620
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
621
fc_input = brew.concat(
623
[hidden_t_prev] + list(extra_input_blobs),
624
self.scope('gates_concatenated_input_t'),
627
fc_input_dim += sum(extra_input_sizes)
632
self.scope('prev_t'),
634
dim_out=self.gates_size,
639
alpha = model.create_param(
641
shape=[self.gates_size],
642
initializer=Initializer('ConstantFill', value=1.0),
644
beta_h = model.create_param(
646
shape=[self.gates_size],
647
initializer=Initializer('ConstantFill', value=1.0),
649
beta_i = model.create_param(
651
shape=[self.gates_size],
652
initializer=Initializer('ConstantFill', value=1.0),
654
b = model.create_param(
656
shape=[self.gates_size],
657
initializer=Initializer('ConstantFill', value=0.0),
662
alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
663
[input_t, alpha, beta_h],
664
self.scope('alpha_by_input_t_plus_beta_h'),
670
alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
671
[alpha_by_input_t_plus_beta_h, prev_t],
672
self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
676
beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
677
[input_t, beta_i, b],
678
self.scope('beta_i_by_input_t_plus_b'),
685
[alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
686
self.scope('gates_t')
688
hidden_t, cell_t = model.net.LSTMUnit(
689
[hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
690
[self.scope('hidden_t_intermediate'), self.scope('cell_t')],
691
forget_bias=self.forget_bias,
692
drop_states=self.drop_states,
694
model.net.AddExternalOutputs(
698
if self.memory_optimization:
699
self.recompute_blobs = [gates_t]
700
return hidden_t, cell_t
703
class LayerNormMILSTMCell(LSTMCell):
714
hidden_t_prev, cell_t_prev = states
716
fc_input = hidden_t_prev
717
fc_input_dim = self.hidden_size
719
if extra_inputs is not None:
720
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
721
fc_input = brew.concat(
723
[hidden_t_prev] + list(extra_input_blobs),
724
self.scope('gates_concatenated_input_t'),
727
fc_input_dim += sum(extra_input_sizes)
732
self.scope('prev_t'),
734
dim_out=self.gates_size,
739
alpha = model.create_param(
741
shape=[self.gates_size],
742
initializer=Initializer('ConstantFill', value=1.0),
744
beta_h = model.create_param(
746
shape=[self.gates_size],
747
initializer=Initializer('ConstantFill', value=1.0),
749
beta_i = model.create_param(
751
shape=[self.gates_size],
752
initializer=Initializer('ConstantFill', value=1.0),
754
b = model.create_param(
756
shape=[self.gates_size],
757
initializer=Initializer('ConstantFill', value=0.0),
762
alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
763
[input_t, alpha, beta_h],
764
self.scope('alpha_by_input_t_plus_beta_h'),
770
alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
771
[alpha_by_input_t_plus_beta_h, prev_t],
772
self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
776
beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
777
[input_t, beta_i, b],
778
self.scope('beta_i_by_input_t_plus_b'),
785
[alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
786
self.scope('gates_t')
789
gates_t, _, _ = brew.layer_norm(
791
self.scope('gates_t'),
792
self.scope('gates_t_norm'),
793
dim_in=self.gates_size,
796
hidden_t, cell_t = model.net.LSTMUnit(
797
[hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
798
[self.scope('hidden_t_intermediate'), self.scope('cell_t')],
799
forget_bias=self.forget_bias,
800
drop_states=self.drop_states,
802
model.net.AddExternalOutputs(
806
if self.memory_optimization:
807
self.recompute_blobs = [gates_t]
808
return hidden_t, cell_t
811
class DropoutCell(RNNCell):
813
Wraps arbitrary RNNCell, applying dropout to its output (but not to the
814
recurrent connection for the corresponding state).
824
self.internal_cell = internal_cell
825
self.dropout_ratio = dropout_ratio
826
assert 'is_test' in kwargs, "Argument 'is_test' is required"
827
self.is_test = kwargs.pop('is_test')
828
self.use_cudnn = use_cudnn
829
super().__init__(**kwargs)
831
self.prepare_input = internal_cell.prepare_input
832
self.get_output_state_index = internal_cell.get_output_state_index
833
self.get_state_names = internal_cell.get_state_names
834
self.get_output_dim = internal_cell.get_output_dim
847
return self.internal_cell._apply(
856
def _prepare_output(self, model, states):
857
output = self.internal_cell._prepare_output(
861
if self.dropout_ratio is not None:
862
output = self._apply_dropout(model, output)
865
def _prepare_output_sequence(self, model, state_outputs):
866
output = self.internal_cell._prepare_output_sequence(
870
if self.dropout_ratio is not None:
871
output = self._apply_dropout(model, output)
874
def _apply_dropout(self, model, output):
875
if self.dropout_ratio and not self.forward_only:
876
with core.NameScope(self.name or ''):
877
output = brew.dropout(
880
str(output) + '_with_dropout_mask{}'.format(self.mask),
881
ratio=float(self.dropout_ratio),
882
is_test=self.is_test,
883
use_cudnn=self.use_cudnn,
889
class MultiRNNCellInitializer:
890
def __init__(self, cells):
893
def create_states(self, model):
895
for i, cell in enumerate(self.cells):
896
if cell.initializer is None:
897
raise Exception("Either initial states "
898
"or initializer have to be set")
900
with core.NameScope("layer_{}".format(i)),\
901
core.NameScope(cell.name):
902
states.extend(cell.initializer.create_states(model))
906
class MultiRNNCell(RNNCell):
908
Multilayer RNN via the composition of RNNCell instance.
910
It is the responsibility of calling code to ensure the compatibility
911
of the successive layers in terms of input/output dimensiality, etc.,
912
and to ensure that their blobs do not have name conflicts, typically by
913
creating the cells with names that specify layer number.
915
Assumes first state (recurrent output) for each layer should be the input
919
def __init__(self, cells, residual_output_layers=None, **kwargs):
921
cells: list of RNNCell instances, from input to output side.
923
name: string designating network component (for scoping)
925
residual_output_layers: list of indices of layers whose input will
926
be added elementwise to their output elementwise. (It is the
927
responsibility of the client code to ensure shape compatibility.)
928
Note that layer 0 (zero) cannot have residual output because of the
929
timing of prepare_input().
931
forward_only: used to construct inference-only network.
933
super().__init__(**kwargs)
936
if residual_output_layers is None:
937
self.residual_output_layers = []
939
self.residual_output_layers = residual_output_layers
941
output_index_per_layer = []
943
for cell in self.cells:
944
output_index_per_layer.append(
945
base_index + cell.get_output_state_index(),
947
base_index += len(cell.get_state_names())
949
self.output_connected_layers = []
950
self.output_indices = []
951
for i in range(len(self.cells) - 1):
952
if (i + 1) in self.residual_output_layers:
953
self.output_connected_layers.append(i)
954
self.output_indices.append(output_index_per_layer[i])
956
self.output_connected_layers = []
957
self.output_indices = []
958
self.output_connected_layers.append(len(self.cells) - 1)
959
self.output_indices.append(output_index_per_layer[-1])
961
self.state_names = []
962
for i, cell in enumerate(self.cells):
963
self.state_names.extend(
964
map(self.layer_scoper(i), cell.get_state_names())
967
self.initializer = MultiRNNCellInitializer(cells)
969
def layer_scoper(self, layer_id):
971
return "{}/layer_{}/{}".format(self.name, layer_id, name)
974
def prepare_input(self, model, input_blob):
975
input_blob = _RectifyName(input_blob)
976
with core.NameScope(self.name or ''):
977
return self.cells[0].prepare_input(model, input_blob)
989
Because below we will do scoping across layers, we need
990
to make sure that string blob names are convereted to BlobReference
994
input_t, seq_lengths, states, timestep, extra_inputs = \
995
self._rectify_apply_inputs(
996
input_t, seq_lengths, states, timestep, extra_inputs)
998
states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
999
assert len(states) == sum(states_per_layer)
1004
layer_input = input_t
1005
for i, layer_cell in enumerate(self.cells):
1008
with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
1009
num_states = states_per_layer[i]
1010
layer_states = states[states_index:(states_index + num_states)]
1011
states_index += num_states
1014
prepared_input = layer_cell.prepare_input(
1017
prepared_input = layer_input
1019
layer_next_states = layer_cell._apply(
1025
extra_inputs=(None if i > 0 else extra_inputs),
1030
if i != len(self.cells) - 1:
1031
layer_output = layer_cell._prepare_output(
1035
if i > 0 and i in self.residual_output_layers:
1036
layer_input = brew.sum(
1038
[layer_output, layer_input],
1039
self.scope('residual_output_{}'.format(i)),
1042
layer_input = layer_output
1044
next_states.extend(layer_next_states)
1047
def get_state_names(self):
1048
return self.state_names
1050
def get_output_state_index(self):
1052
for cell in self.cells[:-1]:
1053
index += len(cell.get_state_names())
1054
index += self.cells[-1].get_output_state_index()
1057
def _prepare_output(self, model, states):
1058
connected_outputs = []
1060
for i, cell in enumerate(self.cells):
1061
num_states = len(cell.get_state_names())
1062
if i in self.output_connected_layers:
1063
layer_states = states[state_index:state_index + num_states]
1064
layer_output = cell._prepare_output(
1068
connected_outputs.append(layer_output)
1069
state_index += num_states
1070
if len(connected_outputs) > 1:
1074
self.scope('residual_output'),
1077
output = connected_outputs[0]
1080
def _prepare_output_sequence(self, model, states):
1081
connected_outputs = []
1083
for i, cell in enumerate(self.cells):
1084
num_states = 2 * len(cell.get_state_names())
1085
if i in self.output_connected_layers:
1086
layer_states = states[state_index:state_index + num_states]
1087
layer_output = cell._prepare_output_sequence(
1091
connected_outputs.append(layer_output)
1092
state_index += num_states
1093
if len(connected_outputs) > 1:
1097
self.scope('residual_output_sequence'),
1100
output = connected_outputs[0]
1104
class AttentionCell(RNNCell):
1114
weighted_encoder_outputs,
1115
attention_memory_optimization,
1118
super().__init__(**kwargs)
1119
self.encoder_output_dim = encoder_output_dim
1120
self.encoder_outputs = encoder_outputs
1121
self.encoder_lengths = encoder_lengths
1122
self.decoder_cell = decoder_cell
1123
self.decoder_state_dim = decoder_state_dim
1124
self.weighted_encoder_outputs = weighted_encoder_outputs
1125
self.encoder_outputs_transposed = None
1126
assert attention_type in [
1127
AttentionType.Regular,
1128
AttentionType.Recurrent,
1130
AttentionType.SoftCoverage,
1132
self.attention_type = attention_type
1133
self.attention_memory_optimization = attention_memory_optimization
1144
if self.attention_type == AttentionType.SoftCoverage:
1145
decoder_prev_states = states[:-2]
1146
attention_weighted_encoder_context_t_prev = states[-2]
1147
coverage_t_prev = states[-1]
1149
decoder_prev_states = states[:-1]
1150
attention_weighted_encoder_context_t_prev = states[-1]
1152
assert extra_inputs is None
1154
decoder_states = self.decoder_cell._apply(
1158
decoder_prev_states,
1161
attention_weighted_encoder_context_t_prev,
1162
self.encoder_output_dim,
1166
self.hidden_t_intermediate = self.decoder_cell._prepare_output(
1171
if self.attention_type == AttentionType.Recurrent:
1173
attention_weighted_encoder_context_t,
1174
self.attention_weights_3d,
1176
) = apply_recurrent_attention(
1178
encoder_output_dim=self.encoder_output_dim,
1179
encoder_outputs_transposed=self.encoder_outputs_transposed,
1180
weighted_encoder_outputs=self.weighted_encoder_outputs,
1181
decoder_hidden_state_t=self.hidden_t_intermediate,
1182
decoder_hidden_state_dim=self.decoder_state_dim,
1184
attention_weighted_encoder_context_t_prev=(
1185
attention_weighted_encoder_context_t_prev
1187
encoder_lengths=self.encoder_lengths,
1189
elif self.attention_type == AttentionType.Regular:
1191
attention_weighted_encoder_context_t,
1192
self.attention_weights_3d,
1194
) = apply_regular_attention(
1196
encoder_output_dim=self.encoder_output_dim,
1197
encoder_outputs_transposed=self.encoder_outputs_transposed,
1198
weighted_encoder_outputs=self.weighted_encoder_outputs,
1199
decoder_hidden_state_t=self.hidden_t_intermediate,
1200
decoder_hidden_state_dim=self.decoder_state_dim,
1202
encoder_lengths=self.encoder_lengths,
1204
elif self.attention_type == AttentionType.Dot:
1206
attention_weighted_encoder_context_t,
1207
self.attention_weights_3d,
1209
) = apply_dot_attention(
1211
encoder_output_dim=self.encoder_output_dim,
1212
encoder_outputs_transposed=self.encoder_outputs_transposed,
1213
decoder_hidden_state_t=self.hidden_t_intermediate,
1214
decoder_hidden_state_dim=self.decoder_state_dim,
1216
encoder_lengths=self.encoder_lengths,
1218
elif self.attention_type == AttentionType.SoftCoverage:
1220
attention_weighted_encoder_context_t,
1221
self.attention_weights_3d,
1224
) = apply_soft_coverage_attention(
1226
encoder_output_dim=self.encoder_output_dim,
1227
encoder_outputs_transposed=self.encoder_outputs_transposed,
1228
weighted_encoder_outputs=self.weighted_encoder_outputs,
1229
decoder_hidden_state_t=self.hidden_t_intermediate,
1230
decoder_hidden_state_dim=self.decoder_state_dim,
1232
encoder_lengths=self.encoder_lengths,
1233
coverage_t_prev=coverage_t_prev,
1234
coverage_weights=self.coverage_weights,
1237
raise Exception('Attention type {} not implemented'.format(
1241
if self.attention_memory_optimization:
1242
self.recompute_blobs.extend(attention_blobs)
1244
output = list(decoder_states) + [attention_weighted_encoder_context_t]
1245
if self.attention_type == AttentionType.SoftCoverage:
1246
output.append(coverage_t)
1248
output[self.decoder_cell.get_output_state_index()] = model.Copy(
1249
output[self.decoder_cell.get_output_state_index()],
1250
self.scope('hidden_t_external'),
1252
model.net.AddExternalOutputs(*output)
1256
def get_attention_weights(self):
1258
return self.attention_weights_3d
1260
def prepare_input(self, model, input_blob):
1261
if self.encoder_outputs_transposed is None:
1262
self.encoder_outputs_transposed = brew.transpose(
1264
self.encoder_outputs,
1265
self.scope('encoder_outputs_transposed'),
1269
self.weighted_encoder_outputs is None and
1270
self.attention_type != AttentionType.Dot
1272
self.weighted_encoder_outputs = brew.fc(
1274
self.encoder_outputs,
1275
self.scope('weighted_encoder_outputs'),
1276
dim_in=self.encoder_output_dim,
1277
dim_out=self.encoder_output_dim,
1281
return self.decoder_cell.prepare_input(model, input_blob)
1283
def build_initial_coverage(self, model):
1285
initial_coverage is always zeros of shape [encoder_length],
1286
which shape must be determined programmatically dureing network
1289
This method also sets self.coverage_weights, a separate transform
1290
of encoder_outputs which is used to determine coverage contribution
1293
assert self.attention_type == AttentionType.SoftCoverage
1296
self.coverage_weights = brew.fc(
1298
self.encoder_outputs,
1299
self.scope('coverage_weights'),
1300
dim_in=self.encoder_output_dim,
1301
dim_out=self.encoder_output_dim,
1305
encoder_length = model.net.Slice(
1306
model.net.Shape(self.encoder_outputs),
1311
scope.CurrentDeviceScope() is not None and
1312
core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
1314
encoder_length = model.net.CopyGPUToCPU(
1316
'encoder_length_cpu',
1320
initial_coverage = model.net.ConstantFill(
1322
self.scope('initial_coverage'),
1326
return initial_coverage
1328
def get_state_names(self):
1329
state_names = list(self.decoder_cell.get_state_names())
1330
state_names[self.get_output_state_index()] = self.scope(
1331
'hidden_t_external',
1333
state_names.append(self.scope('attention_weighted_encoder_context_t'))
1334
if self.attention_type == AttentionType.SoftCoverage:
1335
state_names.append(self.scope('coverage_t'))
1338
def get_output_dim(self):
1339
return self.decoder_state_dim + self.encoder_output_dim
1341
def get_output_state_index(self):
1342
return self.decoder_cell.get_output_state_index()
1344
def _prepare_output(self, model, states):
1345
if self.attention_type == AttentionType.SoftCoverage:
1346
attention_context = states[-2]
1348
attention_context = states[-1]
1350
with core.NameScope(self.name or ''):
1351
output = brew.concat(
1353
[self.hidden_t_intermediate, attention_context],
1354
'states_and_context_combination',
1360
def _prepare_output_sequence(self, model, state_outputs):
1361
if self.attention_type == AttentionType.SoftCoverage:
1362
decoder_state_outputs = state_outputs[:-4]
1364
decoder_state_outputs = state_outputs[:-2]
1366
decoder_output = self.decoder_cell._prepare_output_sequence(
1368
decoder_state_outputs,
1371
if self.attention_type == AttentionType.SoftCoverage:
1372
attention_context_index = 2 * (len(self.get_state_names()) - 2)
1374
attention_context_index = 2 * (len(self.get_state_names()) - 1)
1376
with core.NameScope(self.name or ''):
1377
output = brew.concat(
1381
state_outputs[attention_context_index],
1383
'states_and_context_combination',
1389
class LSTMWithAttentionCell(AttentionCell):
1400
weighted_encoder_outputs,
1402
lstm_memory_optimization,
1403
attention_memory_optimization,
1406
decoder_cell = LSTMCell(
1407
input_size=decoder_input_dim,
1408
hidden_size=decoder_state_dim,
1409
forget_bias=forget_bias,
1410
memory_optimization=lstm_memory_optimization,
1411
name='{}/decoder'.format(name),
1416
encoder_output_dim=encoder_output_dim,
1417
encoder_outputs=encoder_outputs,
1418
encoder_lengths=encoder_lengths,
1419
decoder_cell=decoder_cell,
1420
decoder_state_dim=decoder_state_dim,
1422
attention_type=attention_type,
1423
weighted_encoder_outputs=weighted_encoder_outputs,
1424
attention_memory_optimization=attention_memory_optimization,
1425
forward_only=forward_only,
1429
class MILSTMWithAttentionCell(AttentionCell):
1439
weighted_encoder_outputs,
1441
lstm_memory_optimization,
1442
attention_memory_optimization,
1445
decoder_cell = MILSTMCell(
1446
input_size=decoder_input_dim,
1447
hidden_size=decoder_state_dim,
1448
forget_bias=forget_bias,
1449
memory_optimization=lstm_memory_optimization,
1450
name='{}/decoder'.format(name),
1455
encoder_output_dim=encoder_output_dim,
1456
encoder_outputs=encoder_outputs,
1457
decoder_cell=decoder_cell,
1458
decoder_state_dim=decoder_state_dim,
1460
attention_type=attention_type,
1461
weighted_encoder_outputs=weighted_encoder_outputs,
1462
attention_memory_optimization=attention_memory_optimization,
1463
forward_only=forward_only,
1476
outputs_with_grads=(0,),
1477
return_params=False,
1478
memory_optimization=False,
1482
return_last_layer_only=True,
1483
static_rnn_unroll_size=None,
1487
Adds a standard LSTM recurrent network operator to a model.
1489
cell_class: LSTMCell or compatible subclass
1491
model: ModelHelper object new operators would be added to
1493
input_blob: the input sequence in a format T x N x D
1494
where T is sequence size, N - batch size and D - input dimension
1496
seq_lengths: blob containing sequence lengths which would be passed to
1499
initial_states: a list of (2 * num_layers) blobs representing the initial
1500
hidden and cell states of each layer. If this argument is None,
1501
these states will be added to the model as network parameters.
1503
dim_in: input dimension
1505
dim_out: number of units per LSTM layer
1506
(use int for single-layer LSTM, list of ints for multi-layer)
1508
outputs_with_grads : position indices of output blobs for LAST LAYER which
1509
will receive external error gradient during backpropagation.
1510
These outputs are: (h_all, h_last, c_all, c_last)
1512
return_params: if True, will return a dictionary of parameters of the LSTM
1514
memory_optimization: if enabled, the LSTM step is recomputed on backward
1515
step so that we don't need to store forward activations for each
1516
timestep. Saves memory with cost of computation.
1518
forget_bias: forget gate bias (default 0.0)
1520
forward_only: whether to create a backward pass
1522
drop_states: drop invalid states, passed through to LSTMUnit operator
1524
return_last_layer_only: only return outputs from final layer
1525
(so that length of results does depend on number of layers)
1527
static_rnn_unroll_size: if not None, we will use static RNN which is
1528
unrolled into Caffe2 graph. The size of the unroll is the value of
1531
if type(dim_out) is not list and type(dim_out) is not tuple:
1533
num_layers = len(dim_out)
1536
for i in range(num_layers):
1538
input_size=(dim_in if i == 0 else dim_out[i - 1]),
1539
hidden_size=dim_out[i],
1540
forget_bias=forget_bias,
1541
memory_optimization=memory_optimization,
1542
name=scope if num_layers == 1 else None,
1543
forward_only=forward_only,
1544
drop_states=drop_states,
1549
cell = MultiRNNCell(
1552
forward_only=forward_only,
1553
) if num_layers > 1 else cells[0]
1556
cell if static_rnn_unroll_size is None
1557
else UnrolledCell(cell, static_rnn_unroll_size))
1560
outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
1561
_, result = cell.apply_over_sequence(
1564
seq_lengths=seq_lengths,
1565
initial_states=initial_states,
1566
outputs_with_grads=outputs_with_grads,
1569
if return_last_layer_only:
1570
result = result[4 * (num_layers - 1):]
1572
result = list(result) + [{
1573
'input': cell.get_input_params(),
1574
'recurrent': cell.get_recurrent_params(),
1576
return tuple(result)
1579
LSTM = functools.partial(_LSTM, LSTMCell)
1580
BasicRNN = functools.partial(_LSTM, BasicRNNCell)
1581
MILSTM = functools.partial(_LSTM, MILSTMCell)
1582
LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
1583
LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
1586
class UnrolledCell(RNNCell):
1587
def __init__(self, cell, T):
1591
def apply_over_sequence(
1597
outputs_with_grads=None,
1599
inputs = self.cell.prepare_input(model, inputs)
1602
split_inputs = model.net.Split(
1604
[str(inputs) + "_timestep_{}".format(i)
1605
for i in range(self.T)],
1608
split_inputs = [split_inputs]
1610
states = initial_states
1612
for t in range(0, self.T):
1613
scope_name = "timestep_{}".format(t)
1615
with ParameterSharing({scope_name: ''}),\
1616
scope.NameScope(scope_name):
1617
timestep = model.param_init_net.ConstantFill(
1618
[], "timestep", value=t, shape=[1],
1619
dtype=core.DataType.INT32,
1620
device_option=core.DeviceOption(caffe2_pb2.CPU))
1621
states = self.cell._apply(
1623
input_t=split_inputs[t],
1624
seq_lengths=seq_lengths,
1628
all_states.append(states)
1630
all_states = zip(*all_states)
1635
str(full_output[0])[len("timestep_0/"):] + "_concat",
1636
str(full_output[0])[len("timestep_0/"):] + "_concat_info"
1640
for full_output in all_states
1651
state for state_pair in zip(all_states, states) for state in state_pair
1653
outputs_without_grad = set(range(len(outputs))) - set(
1655
for i in outputs_without_grad:
1656
model.net.ZeroGradient(outputs[i], [])
1657
logging.debug("Added 0 gradients for blobs:",
1658
[outputs[i] for i in outputs_without_grad])
1660
final_output = self.cell._prepare_output_sequence(model, outputs)
1662
return final_output, outputs
1665
def GetLSTMParamNames():
1666
weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
1667
bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
1668
return {'weights': weight_params, 'biases': bias_params}
1671
def InitFromLSTMParams(lstm_pblobs, param_values):
1673
Set the parameters of LSTM based on predefined values
1675
weight_params = GetLSTMParamNames()['weights']
1676
bias_params = GetLSTMParamNames()['biases']
1677
for input_type in param_values.keys():
1679
param_values[input_type][w].flatten()
1680
for w in weight_params
1683
for w in weight_values:
1684
wmat = np.append(wmat, w)
1686
param_values[input_type][b].flatten()
1687
for b in bias_params
1690
for b in bias_values:
1691
bm = np.append(bm, b)
1693
weights_blob = lstm_pblobs[input_type]['weights']
1694
bias_blob = lstm_pblobs[input_type]['biases']
1695
cur_weight = workspace.FetchBlob(weights_blob)
1696
cur_biases = workspace.FetchBlob(bias_blob)
1700
wmat.reshape(cur_weight.shape).astype(np.float32))
1703
bm.reshape(cur_biases.shape).astype(np.float32))
1706
def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
1707
scope, recurrent_params=None, input_params=None,
1708
num_layers=1, return_params=False):
1710
CuDNN version of LSTM for GPUs.
1711
input_blob Blob containing the input. Will need to be available
1712
when param_init_net is run, because the sequence lengths
1713
and batch sizes will be inferred from the size of this
1715
initial_states tuple of (hidden_init, cell_init) blobs
1716
dim_in input dimensions
1717
dim_out output/hidden dimension
1718
scope namescope to apply
1719
recurrent_params dict of blobs containing values for recurrent
1720
gate weights, biases (if None, use random init values)
1721
See GetLSTMParamNames() for format.
1722
input_params dict of blobs containing values for input
1723
gate weights, biases (if None, use random init values)
1724
See GetLSTMParamNames() for format.
1725
num_layers number of LSTM layers
1726
return_params if True, returns (param_extract_net, param_mapping)
1727
where param_extract_net is a net that when run, will
1728
populate the blobs specified in param_mapping with the
1729
current gate weights and biases (input/recurrent).
1730
Useful for assigning the values back to non-cuDNN
1733
with core.NameScope(scope):
1734
weight_params = GetLSTMParamNames()['weights']
1735
bias_params = GetLSTMParamNames()['biases']
1737
input_weight_size = dim_out * dim_in
1738
upper_layer_input_weight_size = dim_out * dim_out
1739
recurrent_weight_size = dim_out * dim_out
1740
input_bias_size = dim_out
1741
recurrent_bias_size = dim_out
1743
def init(layer, pname, input_type):
1744
input_weight_size_for_layer = input_weight_size if layer == 0 else \
1745
upper_layer_input_weight_size
1746
if pname in weight_params:
1747
sz = input_weight_size_for_layer if input_type == 'input' \
1748
else recurrent_weight_size
1749
elif pname in bias_params:
1750
sz = input_bias_size if input_type == 'input' \
1751
else recurrent_bias_size
1753
assert False, "unknown parameter type {}".format(pname)
1754
return model.param_init_net.UniformFill(
1756
"lstm_init_{}_{}_{}".format(input_type, pname, layer),
1760
first_layer_sz = input_weight_size + recurrent_weight_size + \
1761
input_bias_size + recurrent_bias_size
1762
upper_layer_sz = upper_layer_input_weight_size + \
1763
recurrent_weight_size + input_bias_size + \
1765
total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
1767
weights = model.create_param(
1770
initializer=Initializer('UniformFill'),
1771
tags=ParameterTags.WEIGHT,
1775
'hidden_size': dim_out,
1779
'input_mode': 'linear',
1780
'num_layers': num_layers,
1784
param_extract_net = core.Net("lstm_param_extractor")
1785
param_extract_net.AddExternalInputs([input_blob, weights])
1786
param_extract_mapping = {}
1793
for input_type in ['input', 'recurrent']:
1794
param_extract_mapping[input_type] = {}
1795
p = recurrent_params if input_type == 'recurrent' else input_params
1798
for pname in weight_params + bias_params:
1799
for j in range(0, num_layers):
1800
values = p[pname] if pname in p else init(j, pname, input_type)
1801
model.param_init_net.RecurrentParamSet(
1802
[input_blob, weights, values],
1805
input_type=input_type,
1809
if pname not in param_extract_mapping[input_type]:
1810
param_extract_mapping[input_type][pname] = {}
1811
b = param_extract_net.RecurrentParamGet(
1812
[input_blob, weights],
1813
["lstm_{}_{}_{}".format(input_type, pname, j)],
1815
input_type=input_type,
1819
param_extract_mapping[input_type][pname][j] = b
1821
(hidden_input_blob, cell_input_blob) = initial_states
1822
output, hidden_output, cell_output, rnn_scratch, dropout_states = \
1823
model.net.Recurrent(
1824
[input_blob, hidden_input_blob, cell_input_blob, weights],
1825
["lstm_output", "lstm_hidden_output", "lstm_cell_output",
1826
"lstm_rnn_scratch", "lstm_dropout_states"],
1827
seed=random.randint(0, 100000),
1830
model.net.AddExternalOutputs(
1831
hidden_output, cell_output, rnn_scratch, dropout_states)
1834
param_extract = param_extract_net, param_extract_mapping
1835
return output, hidden_output, cell_output, param_extract
1837
return output, hidden_output, cell_output
1840
def LSTMWithAttention(
1843
decoder_input_lengths,
1844
initial_decoder_hidden_state,
1845
initial_decoder_cell_state,
1846
initial_attention_weighted_encoder_context,
1853
attention_type=AttentionType.Regular,
1854
outputs_with_grads=(0, 4),
1855
weighted_encoder_outputs=None,
1856
lstm_memory_optimization=False,
1857
attention_memory_optimization=False,
1862
Adds a LSTM with attention mechanism to a model.
1864
The implementation is based on https://arxiv.org/abs/1409.0473, with
1865
a small difference in the order
1866
how we compute new attention context and new hidden state, similarly to
1867
https://arxiv.org/abs/1508.04025.
1869
The model uses encoder-decoder naming conventions,
1870
where the decoder is the sequence the op is iterating over,
1871
while computing the attention context over the encoder.
1873
model: ModelHelper object new operators would be added to
1875
decoder_inputs: the input sequence in a format T x N x D
1876
where T is sequence size, N - batch size and D - input dimension
1878
decoder_input_lengths: blob containing sequence lengths
1879
which would be passed to LSTMUnit operator
1881
initial_decoder_hidden_state: initial hidden state of LSTM
1883
initial_decoder_cell_state: initial cell state of LSTM
1885
initial_attention_weighted_encoder_context: initial attention context
1887
encoder_output_dim: dimension of encoder outputs
1889
encoder_outputs: the sequence, on which we compute the attention context
1892
encoder_lengths: a tensor with lengths of each encoder sequence in batch
1893
(may be None, meaning all encoder sequences are of same length)
1895
decoder_input_dim: input dimension (last dimension on decoder_inputs)
1897
decoder_state_dim: size of hidden states of LSTM
1899
attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
1900
Determines which type of attention mechanism to use.
1902
outputs_with_grads : position indices of output blobs which will receive
1903
external error gradient during backpropagation
1905
weighted_encoder_outputs: encoder outputs to be used to compute attention
1906
weights. In the basic case it's just linear transformation of
1907
encoder outputs (that the default, when weighted_encoder_outputs is None).
1908
However, it can be something more complicated - like a separate
1909
encoder network (for example, in case of convolutional encoder)
1911
lstm_memory_optimization: recompute LSTM activations on backward pass, so
1912
we don't need to store their values in forward passes
1914
attention_memory_optimization: recompute attention for backward pass
1916
forward_only: whether to create only forward pass
1918
cell = LSTMWithAttentionCell(
1919
encoder_output_dim=encoder_output_dim,
1920
encoder_outputs=encoder_outputs,
1921
encoder_lengths=encoder_lengths,
1922
decoder_input_dim=decoder_input_dim,
1923
decoder_state_dim=decoder_state_dim,
1925
attention_type=attention_type,
1926
weighted_encoder_outputs=weighted_encoder_outputs,
1927
forget_bias=forget_bias,
1928
lstm_memory_optimization=lstm_memory_optimization,
1929
attention_memory_optimization=attention_memory_optimization,
1930
forward_only=forward_only,
1933
initial_decoder_hidden_state,
1934
initial_decoder_cell_state,
1935
initial_attention_weighted_encoder_context,
1937
if attention_type == AttentionType.SoftCoverage:
1938
initial_states.append(cell.build_initial_coverage(model))
1939
_, result = cell.apply_over_sequence(
1941
inputs=decoder_inputs,
1942
seq_lengths=decoder_input_lengths,
1943
initial_states=initial_states,
1944
outputs_with_grads=outputs_with_grads,
1950
model, input_blob, seq_lengths, initial_states,
1951
dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
1952
memory_optimization=False, forget_bias=0.0, forward_only=False,
1953
drop_states=False, create_lstm=None):
1955
params.pop('create_lstm')
1956
if not isinstance(dim_out, list):
1957
return create_lstm(**params)
1958
elif len(dim_out) == 1:
1959
params['dim_out'] = dim_out[0]
1960
return create_lstm(**params)
1962
assert len(dim_out) != 0, "dim_out list can't be empty"
1963
assert return_params is False, "return_params not supported for layering"
1964
for i, output_dim in enumerate(dim_out):
1966
'dim_out': output_dim
1968
output, last_output, all_states, last_state = create_lstm(**params)
1970
'input_blob': output,
1971
'dim_in': output_dim,
1972
'initial_states': (last_output, last_state),
1973
'scope': scope + '_layer_{}'.format(i + 1)
1975
return output, last_output, all_states, last_state
1978
layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)