google-research
273 строки · 7.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for variational dropout layers."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21import functools22
23import absl.testing.parameterized as parameterized24import numpy as np25import tensorflow.compat.v1 as tf26
27import state_of_sparsity.layers.variational_dropout as vd28
29
30# Parameters to test the matmul primitive on. First dimensions of
31# first matrix, second dimension of first matrix/first dimension
32# of second matrix, second dimension of second matrix.
33MATMUL_TEST_PARAMETERS = [(32, 200, 100)]34
35
36@parameterized.parameters(MATMUL_TEST_PARAMETERS)37class MatmulTest(vd.test_base.TestCase):38
39def testMatmulTrain(self, m, n, k):40self.assertSameResult(41self.set_no_epsilon(vd.nn.matmul_train),42tf.matmul,43[m, n],44[n, k])45
46def testMatmulTrain_NonDeterministic(self, m, n, k):47self.assertNonDeterministic(48vd.nn.matmul_train,49[m, n],50[n, k])51
52def testMatmulEval(self, m, n, k):53self.assertSameResult(54self.set_no_epsilon(vd.nn.matmul_eval),55tf.matmul,56[m, n],57[n, k])58
59def testMatmulEval_Deterministic(self, m, n, k):60self.assertDeterministic(61vd.nn.matmul_eval,62[m, n],63[n, k])64
65
66# Parameters to test the batched matmul primitive on. First dimension
67# of the first matrix, second dimension of the first matrix, third
68# dimension of the first matrix/first dimenions of the second matrix,
69# second dimension of the second matrix.
70BROADCAST_MATMUL_TEST_PARAMETERS = [(32, 20, 200, 100),71(1, 10, 100, 50)]72
73
74@parameterized.parameters(BROADCAST_MATMUL_TEST_PARAMETERS)75class BroadcastMatmulTest(vd.test_base.TestCase):76
77def set_axes(self, ref_op):78return functools.partial(ref_op, axes=[[2], [0]])79
80def testBroadcastMatmulTrain(self, m, t, n, k):81self.assertSameResult(82self.set_no_epsilon(vd.nn.broadcast_matmul_train),83self.set_axes(tf.tensordot),84[m, t, n],85[n, k])86
87def testBroadcastMatmulTrain_NonDeterministic(self, m, t, n, k):88self.assertNonDeterministic(89vd.nn.broadcast_matmul_train,90[m, t, n],91[n, k])92
93def testBroadcastMatmulEval(self, m, t, n, k):94self.assertSameResult(95self.set_no_epsilon(vd.nn.broadcast_matmul_eval),96self.set_axes(tf.tensordot),97[m, t, n],98[n, k])99
100def testBroadcastMatmulEval_Deterministic(self, m, t, n, k):101self.assertDeterministic(102vd.nn.broadcast_matmul_eval,103[m, t, n],104[n, k])105
106
107# Parameters to test the conv2d primitive with. Input tensor batch size,
108# input channels, input height, input width, size of the convolutional
109# filters, number of output channels.
110CONV2D_TEST_PARAMETERS = [(32, 3, 224, 224, 3, 64)]111
112
113@parameterized.parameters(CONV2D_TEST_PARAMETERS)114class Conv2dTest(vd.test_base.TestCase):115
116def testConv2dTrain(117self,118batch_size,119in_channels,120height,121width,122filter_size,123out_channels):124conv2d_train = self.set_no_epsilon(vd.nn.conv2d_train)125self.assertSameResult(126self.fix_padding_and_strides(conv2d_train),127self.fix_padding_and_strides(tf.nn.conv2d),128[batch_size, height, width, in_channels],129[filter_size, filter_size, in_channels, out_channels])130
131def testConv2dTrain_NonDeterministic(132self,133batch_size,134in_channels,135height,136width,137filter_size,138out_channels):139self.assertNonDeterministic(140self.fix_padding_and_strides(vd.nn.conv2d_train),141[batch_size, height, width, in_channels],142[filter_size, filter_size, in_channels, out_channels])143
144def testConv2dEval(145self,146batch_size,147in_channels,148height,149width,150filter_size,151out_channels):152conv2d_eval = self.set_no_epsilon(vd.nn.conv2d_eval)153self.assertSameResult(154self.fix_padding_and_strides(conv2d_eval),155self.fix_padding_and_strides(tf.nn.conv2d),156[batch_size, height, width, in_channels],157[filter_size, filter_size, in_channels, out_channels])158
159def testConv2dEval_Deterministic(160self,161batch_size,162in_channels,163height,164width,165filter_size,166out_channels):167self.assertDeterministic(168self.fix_padding_and_strides(vd.nn.conv2d_eval),169[batch_size, height, width, in_channels],170[filter_size, filter_size, in_channels, out_channels])171
172
173# Parameters for the embedding lookup tests. Batch size, sequence length,
174# vocabulary size, embedding vector size
175EMBEDDING_TEST_PARAMETERS = [(32, 25, 10000, 512)]176
177
178@parameterized.parameters(EMBEDDING_TEST_PARAMETERS)179class TestEmbeddingLookup(vd.test_base.TestCase):180
181def testEmbeddingLookupTrain(182self,183batch_size,184seq_length,185vocab_size,186embedding_size):187embedding_lookup_train = self.set_no_epsilon(vd.nn.embedding_lookup_train)188self.assertSameResult(189self.flip_input_wrapper(embedding_lookup_train),190self.flip_input_wrapper(tf.nn.embedding_lookup),191[batch_size, seq_length, 1],192[vocab_size, embedding_size],193data_dtype=tf.int32)194
195def testEmbeddingLookupTrain_NonDeterministic(196self,197batch_size,198seq_length,199vocab_size,200embedding_size):201self.assertNonDeterministic(202self.flip_input_wrapper(vd.nn.embedding_lookup_train),203[batch_size, seq_length, 1],204[vocab_size, embedding_size],205data_dtype=tf.int32)206
207def testEmbeddingLookupEval(208self,209batch_size,210seq_length,211vocab_size,212embedding_size):213embedding_lookup_eval = self.set_no_epsilon(vd.nn.embedding_lookup_eval)214self.assertSameResult(215self.flip_input_wrapper(embedding_lookup_eval),216self.flip_input_wrapper(tf.nn.embedding_lookup),217[batch_size, seq_length, 1],218[vocab_size, embedding_size],219data_dtype=tf.int32)220
221def testEmbeddingLookupEval_Deterministic(222self,223batch_size,224seq_length,225vocab_size,226embedding_size):227self.assertDeterministic(228self.flip_input_wrapper(vd.nn.embedding_lookup_eval),229[batch_size, seq_length, 1],230[vocab_size, embedding_size],231data_dtype=tf.int32)232
233
234# Dimensions of the parameters to calculate the KL divergence over.
235DKL_TEST_PARAMETERS = [(256, 128)]236
237
238@parameterized.parameters(DKL_TEST_PARAMETERS)239class TestNegativeDKL(vd.test_base.TestCase):240
241def testNegativeDKL(self, d, k):242self.fix_random_seeds()243
244theta = tf.random_normal([d, k], dtype=tf.float32)245log_sigma2 = tf.random_normal([d, k], dtype=tf.float32)246weights = (theta, log_sigma2)247
248output = vd.nn.negative_dkl(weights)249
250result, theta, log_sigma2 = self.evaluate([output, theta, log_sigma2])251
252# Verify the output shape253self.assertEqual(result.shape, ())254
255# Compute the expected results256k1, k2, k3 = 0.63576, 1.8732, 1.48695257c = -k1258
259# Compute the log alpha values260log_alpha = log_sigma2 - np.log(np.power(theta, 2) + 1e-8)261
262def sigmoid(x):263return 1.0 /(1.0 + np.exp(-x))264
265term_1 = k1 * sigmoid(k2 + k3*log_alpha)266term_2 = -0.5 * np.log1p(np.exp(-log_alpha))267expected_result = -np.sum(term_1 + term_2 + c)268
269self.assertAllClose(result, expected_result)270
271
272if __name__ == "__main__":273tf.test.main()274