google-research
221 строка · 7.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines recurrent network layers that train using l0 regularization."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22
23from state_of_sparsity.layers.l0_regularization import common
24from state_of_sparsity.layers.utils import rnn_checks
25from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
26
27
28class RNNCell(tf.nn.rnn_cell.BasicRNNCell):
29"""RNN cell trained with l0 regularization.
30
31This class implements an RNN cell trained with l0 regularization following
32the technique from https://arxiv.org/abs/1712.01312.
33"""
34
35def __init__(
36self,
37kernel_weights,
38bias_weights,
39num_units,
40beta=common.BETA,
41gamma=common.GAMMA,
42zeta=common.ZETA,
43training=True,
44eps=common.EPSILON,
45activation=None,
46name=None):
47R"""Initialize the RNN cell.
48
49Args:
50kernel_weights: 2-tuple of Tensors, where the first tensor is the unscaled
51weight values and the second is the log of the alpha values for the hard
52concrete distribution.
53bias_weights: The weight matrix to use for the biases.
54num_units: int, The number of units in the RNN cell.
55beta: The beta parameter, which controls the "temperature" of
56the distribution. Defaults to 2/3 from the above paper.
57gamma: The gamma parameter, which controls the lower bound of the
58stretched distribution. Defaults to -0.1 from the above paper.
59zeta: The zeta parameters, which controls the upper bound of the
60stretched distribution. Defaults to 1.1 from the above paper.
61training: boolean, Whether the model is training or being evaluated.
62eps: Small constant value to add to the term inside the square-root
63operation to avoid NaNs.
64activation: Activation function of the inner states. Defaults to `tanh`.
65name: String, the name of the layer.
66
67Raises:
68RuntimeError: If the input kernel_weights is not a 2-tuple of Tensors
69that have the same shape.
70"""
71super(RNNCell, self).__init__(
72num_units=num_units,
73activation=activation,
74reuse=None,
75name=name,
76dtype=None)
77
78# Verify and save the weight matrices
79rnn_checks.check_rnn_weight_shapes(kernel_weights, bias_weights, num_units)
80self._weight_parameters = kernel_weights
81self._bias = bias_weights
82
83self._beta = beta
84self._gamma = gamma
85self._zeta = zeta
86
87self._training = training
88self._eps = eps
89
90def build(self, _):
91"""Initializes the weights for the RNN."""
92with ops.init_scope():
93theta, log_alpha = self._weight_parameters
94if self._training:
95weight_noise = common.hard_concrete_sample(
96log_alpha,
97self._beta,
98self._gamma,
99self._zeta,
100self._eps)
101else:
102weight_noise = common.hard_concrete_mean(
103log_alpha,
104self._gamma,
105self._zeta)
106self._weights = weight_noise * theta
107self.built = True
108
109def call(self, inputs, state):
110gate_inputs = tf.matmul(
111tf.concat([inputs, state], axis=1),
112self._weights)
113gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
114output = self._activation(gate_inputs)
115return output, output
116
117
118class LSTMCell(tf.nn.rnn_cell.LSTMCell):
119"""LSTM cell trained with l0 regularization.
120
121This class implements an LSTM cell trained with l0 regularization following
122the technique from https://arxiv.org/abs/1712.01312.
123"""
124
125def __init__(
126self,
127kernel_weights,
128bias_weights,
129num_units,
130beta=common.BETA,
131gamma=common.GAMMA,
132zeta=common.ZETA,
133training=True,
134eps=common.EPSILON,
135forget_bias=1.0,
136activation=None,
137name="lstm_cell"):
138R"""Initialize the LSTM cell.
139
140Args:
141kernel_weights: 2-tuple of Tensors, where the first tensor is the unscaled
142weight values and the second is the log of the alpha values for the hard
143concrete distribution.
144bias_weights: the weight matrix to use for the biases.
145num_units: int, The number of units in the LSTM cell.
146beta: The beta parameter, which controls the "temperature" of
147the distribution. Defaults to 2/3 from the above paper.
148gamma: The gamma parameter, which controls the lower bound of the
149stretched distribution. Defaults to -0.1 from the above paper.
150zeta: The zeta parameters, which controls the upper bound of the
151stretched distribution. Defaults to 1.1 from the above paper.
152training: boolean, Whether the model is training or being evaluated.
153eps: Small constant value to add to the term inside the square-root
154operation to avoid NaNs.
155forget_bias: float, The bias added to forget gates (see above).
156activation: Activation function of the inner states. Defaults to `tanh`.
157It could also be string that is within Keras activation function names.
158name: String, the name of the layer.
159
160Raises:
161RuntimeError: If the input kernel_weights is not a 2-tuple of Tensors
162that have the same shape.
163"""
164super(LSTMCell, self).__init__(
165num_units=num_units,
166forget_bias=forget_bias,
167state_is_tuple=True,
168activation=activation,
169name=name)
170
171# Verify and save the weight matrices
172rnn_checks.check_lstm_weight_shapes(kernel_weights, bias_weights, num_units)
173self._weight_parameters = kernel_weights
174self._bias = bias_weights
175
176self._beta = beta
177self._gamma = gamma
178self._zeta = zeta
179
180self._training = training
181self._eps = eps
182
183def build(self, _):
184"""Initialize the weights for the LSTM."""
185with ops.init_scope():
186theta, log_alpha = self._weight_parameters
187if self._training:
188weight_noise = common.hard_concrete_sample(
189log_alpha,
190self._beta,
191self._gamma,
192self._zeta,
193self._eps)
194else:
195weight_noise = common.hard_concrete_mean(
196log_alpha,
197self._gamma,
198self._zeta)
199self._weights = weight_noise * theta
200self.built = True
201
202def call(self, inputs, state):
203(c_prev, m_prev) = state
204lstm_matrix = tf.matmul(
205tf.concat([inputs, m_prev], axis=1),
206self._weights)
207lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
208
209# i = input_gate, j = new_input, f = forget_gate, o = output_gate
210i, j, f, o = tf.split(
211value=lstm_matrix,
212num_or_size_splits=4,
213axis=1)
214
215sigmoid = tf.sigmoid
216c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
217self._activation(j))
218m = sigmoid(o) * self._activation(c)
219
220new_state = tf.nn.rnn_cell.LSTMStateTuple(c, m)
221return m, new_state
222