google-research

Форк
0
259 строк · 7.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for l0 layers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import functools
22

23
import absl.testing.parameterized as parameterized
24
import numpy as np
25
import tensorflow.compat.v1 as tf
26

27
import state_of_sparsity.layers.l0_regularization as l0
28

29

30
# Parameters to test the matmul primitive on. First dimensions of
31
# first matrix, second dimension of first matrix/first dimension
32
# of second matrix, second dimension of second matrix.
33
MATMUL_TEST_PARAMETERS = [(32, 200, 100)]
34

35

36
@parameterized.parameters(MATMUL_TEST_PARAMETERS)
37
class MatmulTest(l0.test_base.TestCase):
38

39
  def testMatmulTrain(self, m, n, k):
40
    self.assertSameResult(
41
        l0.nn.matmul_train,
42
        tf.matmul,
43
        [m, n],
44
        [n, k])
45

46
  def testMatmulTrain_NonDeterministic(self, m, n, k):
47
    self.assertNonDeterministic(
48
        l0.nn.matmul_train,
49
        [m, n],
50
        [n, k])
51

52
  def testMatmulEval(self, m, n, k):
53
    self.assertSameResult(
54
        l0.nn.matmul_eval,
55
        tf.matmul,
56
        [m, n],
57
        [n, k])
58

59
  def testMatmulEval_Deterministic(self, m, n, k):
60
    self.assertDeterministic(
61
        l0.nn.matmul_eval,
62
        [m, n],
63
        [n, k])
64

65

66
# Parameters to test the batched matmul primitive on. First dimension
67
# of the first matrix, second dimension of the first matrix, third
68
# dimension of the first matrix/first dimenions of the second matrix,
69
# second dimension of the second matrix.
70
BROADCAST_MATMUL_TEST_PARAMETERS = [(32, 20, 200, 100),
71
                                    (1, 10, 100, 50)]
72

73

74
@parameterized.parameters(BROADCAST_MATMUL_TEST_PARAMETERS)
75
class BroadcastMatmulTest(l0.test_base.TestCase):
76

77
  def set_axes(self, ref_op):
78
    return functools.partial(ref_op, axes=[[2], [0]])
79

80
  def testBroadcastMatmulTrain(self, m, t, n, k):
81
    self.assertSameResult(
82
        l0.nn.broadcast_matmul_train,
83
        self.set_axes(tf.tensordot),
84
        [m, t, n],
85
        [n, k])
86

87
  def testBroadcastMatmulTrain_NonDeterministic(self, m, t, n, k):
88
    self.assertNonDeterministic(
89
        l0.nn.broadcast_matmul_train,
90
        [m, t, n],
91
        [n, k])
92

93
  def testBroadcastMatmulEval(self, m, t, n, k):
94
    self.assertSameResult(
95
        l0.nn.broadcast_matmul_eval,
96
        self.set_axes(tf.tensordot),
97
        [m, t, n],
98
        [n, k])
99

100
  def testBroadcastMatmulEval_Deterministic(self, m, t, n, k):
101
    self.assertDeterministic(
102
        l0.nn.broadcast_matmul_eval,
103
        [m, t, n],
104
        [n, k])
105

106
# Parameters to test the conv2d primitive with. Input tensor batch size,
107
# input channels, input height, input width, size of the convolutional
108
# filters, number of output channels.
109
CONV2D_TEST_PARAMETERS = [(32, 3, 224, 224, 3, 64)]
110

111

112
@parameterized.parameters(CONV2D_TEST_PARAMETERS)
113
class Conv2dTest(l0.test_base.TestCase):
114

115
  def testConv2dTrain(
116
      self,
117
      batch_size,
118
      in_channels,
119
      height,
120
      width,
121
      filter_size,
122
      out_channels):
123
    self.assertSameResult(
124
        self.fix_padding_and_strides(l0.nn.conv2d_train),
125
        self.fix_padding_and_strides(tf.nn.conv2d),
126
        [batch_size, height, width, in_channels],
127
        [filter_size, filter_size, in_channels, out_channels])
128

129
  def testConv2dTrain_NonDeterministic(
130
      self,
131
      batch_size,
132
      in_channels,
133
      height,
134
      width,
135
      filter_size,
136
      out_channels):
137
    self.assertNonDeterministic(
138
        self.fix_padding_and_strides(l0.nn.conv2d_train),
139
        [batch_size, height, width, in_channels],
140
        [filter_size, filter_size, in_channels, out_channels])
