7
from caffe2.python import brew, rnn_cell
10
class GRUCell(rnn_cell.RNNCell):
19
linear_before_reset=False,
22
super().__init__(**kwargs)
23
self.input_size = input_size
24
self.hidden_size = hidden_size
25
self.forget_bias = float(forget_bias)
26
self.memory_optimization = memory_optimization
27
self.drop_states = drop_states
28
self.linear_before_reset = linear_before_reset
44
hidden_t_prev = states[0]
47
input_t_reset, input_t_update, input_t_output = model.net.Split(
52
self.scope('input_t_reset'),
53
self.scope('input_t_update'),
54
self.scope('input_t_output'),
60
reset_gate_t = brew.fc(
63
self.scope('reset_gate_t'),
64
dim_in=self.hidden_size,
65
dim_out=self.hidden_size,
68
update_gate_t = brew.fc(
71
self.scope('update_gate_t'),
72
dim_in=self.hidden_size,
73
dim_out=self.hidden_size,
78
reset_gate_t = model.net.Sum(
79
[reset_gate_t, input_t_reset],
80
self.scope('reset_gate_t')
82
reset_gate_t_sigmoid = model.net.Sigmoid(
84
self.scope('reset_gate_t_sigmoid')
88
if self.linear_before_reset:
89
output_gate_fc = brew.fc(
92
self.scope('output_gate_t'),
93
dim_in=self.hidden_size,
94
dim_out=self.hidden_size,
97
output_gate_t = model.net.Mul(
98
[reset_gate_t_sigmoid, output_gate_fc],
99
self.scope('output_gate_t_mul')
102
modified_hidden_t_prev = model.net.Mul(
103
[reset_gate_t_sigmoid, hidden_t_prev],
104
self.scope('modified_hidden_t_prev')
106
output_gate_t = brew.fc(
108
modified_hidden_t_prev,
109
self.scope('output_gate_t'),
110
dim_in=self.hidden_size,
111
dim_out=self.hidden_size,
117
update_gate_t = model.net.Sum(
118
[update_gate_t, input_t_update],
119
self.scope('update_gate_t'),
121
output_gate_t = model.net.Sum(
122
[output_gate_t, input_t_output],
123
self.scope('output_gate_t_summed'),
127
gates_t, _gates_t_concat_dims = model.net.Concat(
134
self.scope('gates_t'),
135
self.scope('_gates_t_concat_dims'),
140
if seq_lengths is not None:
141
inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
143
inputs = [hidden_t_prev, gates_t, timestep]
145
hidden_t = model.net.GRUUnit(
147
list(self.get_state_names()),
148
forget_bias=self.forget_bias,
149
drop_states=self.drop_states,
150
sequence_lengths=(seq_lengths is not None),
152
model.net.AddExternalOutputs(hidden_t)
155
def prepare_input(self, model, input_blob):
160
dim_in=self.input_size,
161
dim_out=3 * self.hidden_size,
165
def get_state_names(self):
166
return (self.scope('hidden_t'),)
168
def get_output_dim(self):
169
return self.hidden_size
172
GRU = functools.partial(rnn_cell._LSTM, GRUCell)