google-research
259 строк · 7.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for l0 layers."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import absl.testing.parameterized as parameterized
24import numpy as np
25import tensorflow.compat.v1 as tf
26
27import state_of_sparsity.layers.l0_regularization as l0
28
29
30# Parameters to test the matmul primitive on. First dimensions of
31# first matrix, second dimension of first matrix/first dimension
32# of second matrix, second dimension of second matrix.
33MATMUL_TEST_PARAMETERS = [(32, 200, 100)]
34
35
36@parameterized.parameters(MATMUL_TEST_PARAMETERS)
37class MatmulTest(l0.test_base.TestCase):
38
39def testMatmulTrain(self, m, n, k):
40self.assertSameResult(
41l0.nn.matmul_train,
42tf.matmul,
43[m, n],
44[n, k])
45
46def testMatmulTrain_NonDeterministic(self, m, n, k):
47self.assertNonDeterministic(
48l0.nn.matmul_train,
49[m, n],
50[n, k])
51
52def testMatmulEval(self, m, n, k):
53self.assertSameResult(
54l0.nn.matmul_eval,
55tf.matmul,
56[m, n],
57[n, k])
58
59def testMatmulEval_Deterministic(self, m, n, k):
60self.assertDeterministic(
61l0.nn.matmul_eval,
62[m, n],
63[n, k])
64
65
66# Parameters to test the batched matmul primitive on. First dimension
67# of the first matrix, second dimension of the first matrix, third
68# dimension of the first matrix/first dimenions of the second matrix,
69# second dimension of the second matrix.
70BROADCAST_MATMUL_TEST_PARAMETERS = [(32, 20, 200, 100),
71(1, 10, 100, 50)]
72
73
74@parameterized.parameters(BROADCAST_MATMUL_TEST_PARAMETERS)
75class BroadcastMatmulTest(l0.test_base.TestCase):
76
77def set_axes(self, ref_op):
78return functools.partial(ref_op, axes=[[2], [0]])
79
80def testBroadcastMatmulTrain(self, m, t, n, k):
81self.assertSameResult(
82l0.nn.broadcast_matmul_train,
83self.set_axes(tf.tensordot),
84[m, t, n],
85[n, k])
86
87def testBroadcastMatmulTrain_NonDeterministic(self, m, t, n, k):
88self.assertNonDeterministic(
89l0.nn.broadcast_matmul_train,
90[m, t, n],
91[n, k])
92
93def testBroadcastMatmulEval(self, m, t, n, k):
94self.assertSameResult(
95l0.nn.broadcast_matmul_eval,
96self.set_axes(tf.tensordot),
97[m, t, n],
98[n, k])
99
100def testBroadcastMatmulEval_Deterministic(self, m, t, n, k):
101self.assertDeterministic(
102l0.nn.broadcast_matmul_eval,
103[m, t, n],
104[n, k])
105
106# Parameters to test the conv2d primitive with. Input tensor batch size,
107# input channels, input height, input width, size of the convolutional
108# filters, number of output channels.
109CONV2D_TEST_PARAMETERS = [(32, 3, 224, 224, 3, 64)]
110
111
112@parameterized.parameters(CONV2D_TEST_PARAMETERS)
113class Conv2dTest(l0.test_base.TestCase):
114
115def testConv2dTrain(
116self,
117batch_size,
118in_channels,
119height,
120width,
121filter_size,
122out_channels):
123self.assertSameResult(
124self.fix_padding_and_strides(l0.nn.conv2d_train),
125self.fix_padding_and_strides(tf.nn.conv2d),
126[batch_size, height, width, in_channels],
127[filter_size, filter_size, in_channels, out_channels])
128
129def testConv2dTrain_NonDeterministic(
130self,
131batch_size,
132in_channels,
133height,
134width,
135filter_size,
136out_channels):
137self.assertNonDeterministic(
138self.fix_padding_and_strides(l0.nn.conv2d_train),
139[batch_size, height, width, in_channels],
140[filter_size, filter_size, in_channels, out_channels])
141
142def testConv2dEval(
143self,
144batch_size,
145in_channels,
146height,
147width,
148filter_size,
149out_channels):
150self.assertSameResult(
151self.fix_padding_and_strides(l0.nn.conv2d_eval),
152self.fix_padding_and_strides(tf.nn.conv2d),
153[batch_size, height, width, in_channels],
154[filter_size, filter_size, in_channels, out_channels])
155
156def testConv2dEval_Deterministic(
157self,
158batch_size,
159in_channels,
160height,
161width,
162filter_size,
163out_channels):
164self.assertDeterministic(
165self.fix_padding_and_strides(l0.nn.conv2d_eval),
166[batch_size, height, width, in_channels],
167[filter_size, filter_size, in_channels, out_channels])
168
169
170# Parameters for the embedding lookup tests. Batch size, sequence length,
171# vocabulary size, embedding vector size
172EMBEDDING_TEST_PARAMETERS = [(32, 25, 10000, 512)]
173
174
175@parameterized.parameters(EMBEDDING_TEST_PARAMETERS)
176class EmbeddingLookupTest(l0.test_base.TestCase):
177
178def testEmbeddingLookupTrain(
179self,
180batch_size,
181seq_length,
182vocab_size,
183embedding_size):
184self.assertSameResult(
185self.flip_input_wrapper(l0.nn.embedding_lookup_train),
186self.flip_input_wrapper(tf.nn.embedding_lookup),
187[batch_size, seq_length, 1],
188[vocab_size, embedding_size],
189data_dtype=tf.int32)
190
191def testEmbeddingLookupTrain_NonDeterministic(
192self,
193batch_size,
194seq_length,
195vocab_size,
196embedding_size):
197self.assertNonDeterministic(
198self.flip_input_wrapper(l0.nn.embedding_lookup_train),
199[batch_size, seq_length, 1],
200[vocab_size, embedding_size],
201data_dtype=tf.int32)
202
203def testEmbeddingLookupEval(
204self,
205batch_size,
206seq_length,
207vocab_size,
208embedding_size):
209self.assertSameResult(
210self.flip_input_wrapper(l0.nn.embedding_lookup_eval),
211self.flip_input_wrapper(tf.nn.embedding_lookup),
212[batch_size, seq_length, 1],
213[vocab_size, embedding_size],
214data_dtype=tf.int32)
215
216def testEmbeddingLookupEval_Deterministic(
217self,
218batch_size,
219seq_length,
220vocab_size,
221embedding_size):
222self.assertDeterministic(
223self.flip_input_wrapper(l0.nn.embedding_lookup_eval),
224[batch_size, seq_length, 1],
225[vocab_size, embedding_size],
226data_dtype=tf.int32)
227
228
229# Dimensions to calculate the regularization contribution over, and
230# the beta, gamma, and zeta parameters.
231L0_NORM_TEST_PARAMETERS = [(256, 128, 2.0 / 3.0, -0.1, 1.1)]
232
233
234@parameterized.parameters(L0_NORM_TEST_PARAMETERS)
235class TestL0Norm(l0.test_base.TestCase):
236
237def testL0Norm(self, d, k, beta, gamma, zeta):
238self.fix_random_seeds()
239
240log_alpha = tf.random_normal([d, k], dtype=tf.float32)
241
242output = l0.nn.l0_norm(log_alpha, beta, gamma, zeta)
243result, log_alpha = self.evaluate([output, log_alpha])
244
245# Verify the output shape
246self.assertEqual(result.shape, ())
247
248def expected_l0_norm(log_alpha, beta, gamma, zeta):
249def sigmoid(x):
250return 1.0 /(1.0 + np.exp(-x))
251return np.sum(sigmoid(log_alpha - beta * np.log(-gamma / zeta)))
252
253# Calculate the expected result and compare
254expected_result = expected_l0_norm(log_alpha, beta, gamma, zeta)
255self.assertAllClose(result, expected_result)
256
257
258if __name__ == "__main__":
259tf.test.main()
260