google-research

Форк
0
222 строки · 5.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Base class for variational dropout tests."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import functools
22

23
import tensorflow.compat.v1 as tf
24

25
from state_of_sparsity.layers.utils import test_utils
26

27

28
class TestCase(test_utils.TestCase):
29
  """Base class for all variational dropout tests."""
30

31
  def set_no_epsilon(self, test_op):
32
    return functools.partial(test_op, eps=0.0)
33

34
  def get_data_and_weights(
35
      self,
36
      data_shape,
37
      weights_shape,
38
      data_dtype,
39
      weights_dtype,
40
      variance_value=0.0):
41
    x = tf.ones(data_shape, dtype=data_dtype)
42
    theta = tf.ones(weights_shape, dtype=weights_dtype)
43
    log_sigma2 = tf.constant(
44
        variance_value,
45
        shape=weights_shape,
46
        dtype=weights_dtype)
47
    return x, (theta, log_sigma2)
48

49
  def assertSameResult(
50
      self,
51
      test_op,
52
      ref_op,
53
      data_shape,
54
      weights_shape,
55
      data_dtype=tf.float32,
56
      weights_dtype=tf.float32,
57
      log_sigma2_value=-10e6):
58
    super(TestCase, self).assertSameResult(
59
        test_op,
60
        ref_op,
61
        data_shape,
62
        weights_shape,
63
        data_dtype,
64
        weights_dtype,
65
        variance_value=log_sigma2_value)
66

67
  def assertDeterministic(
68
      self,
69
      test_op,
70
      data_shape,
71
      weights_shape,
72
      data_dtype=tf.float32,
73
      weights_dtype=tf.float32,
74
      log_sigma2_value=0.0):
75
    self._determinism_helper(
76
        test_op,
77
        data_shape,
78
        weights_shape,
79
        data_dtype,
80
        weights_dtype,
81
        log_sigma2_value,
82
        check_same=True)
83

84
  def assertNonDeterministic(
85
      self,
86
      test_op,
87
      data_shape,
88
      weights_shape,
89
      data_dtype=tf.float32,
90
      weights_dtype=tf.float32,
91
      log_sigma2_value=0.0):
92
    self._determinism_helper(
93
        test_op,
94
        data_shape,
95
        weights_shape,
96
        data_dtype,
97
        weights_dtype,
98
        log_sigma2_value,
99
        check_same=False)
100

101

102
class RNNTestCase(test_utils.RNNTestCase):
103
  """Base class for all variational dropout recurrent cell tests."""
104

105
  def set_no_epsilon(self, test_op):
106
    return functools.partial(test_op, eps=0.0)
107

108
  def get_data_and_weights_and_biases(
109
      self,
110
      data_shape,
111
      weights_shape,
112
      biases_shape,
113
      data_dtype,
114
      weights_dtype,
115
      variance_value=0.0):
116
    x = tf.constant(0.1, data_dtype, data_shape)
117
    theta = tf.constant(1.0, weights_dtype, weights_shape)
118
    log_sigma2 = tf.constant(variance_value, weights_dtype, weights_shape)
119
    biases = tf.constant(1.0, weights_dtype, biases_shape)
120
    return x, (theta, log_sigma2), biases
121

122
  def assertSameNoiseForAllTimesteps(
123
      self,
124
      test_cell,
125
      num_units,
126
      data_shape,
127
      weights_shape,
128
      biases_shape,
129
      data_dtype=tf.float32,
130
      weights_dtype=tf.float32,
131
      log_sigma2_value=0):
132
    super(RNNTestCase, self).assertSameNoiseForAllTimesteps(
133
        test_cell,
134
        num_units,
135
        data_shape,
136
        weights_shape,
137
        biases_shape,
138
        data_dtype,
139
        weights_dtype,
140
        variance_value=log_sigma2_value)
141

142
  def assertDifferentNoiseAcrossBatches(
143
      self,
144
      test_cell,
145
      num_units,
146
      data_shape,
147
      weights_shape,
148
      biases_shape,
149
      data_dtype=tf.float32,
150
      weights_dtype=tf.float32,
151
      log_sigma2_value=0):
152
    super(RNNTestCase, self).assertDifferentNoiseAcrossBatches(
153
        test_cell,
154
        num_units,
155
        data_shape,
156
        weights_shape,
157
        biases_shape,
158
        data_dtype,
159
        weights_dtype,
160
        variance_value=log_sigma2_value)
161

162
  def assertDeterministic(
163
      self,
164
      test_cell,
165
      num_units,
166
      data_shape,
167
      weights_shape,
168
      biases_shape,
169
      data_dtype=tf.float32,
170
      weights_dtype=tf.float32,
171
      log_sigma2_value=0):
172
    super(RNNTestCase, self).assertDeterministic(
173
        test_cell,
174
        num_units,
175
        data_shape,
176
        weights_shape,
177
        biases_shape,
178
        data_dtype,
179
        weights_dtype,
180
        variance_value=log_sigma2_value)
181

182
  def assertNonDeterministic(
183
      self,
184
      test_cell,
185
      num_units,
186
      data_shape,
187
      weights_shape,
188
      biases_shape,
189
      data_dtype=tf.float32,
190
      weights_dtype=tf.float32,
191
      log_sigma2_value=0):
192
    super(RNNTestCase, self).assertNonDeterministic(
193
        test_cell,
194
        num_units,
195
        data_shape,
196
        weights_shape,
197
        biases_shape,
198
        data_dtype,
199
        weights_dtype,
200
        variance_value=log_sigma2_value)
201

202
  def assertSameResult(
203
      self,
204
      test_cell,
205
      ref_cell,
206
      num_units,
207
      data_shape,
208
      weights_shape,
209
      biases_shape,
210
      data_dtype=tf.float32,
211
      weights_dtype=tf.float32,
212
      log_sigma2_value=-10e6):
213
    super(RNNTestCase, self).assertSameResult(
214
        test_cell,
215
        ref_cell,
216
        num_units,
217
        data_shape,
218
        weights_shape,
219
        biases_shape,
220
        data_dtype,
221
        weights_dtype,
222
        variance_value=log_sigma2_value)
223

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.