google-research

Форк
0
184 строки · 4.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for variational dropout reccurrent cells."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import absl.testing.parameterized as parameterized
22
import tensorflow.compat.v1 as tf
23

24
import state_of_sparsity.layers.variational_dropout as vd
25

26

27
# batch_size, seq_length, num_units, data_size
28
RNN_TEST_PARAMETERS = [(32, 25, 100, 33)]
29

30

31
@parameterized.parameters(RNN_TEST_PARAMETERS)
32
class RNNCellTest(vd.test_base.RNNTestCase):
33

34
  def testRNNCell_Train(
35
      self,
36
      batch_size,
37
      seq_length,
38
      num_units,
39
      data_size):
40
    rnn_cell = self.set_no_epsilon(vd.rnn.RNNCell)
41
    self.assertSameResult(
42
        self.set_training(rnn_cell),
43
        tf.nn.rnn_cell.BasicRNNCell,
44
        num_units,
45
        [batch_size, seq_length, data_size],
46
        [data_size + num_units, num_units],
47
        [num_units])
48

49
  def testRNNCell_Eval(
50
      self,
51
      batch_size,
52
      seq_length,
53
      num_units,
54
      data_size):
55
    rnn_cell = self.set_no_epsilon(vd.rnn.RNNCell)
56
    self.assertSameResult(
57
        self.set_evaluation(rnn_cell),
58
        tf.nn.rnn_cell.BasicRNNCell,
59
        num_units,
60
        [batch_size, seq_length, data_size],
61
        [data_size + num_units, num_units],
62
        [num_units])
63

64
  def testRNNCell_SameNoiseForAllTimeSteps(
65
      self,
66
      batch_size,
67
      seq_length,
68
      num_units,
69
      data_size):
70
    self.fix_random_seeds()
71
    self.assertSameNoiseForAllTimesteps(
72
        self.set_training(vd.rnn.RNNCell),
73
        num_units,
74
        [batch_size, seq_length, data_size],
75
        [data_size + num_units, num_units],
76
        [num_units])
77

78
  def testRNNCell_DifferentNoiseAcrossBatches(
79
      self,
80
      batch_size,
81
      seq_length,
82
      num_units,
83
      data_size):
84
    self.fix_random_seeds()
85
    self.assertDifferentNoiseAcrossBatches(
86
        self.set_training(vd.rnn.RNNCell),
87
        num_units,
88
        [batch_size, seq_length, data_size],
89
        [data_size + num_units, num_units],
90
        [num_units])
91

92
  def testRNNCell_DeterministicEval(
93
      self,
94
      batch_size,
95
      seq_length,
96
      num_units,
97
      data_size):
98
    self.fix_random_seeds()
99
    self.assertDeterministic(
100
        self.set_evaluation(vd.rnn.RNNCell),
101
        num_units,
102
        [batch_size, seq_length, data_size],
103
        [data_size + num_units, num_units],
104
        [num_units])
105

106

107
@parameterized.parameters(RNN_TEST_PARAMETERS)
108
class LSTMCellTest(vd.test_base.RNNTestCase):
109

110
  def testLSTMCell_Train(
111
      self,
112
      batch_size,
113
      seq_length,
114
      num_units,
115
      data_size):
116
    lstm_cell = self.set_no_epsilon(vd.rnn.LSTMCell)
117
    self.assertSameResult(
118
        self.set_training(lstm_cell),
119
        tf.nn.rnn_cell.LSTMCell,
120
        num_units,
121
        [batch_size, seq_length, data_size],
122
        [data_size + num_units, 4 * num_units],
123
        [4 * num_units])
124

125
  def testLSTMCell_Eval(
126
      self,
127
      batch_size,
128
      seq_length,
129
      num_units,
130
      data_size):
131
    lstm_cell = self.set_no_epsilon(vd.rnn.LSTMCell)
132
    self.assertSameResult(
133
        self.set_evaluation(lstm_cell),
134
        tf.nn.rnn_cell.LSTMCell,
135
        num_units,
136
        [batch_size, seq_length, data_size],
137
        [data_size + num_units, 4 * num_units],
138
        [4 * num_units])
139

140
  def testLSTMCell_SameNoiseForAllTimeSteps(
141
      self,
142
      batch_size,
143
      seq_length,
144
      num_units,
145
      data_size):
146
    self.fix_random_seeds()
147
    self.assertSameNoiseForAllTimesteps(
148
        self.set_training(vd.rnn.LSTMCell),
149
        num_units,
150
        [batch_size, seq_length, data_size],
151
        [data_size + num_units, 4 * num_units],
152
        [4 * num_units])
153

154
  def testLSTMCell_DifferentNoiseAcrossBatches(
155
      self,
156
      batch_size,
157
      seq_length,
158
      num_units,
159
      data_size):
160
    self.fix_random_seeds()
161
    self.assertDifferentNoiseAcrossBatches(
162
        self.set_training(vd.rnn.LSTMCell),
163
        num_units,
164
        [batch_size, seq_length, data_size],
165
        [data_size + num_units, 4 * num_units],
166
        [4 * num_units])
167

168
  def testLSTMCell_DeterministicEval(
169
      self,
170
      batch_size,
171
      seq_length,
172
      num_units,
173
      data_size):
174
    self.fix_random_seeds()
175
    self.assertDeterministic(
176
        self.set_evaluation(vd.rnn.LSTMCell),
177
        num_units,
178
        [batch_size, seq_length, data_size],
179
        [data_size + num_units, 4 * num_units],
180
        [4 * num_units])
181

182

183
if __name__ == "__main__":
184
  tf.test.main()
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.