google-research

Форк
0
221 строка · 7.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines recurrent network layers that train using l0 regularization."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22

23
from state_of_sparsity.layers.l0_regularization import common
24
from state_of_sparsity.layers.utils import rnn_checks
25
from tensorflow.python.framework import ops  # pylint: disable=g-direct-tensorflow-import
26

27

28
class RNNCell(tf.nn.rnn_cell.BasicRNNCell):
29
  """RNN cell trained with l0 regularization.
30

31
  This class implements an RNN cell trained with l0 regularization following
32
  the technique from https://arxiv.org/abs/1712.01312.
33
  """
34

35
  def __init__(
36
      self,
37
      kernel_weights,
38
      bias_weights,
39
      num_units,
40
      beta=common.BETA,
41
      gamma=common.GAMMA,
42
      zeta=common.ZETA,
43
      training=True,
44
      eps=common.EPSILON,
45
      activation=None,
46
      name=None):
47
    R"""Initialize the RNN cell.
48

49
    Args:
50
      kernel_weights: 2-tuple of Tensors, where the first tensor is the unscaled
51
        weight values and the second is the log of the alpha values for the hard
52
        concrete distribution.
53
      bias_weights: The weight matrix to use for the biases.
54
      num_units: int, The number of units in the RNN cell.
55
      beta: The beta parameter, which controls the "temperature" of
56
        the distribution. Defaults to 2/3 from the above paper.
57
      gamma: The gamma parameter, which controls the lower bound of the
58
        stretched distribution. Defaults to -0.1 from the above paper.
59
      zeta: The zeta parameters, which controls the upper bound of the
60
        stretched distribution. Defaults to 1.1 from the above paper.
61
      training: boolean, Whether the model is training or being evaluated.
62
      eps: Small constant value to add to the term inside the square-root
63
        operation to avoid NaNs.
64
      activation: Activation function of the inner states. Defaults to `tanh`.
65
      name: String, the name of the layer.
66

67
    Raises:
68
      RuntimeError: If the input kernel_weights is not a 2-tuple of Tensors
69
        that have the same shape.
70
    """
71
    super(RNNCell, self).__init__(
72
        num_units=num_units,
73
        activation=activation,
74
        reuse=None,
75
        name=name,
76
        dtype=None)
77

78
    # Verify and save the weight matrices
79
    rnn_checks.check_rnn_weight_shapes(kernel_weights, bias_weights, num_units)
80
    self._weight_parameters = kernel_weights
81
    self._bias = bias_weights
82

83
    self._beta = beta
84
    self._gamma = gamma
85
    self._zeta = zeta
86

87
    self._training = training
88
    self._eps = eps
89

90
  def build(self, _):
91
    """Initializes the weights for the RNN."""
92
    with ops.init_scope():
93
      theta, log_alpha = self._weight_parameters
94
      if self._training:
95
        weight_noise = common.hard_concrete_sample(
96
            log_alpha,
97
            self._beta,
98
            self._gamma,
99
            self._zeta,
100
            self._eps)
101
      else:
102
        weight_noise = common.hard_concrete_mean(
103
            log_alpha,
104
            self._gamma,
105
            self._zeta)
106
      self._weights = weight_noise * theta
107
    self.built = True
108

109
  def call(self, inputs, state):
110
    gate_inputs = tf.matmul(
111
        tf.concat([inputs, state], axis=1),
112
        self._weights)
113
    gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
114
    output = self._activation(gate_inputs)
115
    return output, output
116

117

118
class LSTMCell(tf.nn.rnn_cell.LSTMCell):
119
  """LSTM cell trained with l0 regularization.
120

121
  This class implements an LSTM cell trained with l0 regularization following
122
  the technique from https://arxiv.org/abs/1712.01312.
123
  """
124

125
  def __init__(
126
      self,
127
      kernel_weights,
128
      bias_weights,
129
      num_units,
130
      beta=common.BETA,
131
      gamma=common.GAMMA,
132
      zeta=common.ZETA,
133
      training=True,
134
      eps=common.EPSILON,
135
      forget_bias=1.0,
136
      activation=None,
137
      name="lstm_cell"):
138
    R"""Initialize the LSTM cell.
139

140
    Args:
141
      kernel_weights: 2-tuple of Tensors, where the first tensor is the unscaled
142
        weight values and the second is the log of the alpha values for the hard
143
        concrete distribution.
144
      bias_weights: the weight matrix to use for the biases.
145
      num_units: int, The number of units in the LSTM cell.
146
      beta: The beta parameter, which controls the "temperature" of
147
        the distribution. Defaults to 2/3 from the above paper.
148
      gamma: The gamma parameter, which controls the lower bound of the
149
        stretched distribution. Defaults to -0.1 from the above paper.
150
      zeta: The zeta parameters, which controls the upper bound of the
151
        stretched distribution. Defaults to 1.1 from the above paper.
152
      training: boolean, Whether the model is training or being evaluated.
153
      eps: Small constant value to add to the term inside the square-root
154
        operation to avoid NaNs.
155
      forget_bias: float, The bias added to forget gates (see above).
156
      activation: Activation function of the inner states. Defaults to `tanh`.
157
        It could also be string that is within Keras activation function names.
158
      name: String, the name of the layer.
159

160
    Raises:
161
      RuntimeError: If the input kernel_weights is not a 2-tuple of Tensors
162
       that have the same shape.
163
    """
164
    super(LSTMCell, self).__init__(
165
        num_units=num_units,
166
        forget_bias=forget_bias,
167
        state_is_tuple=True,
168
        activation=activation,
169
        name=name)
170

171
    # Verify and save the weight matrices
172
    rnn_checks.check_lstm_weight_shapes(kernel_weights, bias_weights, num_units)
173
    self._weight_parameters = kernel_weights
174
    self._bias = bias_weights
175

176
    self._beta = beta
177
    self._gamma = gamma
178
    self._zeta = zeta
179

180
    self._training = training
181
    self._eps = eps
182

183
  def build(self, _):
184
    """Initialize the weights for the LSTM."""
185
    with ops.init_scope():
186
      theta, log_alpha = self._weight_parameters
187
      if self._training:
188
        weight_noise = common.hard_concrete_sample(
189
            log_alpha,
190
            self._beta,
191
            self._gamma,
192
            self._zeta,
193
            self._eps)
194
      else:
195
        weight_noise = common.hard_concrete_mean(
196
            log_alpha,
197
            self._gamma,
198
            self._zeta)
199
      self._weights = weight_noise * theta
200
    self.built = True
201

202
  def call(self, inputs, state):
203
    (c_prev, m_prev) = state
204
    lstm_matrix = tf.matmul(
205
        tf.concat([inputs, m_prev], axis=1),
206
        self._weights)
207
    lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
208

209
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
210
    i, j, f, o = tf.split(
211
        value=lstm_matrix,
212
        num_or_size_splits=4,
213
        axis=1)
214

215
    sigmoid = tf.sigmoid
216
    c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
217
         self._activation(j))
218
    m = sigmoid(o) * self._activation(c)
219

220
    new_state = tf.nn.rnn_cell.LSTMStateTuple(c, m)
221
    return m, new_state
222

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.