google-research
388 строк · 13.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines variational dropout recurrent layers."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22from state_of_sparsity.layers.utils import rnn_checks
23from state_of_sparsity.layers.variational_dropout import common
24from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
25
26
27# TODO(tgale): This RNN cell and the following LSTM cell share a large
28# amount of common code. It would be best if we could extract a
29# common base class, and the implement the recurrent functionality
30# from scratch (as opposed to deriving from the non-variational
31# reccurent cells.
32class RNNCell(tf.nn.rnn_cell.BasicRNNCell):
33"""RNN cell trained with variational dropout.
34
35This class implements an RNN cell trained with variational dropout following
36the technique from https://arxiv.org/abs/1708.00077.
37"""
38
39def __init__(self,
40kernel_weights,
41bias_weights,
42num_units,
43training=True,
44threshold=3.0,
45eps=common.EPSILON,
46activation=None,
47name=None):
48R"""Initialize the variational RNN cell.
49
50Args:
51kernel_weights: 2-tuple of Tensors, where the first tensor is the \theta
52values and the second contains the log of the \sigma^2 values.
53bias_weights: The weight matrix to use for the biases.
54num_units: int, The number of units in the RNN cell.
55training: boolean, Whether the model is training or being evaluated.
56threshold: Weights with a log \alpha_{ij} value greater than this will
57be set to zero.
58eps: Small constant value to add to the term inside the square-root
59operation to avoid NaNs.
60activation: Activation function of the inner states. Defaults to `tanh`.
61name: String, the name of the layer.
62
63Raises:
64RuntimeError: If the input variational_params is not a 2-tuple of Tensors
65that have the same shape.
66"""
67super(RNNCell, self).__init__(
68num_units=num_units,
69activation=activation,
70reuse=None,
71name=name,
72dtype=None)
73
74# Verify and save the weight matrices
75rnn_checks.check_rnn_weight_shapes(kernel_weights, bias_weights, num_units)
76self._variational_params = kernel_weights
77self._bias = bias_weights
78
79self._training = training
80self._threshold = threshold
81self._eps = eps
82
83def build(self, inputs_shape):
84"""Initializes noise samples for the RNN.
85
86Args:
87inputs_shape: The shape of the input batch.
88
89Raises:
90RuntimeError: If the first and last dimensions of the input shape are
91not defined.
92"""
93inputs_shape = inputs_shape.as_list()
94if inputs_shape[-1] is None:
95raise RuntimeError("Expected inputs.shape[-1] to be known, saw shape {}"
96.format(inputs_shape))
97if inputs_shape[0] is None:
98raise RuntimeError("Expected inputs.shape[0] to be known, saw shape {}"
99.format(inputs_shape))
100self._batch_size = inputs_shape[0]
101self._data_size = inputs_shape[-1]
102
103with ops.init_scope():
104if self._training:
105# Setup the random noise which should be sampled once per-iteration
106self._input_noise = tf.random_normal(
107[self._batch_size, self._num_units])
108self._hidden_noise = tf.random_normal(
109[self._num_units, self._num_units])
110else:
111# Mask the weights ahead of time for efficiency
112theta, log_sigma2 = self._variational_params
113log_alpha = common.compute_log_alpha(
114log_sigma2, theta, self._eps, value_limit=None)
115
116weight_mask = tf.cast(tf.less(log_alpha, self._threshold), tf.float32)
117self._masked_weights = weight_mask * theta
118self.built = True
119
120def _compute_gate_inputs(
121self,
122inputs,
123state,
124input_parameters,
125hidden_parameters,
126input_noise,
127hidden_noise):
128"""Compute a gate pre-activation with variational dropout.
129
130Args:
131inputs: The input batch feature timesteps.
132state: The input hidden state from the last timestep.
133input_parameters: The posterior parameters for the input-to-hidden
134connections.
135hidden_parameters: The posterior parameters for the hidden-to-hidden
136connections.
137input_noise: Normally distributed random noise used to for the
138sampling of pre-activations from the input-to-hidden weight
139posterior.
140hidden_noise: Normally distribution random noise use for the
141sampling of pre-activations from the hidden-to-hidden weight
142posterior.
143
144Returns:
145A tf.Tensor containing the computed pre-activations.
146"""
147input_theta, input_log_sigma2 = input_parameters
148hidden_theta, hidden_log_sigma2 = hidden_parameters
149
150# Compute the input-to-hidden connections
151input_mu = tf.matmul(inputs, input_theta)
152input_sigma = tf.sqrt(tf.matmul(
153tf.square(inputs),
154tf.exp(input_log_sigma2)) + self._eps)
155
156input_to_hidden = input_mu + input_sigma * input_noise
157
158# Compute the hidden-to-hidden connections
159hidden_sigma = tf.sqrt(tf.exp(hidden_log_sigma2) + self._eps)
160hidden_weights = hidden_theta + hidden_sigma * hidden_noise
161hidden_to_hidden = tf.matmul(state, hidden_weights)
162
163# Sum the results
164return tf.add(input_to_hidden, hidden_to_hidden)
165
166def _forward_train(self, inputs, state):
167# Split the input-to-hidden and hidden-to-hidden weights
168theta, log_sigma2 = self._variational_params
169input_theta, hidden_theta = tf.split(
170theta, [self._data_size, self._num_units])
171input_log_sigma2, hidden_log_sigma2 = tf.split(
172log_sigma2, [self._data_size, self._num_units])
173
174gate_inputs = self._compute_gate_inputs(
175inputs,
176state,
177(input_theta, input_log_sigma2),
178(hidden_theta, hidden_log_sigma2),
179self._input_noise,
180self._hidden_noise)
181
182# Add bias, and apply the activation
183gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
184output = self._activation(gate_inputs)
185return output, output
186
187def _forward_eval(self, inputs, state):
188# At eval time, we use the masked mean values for the input-to-hidden
189# and hidden-to-hidden weights.
190gate_inputs = tf.matmul(
191tf.concat([inputs, state], axis=1),
192self._masked_weights)
193
194gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
195output = self._activation(gate_inputs)
196return output, output
197
198def call(self, inputs, state):
199if self._training:
200return self._forward_train(inputs, state)
201return self._forward_eval(inputs, state)
202
203
204class LSTMCell(tf.nn.rnn_cell.LSTMCell):
205"""LSTM cell trained with variational dropout.
206
207This class implements an LSTM cell trained with variational dropout following
208the technique from https://arxiv.org/abs/1708.00077.
209"""
210
211def __init__(self,
212kernel_weights,
213bias_weights,
214num_units,
215training=True,
216threshold=3.0,
217eps=common.EPSILON,
218forget_bias=1.0,
219activation=None,
220name="lstm_cell"):
221R"""Initialize the LSTM cell.
222
223Args:
224kernel_weights: 2-tuple of Tensors, where the first tensor is the \theta
225values and the second contains the log of the \sigma^2 values.
226bias_weights: the weight matrix to use for the biases.
227num_units: int, The number of units in the LSTM cell.
228training: boolean, Whether the model is training or being evaluated.
229threshold: Weights with a log \alpha_{ij} value greater than this will
230be set to zero.
231eps: Small constant value to add to the term inside the square-root
232operation to avoid NaNs.
233forget_bias: float, The bias added to forget gates (see above).
234activation: Activation function of the inner states. Defaults to `tanh`.
235It could also be string that is within Keras activation function names.
236name: String, the name of the layer.
237
238Raises:
239RuntimeError: If the input variational_params is not a 2-tuple of Tensors
240that have the same shape.
241"""
242super(LSTMCell, self).__init__(
243num_units=num_units,
244forget_bias=forget_bias,
245state_is_tuple=True,
246activation=activation,
247name=name)
248
249# Verify and save the weight matrices
250rnn_checks.check_lstm_weight_shapes(kernel_weights, bias_weights, num_units)
251self._variational_params = kernel_weights
252self._bias = bias_weights
253
254self._training = training
255self._threshold = threshold
256self._eps = eps
257
258def build(self, inputs_shape):
259"""Initializes noise samples for the LSTM.
260
261Args:
262inputs_shape: The shape of the input batch.
263
264Raises:
265RuntimeError: If the first and last dimensions of the input shape are
266not defined.
267"""
268inputs_shape = inputs_shape.as_list()
269if inputs_shape[-1] is None:
270raise RuntimeError("Expected inputs.shape[-1] to be known, saw shape {}"
271.format(inputs_shape))
272if inputs_shape[0] is None:
273raise RuntimeError("Expected inputs.shape[0] to be known, saw shape {}"
274.format(inputs_shape))
275self._batch_size = inputs_shape[0]
276self._data_size = inputs_shape[-1]
277
278with ops.init_scope():
279if self._training:
280# Setup the random noise which should be sampled once per-iteration
281self._input_noise = tf.random_normal(
282[self._batch_size, 4 * self._num_units])
283self._hidden_noise = tf.random_normal(
284[self._num_units, 4 * self._num_units])
285else:
286# Mask the weights ahead of time for efficiency
287theta, log_sigma2 = self._variational_params
288log_alpha = common.compute_log_alpha(
289log_sigma2, theta, self._eps, value_limit=None)
290
291weight_mask = tf.cast(tf.less(log_alpha, self._threshold), tf.float32)
292self._masked_weights = weight_mask * theta
293self.built = True
294
295def _compute_gate_inputs(
296self,
297inputs,
298state,
299input_parameters,
300hidden_parameters,
301input_noise,
302hidden_noise):
303"""Compute a gate pre-activation with variational dropout.
304
305Args:
306inputs: The input batch feature timesteps.
307state: The input hidden state from the last timestep.
308input_parameters: The posterior parameters for the input-to-hidden
309connections.
310hidden_parameters: The posterior parameters for the hidden-to-hidden
311connections.
312input_noise: Normally distributed random noise used to for the
313sampling of pre-activations from the input-to-hidden weight
314posterior.
315hidden_noise: Normally distribution random noise use for the
316sampling of pre-activations from the hidden-to-hidden weight
317posterior.
318
319Returns:
320A tf.Tensor containing the computed pre-activations.
321"""
322input_theta, input_log_sigma2 = input_parameters
323hidden_theta, hidden_log_sigma2 = hidden_parameters
324
325# Compute the input-to-hidden connections
326input_mu = tf.matmul(inputs, input_theta)
327input_sigma = tf.sqrt(tf.matmul(
328tf.square(inputs),
329tf.exp(input_log_sigma2)) + self._eps)
330
331input_to_hidden = input_mu + input_sigma * input_noise
332
333# Compute the hidden-to-hidden connections
334hidden_sigma = tf.sqrt(tf.exp(hidden_log_sigma2) + self._eps)
335hidden_weights = hidden_theta + hidden_sigma * hidden_noise
336hidden_to_hidden = tf.matmul(state, hidden_weights)
337
338# Sum the results
339return tf.add(input_to_hidden, hidden_to_hidden)
340
341def _finish_lstm_computation(self, lstm_matrix, c_prev):
342i, j, f, o = tf.split(
343value=lstm_matrix,
344num_or_size_splits=4,
345axis=1)
346
347sigmoid = tf.sigmoid
348c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
349self._activation(j))
350m = sigmoid(o) * self._activation(c)
351
352new_state = tf.nn.rnn_cell.LSTMStateTuple(c, m)
353return m, new_state
354
355def _forward_train(self, inputs, state):
356# Split the input-to-hidden and hidden-to-hidden weights
357theta, log_sigma2 = self._variational_params
358input_theta, hidden_theta = tf.split(
359theta, [self._data_size, self._num_units])
360input_log_sigma2, hidden_log_sigma2 = tf.split(
361log_sigma2, [self._data_size, self._num_units])
362
363(c_prev, m_prev) = state
364lstm_matrix = self._compute_gate_inputs(
365inputs,
366m_prev,
367(input_theta, input_log_sigma2),
368(hidden_theta, hidden_log_sigma2),
369self._input_noise,
370self._hidden_noise)
371lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
372
373return self._finish_lstm_computation(lstm_matrix, c_prev)
374
375def _forward_eval(self, inputs, state):
376(c_prev, m_prev) = state
377
378lstm_matrix = tf.matmul(
379tf.concat([inputs, m_prev], axis=1),
380self._masked_weights)
381lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
382
383return self._finish_lstm_computation(lstm_matrix, c_prev)
384
385def call(self, inputs, state):
386if self._training:
387return self._forward_train(inputs, state)
388return self._forward_eval(inputs, state)
389