google-research

Форк
0
273 строки · 7.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for variational dropout layers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import functools
22

23
import absl.testing.parameterized as parameterized
24
import numpy as np
25
import tensorflow.compat.v1 as tf
26

27
import state_of_sparsity.layers.variational_dropout as vd
28

29

30
# Parameters to test the matmul primitive on. First dimensions of
31
# first matrix, second dimension of first matrix/first dimension
32
# of second matrix, second dimension of second matrix.
33
MATMUL_TEST_PARAMETERS = [(32, 200, 100)]
34

35

36
@parameterized.parameters(MATMUL_TEST_PARAMETERS)
37
class MatmulTest(vd.test_base.TestCase):
38

39
  def testMatmulTrain(self, m, n, k):
40
    self.assertSameResult(
41
        self.set_no_epsilon(vd.nn.matmul_train),
42
        tf.matmul,
43
        [m, n],
44
        [n, k])
45

46
  def testMatmulTrain_NonDeterministic(self, m, n, k):
47
    self.assertNonDeterministic(
48
        vd.nn.matmul_train,
49
        [m, n],
50
        [n, k])
51

52
  def testMatmulEval(self, m, n, k):
53
    self.assertSameResult(
54
        self.set_no_epsilon(vd.nn.matmul_eval),
55
        tf.matmul,
56
        [m, n],
57
        [n, k])
58

59
  def testMatmulEval_Deterministic(self, m, n, k):
60
    self.assertDeterministic(
61
        vd.nn.matmul_eval,
62
        [m, n],
63
        [n, k])
64

65

66
# Parameters to test the batched matmul primitive on. First dimension
67
# of the first matrix, second dimension of the first matrix, third
68
# dimension of the first matrix/first dimenions of the second matrix,
69
# second dimension of the second matrix.
70
BROADCAST_MATMUL_TEST_PARAMETERS = [(32, 20, 200, 100),
71
                                    (1, 10, 100, 50)]
72

73

74
@parameterized.parameters(BROADCAST_MATMUL_TEST_PARAMETERS)
75
class BroadcastMatmulTest(vd.test_base.TestCase):
76

77
  def set_axes(self, ref_op):
78
    return functools.partial(ref_op, axes=[[2], [0]])
79

80
  def testBroadcastMatmulTrain(self, m, t, n, k):
81
    self.assertSameResult(
82
        self.set_no_epsilon(vd.nn.broadcast_matmul_train),
83
        self.set_axes(tf.tensordot),
84
        [m, t, n],
85
        [n, k])
86

87
  def testBroadcastMatmulTrain_NonDeterministic(self, m, t, n, k):
88
    self.assertNonDeterministic(
89
        vd.nn.broadcast_matmul_train,
90
        [m, t, n],
91
        [n, k])
92

93
  def testBroadcastMatmulEval(self, m, t, n, k):
94
    self.assertSameResult(
95
        self.set_no_epsilon(vd.nn.broadcast_matmul_eval),
96
        self.set_axes(tf.tensordot),
97
        [m, t, n],
98
        [n, k])
99

100
  def testBroadcastMatmulEval_Deterministic(self, m, t, n, k):
101
    self.assertDeterministic(
102
        vd.nn.broadcast_matmul_eval,
103
        [m, t, n],
104
        [n, k])
105

106

107
# Parameters to test the conv2d primitive with. Input tensor batch size,
108
# input channels, input height, input width, size of the convolutional
109
# filters, number of output channels.
110
CONV2D_TEST_PARAMETERS = [(32, 3, 224, 224, 3, 64)]
111

112

113
@parameterized.parameters(CONV2D_TEST_PARAMETERS)
114
class Conv2dTest(vd.test_base.TestCase):
115

116
  def testConv2dTrain(
117
      self,
118
      batch_size,
119
      in_channels,
120
      height,
121
      width,
122
      filter_size,
123
      out_channels):
124
    conv2d_train = self.set_no_epsilon(vd.nn.conv2d_train)
125
    self.assertSameResult(
126
        self.fix_padding_and_strides(conv2d_train),
127
        self.fix_padding_and_strides(tf.nn.conv2d),
128
        [batch_size, height, width, in_channels],
129
        [filter_size, filter_size, in_channels, out_channels])
130

131
  def testConv2dTrain_NonDeterministic(
132
      self,
133
      batch_size,
134
      in_channels,
135
      height,
136
      width,
137
      filter_size,
138
      out_channels):
139
    self.assertNonDeterministic(
140
        self.fix_padding_and_strides(vd.nn.conv2d_train),
141
        [batch_size, height, width, in_channels],
142
        [filter_size, filter_size, in_channels, out_channels])
143

144
  def testConv2dEval(
145
      self,
146
      batch_size,
147
      in_channels,
148
      height,
149
      width,
150
      filter_size,
151
      out_channels):
