google-research

Форк
0
388 строк · 13.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines variational dropout recurrent layers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22
from state_of_sparsity.layers.utils import rnn_checks
23
from state_of_sparsity.layers.variational_dropout import common
24
from tensorflow.python.framework import ops  # pylint: disable=g-direct-tensorflow-import
25

26

27
# TODO(tgale): This RNN cell and the following LSTM cell share a large
28
# amount of common code. It would be best if we could extract a
29
# common base class, and the implement the recurrent functionality
30
# from scratch (as opposed to  deriving from the non-variational
31
# reccurent cells.
32
class RNNCell(tf.nn.rnn_cell.BasicRNNCell):
33
  """RNN cell trained with variational dropout.
34

35
  This class implements an RNN cell trained with variational dropout following
36
  the technique from https://arxiv.org/abs/1708.00077.
37
  """
38

39
  def __init__(self,
40
               kernel_weights,
41
               bias_weights,
42
               num_units,
43
               training=True,
44
               threshold=3.0,
45
               eps=common.EPSILON,
46
               activation=None,
47
               name=None):
48
    R"""Initialize the variational RNN cell.
49

50
    Args:
51
      kernel_weights: 2-tuple of Tensors, where the first tensor is the \theta
52
        values and the second contains the log of the \sigma^2 values.
53
      bias_weights: The weight matrix to use for the biases.
54
      num_units: int, The number of units in the RNN cell.
55
      training: boolean, Whether the model is training or being evaluated.
56
      threshold: Weights with a log \alpha_{ij} value greater than this will
57
       be set to zero.
58
      eps: Small constant value to add to the term inside the square-root
59
        operation to avoid NaNs.
60
      activation: Activation function of the inner states. Defaults to `tanh`.
61
      name: String, the name of the layer.
62

63
    Raises:
64
      RuntimeError: If the input variational_params is not a 2-tuple of Tensors
65
        that have the same shape.
66
    """
67
    super(RNNCell, self).__init__(
68
        num_units=num_units,
69
        activation=activation,
70
        reuse=None,
71
        name=name,
72
        dtype=None)
73

74
    # Verify and save the weight matrices
75
    rnn_checks.check_rnn_weight_shapes(kernel_weights, bias_weights, num_units)
76
    self._variational_params = kernel_weights
77
    self._bias = bias_weights
78

79
    self._training = training
80
    self._threshold = threshold
81
    self._eps = eps
82

83
  def build(self, inputs_shape):
84
    """Initializes noise samples for the RNN.
85

86
    Args:
87
      inputs_shape: The shape of the input batch.
88

89
    Raises:
90
      RuntimeError: If the first and last dimensions of the input shape are
91
        not defined.
92
    """
93
    inputs_shape = inputs_shape.as_list()
94
    if inputs_shape[-1] is None:
95
      raise RuntimeError("Expected inputs.shape[-1] to be known, saw shape {}"
96
                         .format(inputs_shape))
97
    if inputs_shape[0] is None:
98
      raise RuntimeError("Expected inputs.shape[0] to be known, saw shape {}"
99
                         .format(inputs_shape))
100
    self._batch_size = inputs_shape[0]
101
    self._data_size = inputs_shape[-1]
102

103
    with ops.init_scope():
104
      if self._training:
105
        # Setup the random noise which should be sampled once per-iteration
106
        self._input_noise = tf.random_normal(
107
            [self._batch_size, self._num_units])
108
        self._hidden_noise = tf.random_normal(
109
            [self._num_units, self._num_units])
110
      else:
111
        # Mask the weights ahead of time for efficiency
112
        theta, log_sigma2 = self._variational_params
113
        log_alpha = common.compute_log_alpha(
114
            log_sigma2, theta, self._eps, value_limit=None)
115

116
        weight_mask = tf.cast(tf.less(log_alpha, self._threshold), tf.float32)
117
        self._masked_weights = weight_mask * theta
118
    self.built = True
119

120
  def _compute_gate_inputs(
121
      self,
122
      inputs,
123
      state,
124
      input_parameters,
125
      hidden_parameters,
126
      input_noise,
127
      hidden_noise):
128
    """Compute a gate pre-activation with variational dropout.
129

130
    Args:
131
      inputs: The input batch feature timesteps.
132
      state: The input hidden state from the last timestep.
133
      input_parameters: The posterior parameters for the input-to-hidden
134
        connections.
135
      hidden_parameters: The posterior parameters for the hidden-to-hidden
136
        connections.
137
      input_noise: Normally distributed random noise used to for the
138
        sampling of pre-activations from the input-to-hidden weight
139
        posterior.
140
      hidden_noise: Normally distribution random noise use for the
141
        sampling of pre-activations from the hidden-to-hidden weight
142
        posterior.
143

144
    Returns:
145
      A tf.Tensor containing the computed pre-activations.
146
    """
