google-research

Форк
0
145 строк · 4.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for common variational dropout utilities."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import absl.testing.parameterized as parameterized
22
import numpy as np
23
import tensorflow.compat.v1 as tf
24

25
import state_of_sparsity.layers.variational_dropout as vd
26

27

28
# Parameter sets to test the helper functions on. Size of the first dimension
29
# of the parameters, size of the second dimension of the parameters, minimum
30
# value the parameters should take, maximum value the parameters should take.
31
HELPER_TEST = [(32, 80, -10, 10)]
32

33

34
@parameterized.parameters(HELPER_TEST)
35
class HelperTest(parameterized.TestCase):
36

37
  def setUp(self):  # pylint: disable=g-missing-super-call
38
    tf.reset_default_graph()
39

40
  def _get_weights(self, d, k, min_val, max_val):
41
    theta = tf.random_uniform(
42
        [d, k],
43
        min_val,
44
        max_val,
45
        dtype=tf.float32)
46
    log_sigma2 = tf.random_uniform(
47
        [d, k],
48
        min_val,
49
        max_val,
50
        dtype=tf.float32)
51
    return (theta, log_sigma2)
52

53
  def testHelper_ComputeLogAlpha(self, d, k, min_val, max_val):
54
    # Fix the random seed
55
    tf.set_random_seed(15)
56

57
    theta, log_sigma2 = self._get_weights(d, k, min_val, max_val)
58

59
    # Compute the log alpha values
60
    log_alpha = vd.common.compute_log_alpha(log_sigma2, theta, value_limit=None)
61

62
    sess = tf.Session()
63
    log_sigma2, log_alpha, theta = sess.run(
64
        [log_sigma2, log_alpha, theta])
65

66
    # Verify the output shapes
67
    self.assertEqual(log_sigma2.shape, (d, k))
68
    self.assertEqual(log_alpha.shape, (d, k))
69
    self.assertEqual(theta.shape, (d, k))
70

71
    # Verify the calculated values
72
    expected_log_alpha = log_sigma2 - np.log(np.power(theta, 2) + 1e-8)
73
    self.assertTrue(
74
        np.all(np.isclose(expected_log_alpha, log_alpha, rtol=1e-3)))
75

76
  def testHelper_ComputeLogSigma2(self, d, k, min_val, max_val):
77
    # Fix the random seed
78
    tf.set_random_seed(15)
79

80
    theta, log_alpha = self._get_weights(d, k, min_val, max_val)
81

82
    # Compute the log \sigma^2 values
83
    log_sigma2 = vd.common.compute_log_sigma2(log_alpha, theta)
84

85
    sess = tf.Session()
86
    log_sigma2, log_alpha, theta = sess.run(
87
        [log_sigma2, log_alpha, theta])
88

89
    # Verify the output shapes
90
    self.assertEqual(log_sigma2.shape, (d, k))
91
    self.assertEqual(log_alpha.shape, (d, k))
92
    self.assertEqual(theta.shape, (d, k))
93

94
    # Verify the calculated values
95
    expected_log_sigma2 = log_alpha + np.log(np.power(theta, 2) + 1e-8)
96
    self.assertTrue(
97
        np.all(np.isclose(expected_log_sigma2, log_sigma2, rtol=1e-3)))
98

99
  def testHelper_ComputeLogAlphaAndBack(self, d, k, min_val, max_val):
100
    theta, true_log_sigma2 = self._get_weights(d, k, min_val, max_val)
101

102
    # Compute the log alpha values
103
    log_alpha = vd.common.compute_log_alpha(
104
        true_log_sigma2, theta, value_limit=None)
105

106
    # Compute the log \sigma^2 values
107
    log_sigma2 = vd.common.compute_log_sigma2(log_alpha, theta)
108

109
    sess = tf.Session()
110
    true_log_sigma2, log_alpha, log_sigma2 = sess.run(
111
        [true_log_sigma2, log_alpha, log_sigma2])
112

113
    # Verify the output shapes
114
    self.assertEqual(true_log_sigma2.shape, (d, k))
115
    self.assertEqual(log_sigma2.shape, (d, k))
116
    self.assertEqual(log_alpha.shape, (d, k))
117

118
    # The calculated log \sigma^2 values should be the same as the
119
    # ones that we calculate through the log \alpha values
120
    for is_close in np.isclose(
121
        true_log_sigma2, log_sigma2, rtol=1e-3).flatten():
122
      self.assertTrue(is_close)
123

124
  def testHelper_ThresholdLogAlphas(self, d, k, min_val, max_val):
125
    theta, log_sigma2 = self._get_weights(d, k, min_val, max_val)
126

127
    # Compute the log alpha values
128
    value_limit = 8.
129
    log_alpha = vd.common.compute_log_alpha(
130
        log_sigma2, theta, value_limit=value_limit)
131

132
    sess = tf.Session()
133
    log_alpha = sess.run(log_alpha)
134

135
    # Verify the output shapes
136
    self.assertEqual(log_alpha.shape, (d, k))
137

138
    # Verify that all log alpha values are within the valid range
139
    for value in log_alpha.flatten():
140
      self.assertLessEqual(value, value_limit)
141
      self.assertGreaterEqual(value, -value_limit)
142

143

144
if __name__ == "__main__":
145
  tf.test.main()
146

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.