152
    conv2d_eval = self.set_no_epsilon(vd.nn.conv2d_eval)
153
    self.assertSameResult(
154
        self.fix_padding_and_strides(conv2d_eval),
155
        self.fix_padding_and_strides(tf.nn.conv2d),
156
        [batch_size, height, width, in_channels],
157
        [filter_size, filter_size, in_channels, out_channels])
158

159
  def testConv2dEval_Deterministic(
160
      self,
161
      batch_size,
162
      in_channels,
163
      height,
164
      width,
165
      filter_size,
166
      out_channels):
167
    self.assertDeterministic(
168
        self.fix_padding_and_strides(vd.nn.conv2d_eval),
169
        [batch_size, height, width, in_channels],
170
        [filter_size, filter_size, in_channels, out_channels])
171

172

173
# Parameters for the embedding lookup tests. Batch size, sequence length,
174
# vocabulary size, embedding vector size
175
EMBEDDING_TEST_PARAMETERS = [(32, 25, 10000, 512)]
176

177

178
@parameterized.parameters(EMBEDDING_TEST_PARAMETERS)
179
class TestEmbeddingLookup(vd.test_base.TestCase):
180

181
  def testEmbeddingLookupTrain(
182
      self,
183
      batch_size,
184
      seq_length,
185
      vocab_size,
186
      embedding_size):
187
    embedding_lookup_train = self.set_no_epsilon(vd.nn.embedding_lookup_train)
188
    self.assertSameResult(
189
        self.flip_input_wrapper(embedding_lookup_train),
190
        self.flip_input_wrapper(tf.nn.embedding_lookup),
191
        [batch_size, seq_length, 1],
192
        [vocab_size, embedding_size],
193
        data_dtype=tf.int32)
194

195
  def testEmbeddingLookupTrain_NonDeterministic(
196
      self,
197
      batch_size,
198
      seq_length,
199
      vocab_size,
200
      embedding_size):
201
    self.assertNonDeterministic(
202
        self.flip_input_wrapper(vd.nn.embedding_lookup_train),
203
        [batch_size, seq_length, 1],
204
        [vocab_size, embedding_size],
205
        data_dtype=tf.int32)
206

207
  def testEmbeddingLookupEval(
208
      self,
209
      batch_size,
210
      seq_length,
211
      vocab_size,
212
      embedding_size):
213
    embedding_lookup_eval = self.set_no_epsilon(vd.nn.embedding_lookup_eval)
214
    self.assertSameResult(
215
        self.flip_input_wrapper(embedding_lookup_eval),
216
        self.flip_input_wrapper(tf.nn.embedding_lookup),
217
        [batch_size, seq_length, 1],
218
        [vocab_size, embedding_size],
219
        data_dtype=tf.int32)
220

221
  def testEmbeddingLookupEval_Deterministic(
222
      self,
223
      batch_size,
224
      seq_length,
225
      vocab_size,
226
      embedding_size):
227
    self.assertDeterministic(
228
        self.flip_input_wrapper(vd.nn.embedding_lookup_eval),
229
        [batch_size, seq_length, 1],
230
        [vocab_size, embedding_size],
231
        data_dtype=tf.int32)
232

233

234
# Dimensions of the parameters to calculate the KL divergence over.
235
DKL_TEST_PARAMETERS = [(256, 128)]
236

237

238
@parameterized.parameters(DKL_TEST_PARAMETERS)
239
class TestNegativeDKL(vd.test_base.TestCase):
240

241
  def testNegativeDKL(self, d, k):
242
    self.fix_random_seeds()
243

244
    theta = tf.random_normal([d, k], dtype=tf.float32)
245
    log_sigma2 = tf.random_normal([d, k], dtype=tf.float32)
246
    weights = (theta, log_sigma2)
247

248
    output = vd.nn.negative_dkl(weights)
249

250
    result, theta, log_sigma2 = self.evaluate([output, theta, log_sigma2])
251

252
    # Verify the output shape
253
    self.assertEqual(result.shape, ())
254

255
    # Compute the expected results
256
    k1, k2, k3 = 0.63576, 1.8732, 1.48695
257
    c = -k1
258

259
    # Compute the log alpha values
260
    log_alpha = log_sigma2 - np.log(np.power(theta, 2) + 1e-8)
261

262
    def sigmoid(x):
263
      return 1.0 /(1.0 + np.exp(-x))
264

265
    term_1 = k1 * sigmoid(k2 + k3*log_alpha)
266
    term_2 = -0.5 * np.log1p(np.exp(-log_alpha))
267
    expected_result = -np.sum(term_1 + term_2 + c)
268

269
    self.assertAllClose(result, expected_result)
270

271

272
if __name__ == "__main__":
273
  tf.test.main()
274

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.