147
    input_theta, input_log_sigma2 = input_parameters
148
    hidden_theta, hidden_log_sigma2 = hidden_parameters
149

150
    # Compute the input-to-hidden connections
151
    input_mu = tf.matmul(inputs, input_theta)
152
    input_sigma = tf.sqrt(tf.matmul(
153
        tf.square(inputs),
154
        tf.exp(input_log_sigma2)) + self._eps)
155

156
    input_to_hidden = input_mu + input_sigma * input_noise
157

158
    # Compute the hidden-to-hidden connections
159
    hidden_sigma = tf.sqrt(tf.exp(hidden_log_sigma2) + self._eps)
160
    hidden_weights = hidden_theta + hidden_sigma * hidden_noise
161
    hidden_to_hidden = tf.matmul(state, hidden_weights)
162

163
    # Sum the results
164
    return tf.add(input_to_hidden, hidden_to_hidden)
165

166
  def _forward_train(self, inputs, state):
167
    # Split the input-to-hidden and hidden-to-hidden weights
168
    theta, log_sigma2 = self._variational_params
169
    input_theta, hidden_theta = tf.split(
170
        theta, [self._data_size, self._num_units])
171
    input_log_sigma2, hidden_log_sigma2 = tf.split(
172
        log_sigma2, [self._data_size, self._num_units])
173

174
    gate_inputs = self._compute_gate_inputs(
175
        inputs,
176
        state,
177
        (input_theta, input_log_sigma2),
178
        (hidden_theta, hidden_log_sigma2),
179
        self._input_noise,
180
        self._hidden_noise)
181

182
    # Add bias, and apply the activation
183
    gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
184
    output = self._activation(gate_inputs)
185
    return output, output
186

187
  def _forward_eval(self, inputs, state):
188
    # At eval time, we use the masked mean values for the input-to-hidden
189
    # and hidden-to-hidden weights.
190
    gate_inputs = tf.matmul(
191
        tf.concat([inputs, state], axis=1),
192
        self._masked_weights)
193

194
    gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
195
    output = self._activation(gate_inputs)
196
    return output, output
197

198
  def call(self, inputs, state):
199
    if self._training:
200
      return self._forward_train(inputs, state)
201
    return self._forward_eval(inputs, state)
202

203

204
class LSTMCell(tf.nn.rnn_cell.LSTMCell):
205
  """LSTM cell trained with variational dropout.
206

207
  This class implements an LSTM cell trained with variational dropout following
208
  the technique from https://arxiv.org/abs/1708.00077.
209
  """
210

211
  def __init__(self,
212
               kernel_weights,
213
               bias_weights,
214
               num_units,
215
               training=True,
216
               threshold=3.0,
217
               eps=common.EPSILON,
218
               forget_bias=1.0,
219
               activation=None,
220
               name="lstm_cell"):
221
    R"""Initialize the LSTM cell.
222

223
    Args:
224
      kernel_weights: 2-tuple of Tensors, where the first tensor is the \theta
225
        values and the second contains the log of the \sigma^2 values.
226
      bias_weights: the weight matrix to use for the biases.
227
      num_units: int, The number of units in the LSTM cell.
228
      training: boolean, Whether the model is training or being evaluated.
229
      threshold: Weights with a log \alpha_{ij} value greater than this will
230
        be set to zero.
231
      eps: Small constant value to add to the term inside the square-root
232
        operation to avoid NaNs.
233
      forget_bias: float, The bias added to forget gates (see above).
234
      activation: Activation function of the inner states. Defaults to `tanh`.
235
        It could also be string that is within Keras activation function names.
236
      name: String, the name of the layer.
237

238
    Raises:
239
      RuntimeError: If the input variational_params is not a 2-tuple of Tensors
240
       that have the same shape.
241
    """
242
    super(LSTMCell, self).__init__(
243
        num_units=num_units,
244
        forget_bias=forget_bias,
245
        state_is_tuple=True,
246
        activation=activation,
247
        name=name)
248

249
    # Verify and save the weight matrices
250
    rnn_checks.check_lstm_weight_shapes(kernel_weights, bias_weights, num_units)
251
    self._variational_params = kernel_weights
252
    self._bias = bias_weights
253

254
    self._training = training
255
    self._threshold = threshold
256
    self._eps = eps
257

258
  def build(self, inputs_shape):
259
    """Initializes noise samples for the LSTM.
260

261
    Args:
262
      inputs_shape: The shape of the input batch.
263

264
    Raises:
265
      RuntimeError: If the first and last dimensions of the input shape are
266
        not defined.
267
    """
268
    inputs_shape = inputs_shape.as_list()
269
    if inputs_shape[-1] is None:
