google-research
184 строки · 4.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for variational dropout reccurrent cells."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import absl.testing.parameterized as parameterized
22import tensorflow.compat.v1 as tf
23
24import state_of_sparsity.layers.variational_dropout as vd
25
26
27# batch_size, seq_length, num_units, data_size
28RNN_TEST_PARAMETERS = [(32, 25, 100, 33)]
29
30
31@parameterized.parameters(RNN_TEST_PARAMETERS)
32class RNNCellTest(vd.test_base.RNNTestCase):
33
34def testRNNCell_Train(
35self,
36batch_size,
37seq_length,
38num_units,
39data_size):
40rnn_cell = self.set_no_epsilon(vd.rnn.RNNCell)
41self.assertSameResult(
42self.set_training(rnn_cell),
43tf.nn.rnn_cell.BasicRNNCell,
44num_units,
45[batch_size, seq_length, data_size],
46[data_size + num_units, num_units],
47[num_units])
48
49def testRNNCell_Eval(
50self,
51batch_size,
52seq_length,
53num_units,
54data_size):
55rnn_cell = self.set_no_epsilon(vd.rnn.RNNCell)
56self.assertSameResult(
57self.set_evaluation(rnn_cell),
58tf.nn.rnn_cell.BasicRNNCell,
59num_units,
60[batch_size, seq_length, data_size],
61[data_size + num_units, num_units],
62[num_units])
63
64def testRNNCell_SameNoiseForAllTimeSteps(
65self,
66batch_size,
67seq_length,
68num_units,
69data_size):
70self.fix_random_seeds()
71self.assertSameNoiseForAllTimesteps(
72self.set_training(vd.rnn.RNNCell),
73num_units,
74[batch_size, seq_length, data_size],
75[data_size + num_units, num_units],
76[num_units])
77
78def testRNNCell_DifferentNoiseAcrossBatches(
79self,
80batch_size,
81seq_length,
82num_units,
83data_size):
84self.fix_random_seeds()
85self.assertDifferentNoiseAcrossBatches(
86self.set_training(vd.rnn.RNNCell),
87num_units,
88[batch_size, seq_length, data_size],
89[data_size + num_units, num_units],
90[num_units])
91
92def testRNNCell_DeterministicEval(
93self,
94batch_size,
95seq_length,
96num_units,
97data_size):
98self.fix_random_seeds()
99self.assertDeterministic(
100self.set_evaluation(vd.rnn.RNNCell),
101num_units,
102[batch_size, seq_length, data_size],
103[data_size + num_units, num_units],
104[num_units])
105
106
107@parameterized.parameters(RNN_TEST_PARAMETERS)
108class LSTMCellTest(vd.test_base.RNNTestCase):
109
110def testLSTMCell_Train(
111self,
112batch_size,
113seq_length,
114num_units,
115data_size):
116lstm_cell = self.set_no_epsilon(vd.rnn.LSTMCell)
117self.assertSameResult(
118self.set_training(lstm_cell),
119tf.nn.rnn_cell.LSTMCell,
120num_units,
121[batch_size, seq_length, data_size],
122[data_size + num_units, 4 * num_units],
123[4 * num_units])
124
125def testLSTMCell_Eval(
126self,
127batch_size,
128seq_length,
129num_units,
130data_size):
131lstm_cell = self.set_no_epsilon(vd.rnn.LSTMCell)
132self.assertSameResult(
133self.set_evaluation(lstm_cell),
134tf.nn.rnn_cell.LSTMCell,
135num_units,
136[batch_size, seq_length, data_size],
137[data_size + num_units, 4 * num_units],
138[4 * num_units])
139
140def testLSTMCell_SameNoiseForAllTimeSteps(
141self,
142batch_size,
143seq_length,
144num_units,
145data_size):
146self.fix_random_seeds()
147self.assertSameNoiseForAllTimesteps(
148self.set_training(vd.rnn.LSTMCell),
149num_units,
150[batch_size, seq_length, data_size],
151[data_size + num_units, 4 * num_units],
152[4 * num_units])
153
154def testLSTMCell_DifferentNoiseAcrossBatches(
155self,
156batch_size,
157seq_length,
158num_units,
159data_size):
160self.fix_random_seeds()
161self.assertDifferentNoiseAcrossBatches(
162self.set_training(vd.rnn.LSTMCell),
163num_units,
164[batch_size, seq_length, data_size],
165[data_size + num_units, 4 * num_units],
166[4 * num_units])
167
168def testLSTMCell_DeterministicEval(
169self,
170batch_size,
171seq_length,
172num_units,
173data_size):
174self.fix_random_seeds()
175self.assertDeterministic(
176self.set_evaluation(vd.rnn.LSTMCell),
177num_units,
178[batch_size, seq_length, data_size],
179[data_size + num_units, 4 * num_units],
180[4 * num_units])
181
182
183if __name__ == "__main__":
184tf.test.main()
185