141

142
  def testConv2dEval(
143
      self,
144
      batch_size,
145
      in_channels,
146
      height,
147
      width,
148
      filter_size,
149
      out_channels):
150
    self.assertSameResult(
151
        self.fix_padding_and_strides(l0.nn.conv2d_eval),
152
        self.fix_padding_and_strides(tf.nn.conv2d),
153
        [batch_size, height, width, in_channels],
154
        [filter_size, filter_size, in_channels, out_channels])
155

156
  def testConv2dEval_Deterministic(
157
      self,
158
      batch_size,
159
      in_channels,
160
      height,
161
      width,
162
      filter_size,
163
      out_channels):
164
    self.assertDeterministic(
165
        self.fix_padding_and_strides(l0.nn.conv2d_eval),
166
        [batch_size, height, width, in_channels],
167
        [filter_size, filter_size, in_channels, out_channels])
168

169

170
# Parameters for the embedding lookup tests. Batch size, sequence length,
171
# vocabulary size, embedding vector size
172
EMBEDDING_TEST_PARAMETERS = [(32, 25, 10000, 512)]
173

174

175
@parameterized.parameters(EMBEDDING_TEST_PARAMETERS)
176
class EmbeddingLookupTest(l0.test_base.TestCase):
177

178
  def testEmbeddingLookupTrain(
179
      self,
180
      batch_size,
181
      seq_length,
182
      vocab_size,
183
      embedding_size):
184
    self.assertSameResult(
185
        self.flip_input_wrapper(l0.nn.embedding_lookup_train),
186
        self.flip_input_wrapper(tf.nn.embedding_lookup),
187
        [batch_size, seq_length, 1],
188
        [vocab_size, embedding_size],
189
        data_dtype=tf.int32)
190

191
  def testEmbeddingLookupTrain_NonDeterministic(
192
      self,
193
      batch_size,
194
      seq_length,
195
      vocab_size,
196
      embedding_size):
197
    self.assertNonDeterministic(
198
        self.flip_input_wrapper(l0.nn.embedding_lookup_train),
199
        [batch_size, seq_length, 1],
200
        [vocab_size, embedding_size],
201
        data_dtype=tf.int32)
202

203
  def testEmbeddingLookupEval(
204
      self,
205
      batch_size,
206
      seq_length,
207
      vocab_size,
208
      embedding_size):
209
    self.assertSameResult(
210
        self.flip_input_wrapper(l0.nn.embedding_lookup_eval),
211
        self.flip_input_wrapper(tf.nn.embedding_lookup),
212
        [batch_size, seq_length, 1],
213
        [vocab_size, embedding_size],
214
        data_dtype=tf.int32)
215

216
  def testEmbeddingLookupEval_Deterministic(
217
      self,
218
      batch_size,
219
      seq_length,
220
      vocab_size,
221
      embedding_size):
222
    self.assertDeterministic(
223
        self.flip_input_wrapper(l0.nn.embedding_lookup_eval),
224
        [batch_size, seq_length, 1],
225
        [vocab_size, embedding_size],
226
        data_dtype=tf.int32)
227

228

229
# Dimensions to calculate the regularization contribution over, and
230
# the beta, gamma, and zeta parameters.
231
L0_NORM_TEST_PARAMETERS = [(256, 128, 2.0 / 3.0, -0.1, 1.1)]
232

233

234
@parameterized.parameters(L0_NORM_TEST_PARAMETERS)
235
class TestL0Norm(l0.test_base.TestCase):
236

237
  def testL0Norm(self, d, k, beta, gamma, zeta):
238
    self.fix_random_seeds()
239

240
    log_alpha = tf.random_normal([d, k], dtype=tf.float32)
241

242
    output = l0.nn.l0_norm(log_alpha, beta, gamma, zeta)
243
    result, log_alpha = self.evaluate([output, log_alpha])
244

245
    # Verify the output shape
246
    self.assertEqual(result.shape, ())
247

248
    def expected_l0_norm(log_alpha, beta, gamma, zeta):
249
      def sigmoid(x):
250
        return 1.0 /(1.0 + np.exp(-x))
251
      return np.sum(sigmoid(log_alpha - beta * np.log(-gamma / zeta)))
252

253
    # Calculate the expected result and compare
254
    expected_result = expected_l0_norm(log_alpha, beta, gamma, zeta)
255
    self.assertAllClose(result, expected_result)
256

257

258
if __name__ == "__main__":
259
  tf.test.main()
260

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.