270
      raise RuntimeError("Expected inputs.shape[-1] to be known, saw shape {}"
271
                         .format(inputs_shape))
272
    if inputs_shape[0] is None:
273
      raise RuntimeError("Expected inputs.shape[0] to be known, saw shape {}"
274
                         .format(inputs_shape))
275
    self._batch_size = inputs_shape[0]
276
    self._data_size = inputs_shape[-1]
277

278
    with ops.init_scope():
279
      if self._training:
280
        # Setup the random noise which should be sampled once per-iteration
281
        self._input_noise = tf.random_normal(
282
            [self._batch_size, 4 * self._num_units])
283
        self._hidden_noise = tf.random_normal(
284
            [self._num_units, 4 * self._num_units])
285
      else:
286
        # Mask the weights ahead of time for efficiency
287
        theta, log_sigma2 = self._variational_params
288
        log_alpha = common.compute_log_alpha(
289
            log_sigma2, theta, self._eps, value_limit=None)
290

291
        weight_mask = tf.cast(tf.less(log_alpha, self._threshold), tf.float32)
292
        self._masked_weights = weight_mask * theta
293
    self.built = True
294

295
  def _compute_gate_inputs(
296
      self,
297
      inputs,
298
      state,
299
      input_parameters,
300
      hidden_parameters,
301
      input_noise,
302
      hidden_noise):
303
    """Compute a gate pre-activation with variational dropout.
304

305
    Args:
306
      inputs: The input batch feature timesteps.
307
      state: The input hidden state from the last timestep.
308
      input_parameters: The posterior parameters for the input-to-hidden
309
        connections.
310
      hidden_parameters: The posterior parameters for the hidden-to-hidden
311
        connections.
312
      input_noise: Normally distributed random noise used to for the
313
        sampling of pre-activations from the input-to-hidden weight
314
        posterior.
315
      hidden_noise: Normally distribution random noise use for the
316
        sampling of pre-activations from the hidden-to-hidden weight
317
        posterior.
318

319
    Returns:
320
      A tf.Tensor containing the computed pre-activations.
321
    """
322
    input_theta, input_log_sigma2 = input_parameters
323
    hidden_theta, hidden_log_sigma2 = hidden_parameters
324

325
    # Compute the input-to-hidden connections
326
    input_mu = tf.matmul(inputs, input_theta)
327
    input_sigma = tf.sqrt(tf.matmul(
328
        tf.square(inputs),
329
        tf.exp(input_log_sigma2)) + self._eps)
330

331
    input_to_hidden = input_mu + input_sigma * input_noise
332

333
    # Compute the hidden-to-hidden connections
334
    hidden_sigma = tf.sqrt(tf.exp(hidden_log_sigma2) + self._eps)
335
    hidden_weights = hidden_theta + hidden_sigma * hidden_noise
336
    hidden_to_hidden = tf.matmul(state, hidden_weights)
337

338
    # Sum the results
339
    return tf.add(input_to_hidden, hidden_to_hidden)
340

341
  def _finish_lstm_computation(self, lstm_matrix, c_prev):
342
    i, j, f, o = tf.split(
343
        value=lstm_matrix,
344
        num_or_size_splits=4,
345
        axis=1)
346

347
    sigmoid = tf.sigmoid
348
    c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
349
         self._activation(j))
350
    m = sigmoid(o) * self._activation(c)
351

352
    new_state = tf.nn.rnn_cell.LSTMStateTuple(c, m)
353
    return m, new_state
354

355
  def _forward_train(self, inputs, state):
356
    # Split the input-to-hidden and hidden-to-hidden weights
357
    theta, log_sigma2 = self._variational_params
358
    input_theta, hidden_theta = tf.split(
359
        theta, [self._data_size, self._num_units])
360
    input_log_sigma2, hidden_log_sigma2 = tf.split(
361
        log_sigma2, [self._data_size, self._num_units])
362

363
    (c_prev, m_prev) = state
364
    lstm_matrix = self._compute_gate_inputs(
365
        inputs,
366
        m_prev,
367
        (input_theta, input_log_sigma2),
368
        (hidden_theta, hidden_log_sigma2),
369
        self._input_noise,
370
        self._hidden_noise)
371
    lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
372

373
    return self._finish_lstm_computation(lstm_matrix, c_prev)
374

375
  def _forward_eval(self, inputs, state):
376
    (c_prev, m_prev) = state
377

378
    lstm_matrix = tf.matmul(
379
        tf.concat([inputs, m_prev], axis=1),
380
        self._masked_weights)
381
    lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
382

383
    return self._finish_lstm_computation(lstm_matrix, c_prev)
384

385
  def call(self, inputs, state):
386
    if self._training:
387
      return self._forward_train(inputs, state)
388
    return self._forward_eval(inputs, state)
389

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.