google-research
145 строк · 4.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for common variational dropout utilities."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21import absl.testing.parameterized as parameterized22import numpy as np23import tensorflow.compat.v1 as tf24
25import state_of_sparsity.layers.variational_dropout as vd26
27
28# Parameter sets to test the helper functions on. Size of the first dimension
29# of the parameters, size of the second dimension of the parameters, minimum
30# value the parameters should take, maximum value the parameters should take.
31HELPER_TEST = [(32, 80, -10, 10)]32
33
34@parameterized.parameters(HELPER_TEST)35class HelperTest(parameterized.TestCase):36
37def setUp(self): # pylint: disable=g-missing-super-call38tf.reset_default_graph()39
40def _get_weights(self, d, k, min_val, max_val):41theta = tf.random_uniform(42[d, k],43min_val,44max_val,45dtype=tf.float32)46log_sigma2 = tf.random_uniform(47[d, k],48min_val,49max_val,50dtype=tf.float32)51return (theta, log_sigma2)52
53def testHelper_ComputeLogAlpha(self, d, k, min_val, max_val):54# Fix the random seed55tf.set_random_seed(15)56
57theta, log_sigma2 = self._get_weights(d, k, min_val, max_val)58
59# Compute the log alpha values60log_alpha = vd.common.compute_log_alpha(log_sigma2, theta, value_limit=None)61
62sess = tf.Session()63log_sigma2, log_alpha, theta = sess.run(64[log_sigma2, log_alpha, theta])65
66# Verify the output shapes67self.assertEqual(log_sigma2.shape, (d, k))68self.assertEqual(log_alpha.shape, (d, k))69self.assertEqual(theta.shape, (d, k))70
71# Verify the calculated values72expected_log_alpha = log_sigma2 - np.log(np.power(theta, 2) + 1e-8)73self.assertTrue(74np.all(np.isclose(expected_log_alpha, log_alpha, rtol=1e-3)))75
76def testHelper_ComputeLogSigma2(self, d, k, min_val, max_val):77# Fix the random seed78tf.set_random_seed(15)79
80theta, log_alpha = self._get_weights(d, k, min_val, max_val)81
82# Compute the log \sigma^2 values83log_sigma2 = vd.common.compute_log_sigma2(log_alpha, theta)84
85sess = tf.Session()86log_sigma2, log_alpha, theta = sess.run(87[log_sigma2, log_alpha, theta])88
89# Verify the output shapes90self.assertEqual(log_sigma2.shape, (d, k))91self.assertEqual(log_alpha.shape, (d, k))92self.assertEqual(theta.shape, (d, k))93
94# Verify the calculated values95expected_log_sigma2 = log_alpha + np.log(np.power(theta, 2) + 1e-8)96self.assertTrue(97np.all(np.isclose(expected_log_sigma2, log_sigma2, rtol=1e-3)))98
99def testHelper_ComputeLogAlphaAndBack(self, d, k, min_val, max_val):100theta, true_log_sigma2 = self._get_weights(d, k, min_val, max_val)101
102# Compute the log alpha values103log_alpha = vd.common.compute_log_alpha(104true_log_sigma2, theta, value_limit=None)105
106# Compute the log \sigma^2 values107log_sigma2 = vd.common.compute_log_sigma2(log_alpha, theta)108
109sess = tf.Session()110true_log_sigma2, log_alpha, log_sigma2 = sess.run(111[true_log_sigma2, log_alpha, log_sigma2])112
113# Verify the output shapes114self.assertEqual(true_log_sigma2.shape, (d, k))115self.assertEqual(log_sigma2.shape, (d, k))116self.assertEqual(log_alpha.shape, (d, k))117
118# The calculated log \sigma^2 values should be the same as the119# ones that we calculate through the log \alpha values120for is_close in np.isclose(121true_log_sigma2, log_sigma2, rtol=1e-3).flatten():122self.assertTrue(is_close)123
124def testHelper_ThresholdLogAlphas(self, d, k, min_val, max_val):125theta, log_sigma2 = self._get_weights(d, k, min_val, max_val)126
127# Compute the log alpha values128value_limit = 8.129log_alpha = vd.common.compute_log_alpha(130log_sigma2, theta, value_limit=value_limit)131
132sess = tf.Session()133log_alpha = sess.run(log_alpha)134
135# Verify the output shapes136self.assertEqual(log_alpha.shape, (d, k))137
138# Verify that all log alpha values are within the valid range139for value in log_alpha.flatten():140self.assertLessEqual(value, value_limit)141self.assertGreaterEqual(value, -value_limit)142
143
144if __name__ == "__main__":145tf.test.main()146