pytorch

Форк
0
/
gru_cell.py 
167 строк · 5.0 Кб
1

2

3

4

5

6
import functools
7
from caffe2.python import brew, rnn_cell
8

9

10
class GRUCell(rnn_cell.RNNCell):
11

12
    def __init__(
13
        self,
14
        input_size,
15
        hidden_size,
16
        forget_bias,  # Currently unused!  Values here will be ignored.
17
        memory_optimization,
18
        drop_states=False,
19
        linear_before_reset=False,
20
        **kwargs
21
    ):
22
        super().__init__(**kwargs)
23
        self.input_size = input_size
24
        self.hidden_size = hidden_size
25
        self.forget_bias = float(forget_bias)
26
        self.memory_optimization = memory_optimization
27
        self.drop_states = drop_states
28
        self.linear_before_reset = linear_before_reset
29

30
    # Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
31
    # (reset gate -> output_gate)
32
    # So, much of the logic to calculate the reset gate output and modified
33
    # output gate input is set here, in the graph definition.
34
    # The remaining logic lives in gru_unit_op.{h,cc}.
35
    def _apply(
36
        self,
37
        model,
38
        input_t,
39
        seq_lengths,
40
        states,
41
        timestep,
42
        extra_inputs=None,
43
    ):
44
        hidden_t_prev = states[0]
45

46
        # Split input tensors to get inputs for each gate.
47
        input_t_reset, input_t_update, input_t_output = model.net.Split(
48
            [
49
                input_t,
50
            ],
51
            [
52
                self.scope('input_t_reset'),
53
                self.scope('input_t_update'),
54
                self.scope('input_t_output'),
55
            ],
56
            axis=2,
57
        )
58

59
        # Fully connected layers for reset and update gates.
60
        reset_gate_t = brew.fc(
61
            model,
62
            hidden_t_prev,
63
            self.scope('reset_gate_t'),
64
            dim_in=self.hidden_size,
65
            dim_out=self.hidden_size,
66
            axis=2,
67
        )
68
        update_gate_t = brew.fc(
69
            model,
70
            hidden_t_prev,
71
            self.scope('update_gate_t'),
72
            dim_in=self.hidden_size,
73
            dim_out=self.hidden_size,
74
            axis=2,
75
        )
76

77
        # Calculating the modified hidden state going into output gate.
78
        reset_gate_t = model.net.Sum(
79
            [reset_gate_t, input_t_reset],
80
            self.scope('reset_gate_t')
81
        )
82
        reset_gate_t_sigmoid = model.net.Sigmoid(
83
            reset_gate_t,
84
            self.scope('reset_gate_t_sigmoid')
85
        )
86

87
        # `self.linear_before_reset = True` matches cudnn semantics
88
        if self.linear_before_reset:
89
            output_gate_fc = brew.fc(
90
                model,
91
                hidden_t_prev,
92
                self.scope('output_gate_t'),
93
                dim_in=self.hidden_size,
94
                dim_out=self.hidden_size,
95
                axis=2,
96
            )
97
            output_gate_t = model.net.Mul(
98
                [reset_gate_t_sigmoid, output_gate_fc],
99
                self.scope('output_gate_t_mul')
100
            )
101
        else:
102
            modified_hidden_t_prev = model.net.Mul(
103
                [reset_gate_t_sigmoid, hidden_t_prev],
104
                self.scope('modified_hidden_t_prev')
105
            )
106
            output_gate_t = brew.fc(
107
                model,
108
                modified_hidden_t_prev,
109
                self.scope('output_gate_t'),
110
                dim_in=self.hidden_size,
111
                dim_out=self.hidden_size,
112
                axis=2,
113
            )
114

115
        # Add input contributions to update and output gate.
116
        # We already (in-place) added input contributions to the reset gate.
117
        update_gate_t = model.net.Sum(
118
            [update_gate_t, input_t_update],
119
            self.scope('update_gate_t'),
120
        )
121
        output_gate_t = model.net.Sum(
122
            [output_gate_t, input_t_output],
123
            self.scope('output_gate_t_summed'),
124
        )
125

126
        # Join gate outputs and add input contributions
127
        gates_t, _gates_t_concat_dims = model.net.Concat(
128
            [
129
                reset_gate_t,
130
                update_gate_t,
131
                output_gate_t,
132
            ],
133
            [
134
                self.scope('gates_t'),
135
                self.scope('_gates_t_concat_dims'),
136
            ],
137
            axis=2,
138
        )
139

140
        if seq_lengths is not None:
141
            inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
142
        else:
143
            inputs = [hidden_t_prev, gates_t, timestep]
144

145
        hidden_t = model.net.GRUUnit(
146
            inputs,
147
            list(self.get_state_names()),
148
            forget_bias=self.forget_bias,
149
            drop_states=self.drop_states,
150
            sequence_lengths=(seq_lengths is not None),
151
        )
152
        model.net.AddExternalOutputs(hidden_t)
153
        return (hidden_t,)
154

155
    def prepare_input(self, model, input_blob):
156
        return brew.fc(
157
            model,
158
            input_blob,
159
            self.scope('i2h'),
160
            dim_in=self.input_size,
161
            dim_out=3 * self.hidden_size,
162
            axis=2,
163
        )
164

165
    def get_state_names(self):
166
        return (self.scope('hidden_t'),)
167

168
    def get_output_dim(self):
169
        return self.hidden_size
170

171

172
GRU = functools.partial(rnn_cell._LSTM, GRUCell)
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.