google-research
222 строки · 5.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Base class for variational dropout tests."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import tensorflow.compat.v1 as tf
24
25from state_of_sparsity.layers.utils import test_utils
26
27
28class TestCase(test_utils.TestCase):
29"""Base class for all variational dropout tests."""
30
31def set_no_epsilon(self, test_op):
32return functools.partial(test_op, eps=0.0)
33
34def get_data_and_weights(
35self,
36data_shape,
37weights_shape,
38data_dtype,
39weights_dtype,
40variance_value=0.0):
41x = tf.ones(data_shape, dtype=data_dtype)
42theta = tf.ones(weights_shape, dtype=weights_dtype)
43log_sigma2 = tf.constant(
44variance_value,
45shape=weights_shape,
46dtype=weights_dtype)
47return x, (theta, log_sigma2)
48
49def assertSameResult(
50self,
51test_op,
52ref_op,
53data_shape,
54weights_shape,
55data_dtype=tf.float32,
56weights_dtype=tf.float32,
57log_sigma2_value=-10e6):
58super(TestCase, self).assertSameResult(
59test_op,
60ref_op,
61data_shape,
62weights_shape,
63data_dtype,
64weights_dtype,
65variance_value=log_sigma2_value)
66
67def assertDeterministic(
68self,
69test_op,
70data_shape,
71weights_shape,
72data_dtype=tf.float32,
73weights_dtype=tf.float32,
74log_sigma2_value=0.0):
75self._determinism_helper(
76test_op,
77data_shape,
78weights_shape,
79data_dtype,
80weights_dtype,
81log_sigma2_value,
82check_same=True)
83
84def assertNonDeterministic(
85self,
86test_op,
87data_shape,
88weights_shape,
89data_dtype=tf.float32,
90weights_dtype=tf.float32,
91log_sigma2_value=0.0):
92self._determinism_helper(
93test_op,
94data_shape,
95weights_shape,
96data_dtype,
97weights_dtype,
98log_sigma2_value,
99check_same=False)
100
101
102class RNNTestCase(test_utils.RNNTestCase):
103"""Base class for all variational dropout recurrent cell tests."""
104
105def set_no_epsilon(self, test_op):
106return functools.partial(test_op, eps=0.0)
107
108def get_data_and_weights_and_biases(
109self,
110data_shape,
111weights_shape,
112biases_shape,
113data_dtype,
114weights_dtype,
115variance_value=0.0):
116x = tf.constant(0.1, data_dtype, data_shape)
117theta = tf.constant(1.0, weights_dtype, weights_shape)
118log_sigma2 = tf.constant(variance_value, weights_dtype, weights_shape)
119biases = tf.constant(1.0, weights_dtype, biases_shape)
120return x, (theta, log_sigma2), biases
121
122def assertSameNoiseForAllTimesteps(
123self,
124test_cell,
125num_units,
126data_shape,
127weights_shape,
128biases_shape,
129data_dtype=tf.float32,
130weights_dtype=tf.float32,
131log_sigma2_value=0):
132super(RNNTestCase, self).assertSameNoiseForAllTimesteps(
133test_cell,
134num_units,
135data_shape,
136weights_shape,
137biases_shape,
138data_dtype,
139weights_dtype,
140variance_value=log_sigma2_value)
141
142def assertDifferentNoiseAcrossBatches(
143self,
144test_cell,
145num_units,
146data_shape,
147weights_shape,
148biases_shape,
149data_dtype=tf.float32,
150weights_dtype=tf.float32,
151log_sigma2_value=0):
152super(RNNTestCase, self).assertDifferentNoiseAcrossBatches(
153test_cell,
154num_units,
155data_shape,
156weights_shape,
157biases_shape,
158data_dtype,
159weights_dtype,
160variance_value=log_sigma2_value)
161
162def assertDeterministic(
163self,
164test_cell,
165num_units,
166data_shape,
167weights_shape,
168biases_shape,
169data_dtype=tf.float32,
170weights_dtype=tf.float32,
171log_sigma2_value=0):
172super(RNNTestCase, self).assertDeterministic(
173test_cell,
174num_units,
175data_shape,
176weights_shape,
177biases_shape,
178data_dtype,
179weights_dtype,
180variance_value=log_sigma2_value)
181
182def assertNonDeterministic(
183self,
184test_cell,
185num_units,
186data_shape,
187weights_shape,
188biases_shape,
189data_dtype=tf.float32,
190weights_dtype=tf.float32,
191log_sigma2_value=0):
192super(RNNTestCase, self).assertNonDeterministic(
193test_cell,
194num_units,
195data_shape,
196weights_shape,
197biases_shape,
198data_dtype,
199weights_dtype,
200variance_value=log_sigma2_value)
201
202def assertSameResult(
203self,
204test_cell,
205ref_cell,
206num_units,
207data_shape,
208weights_shape,
209biases_shape,
210data_dtype=tf.float32,
211weights_dtype=tf.float32,
212log_sigma2_value=-10e6):
213super(RNNTestCase, self).assertSameResult(
214test_cell,
215ref_cell,
216num_units,
217data_shape,
218weights_shape,
219biases_shape,
220data_dtype,
221weights_dtype,
222variance_value=log_sigma2_value)
223