google-research

Форк
0
/
blocks_test.py 
253 строки · 9.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for supcon.blocks."""
17

18
import tensorflow.compat.v1 as tf
19

20
from supcon import blocks
21

22

23
# This is the equivalent of tf.tpu.bfloat16_scope but it can run on CPU where
24
# bfloat16 isn't supported.
25
def custom_float16_getter(getter, *args, **kwargs):
26
  cast_to_float16 = False
27
  requested_dtype = kwargs['dtype']
28
  if requested_dtype == tf.float16:
29
    # Only change the variable dtype if doing so does not decrease variable
30
    # precision.
31
    kwargs['dtype'] = tf.float32
32
    cast_to_float16 = True
33
  var = getter(*args, **kwargs)
34
  # This if statement is needed to guard the cast, because batch norm
35
  # assigns directly to the return value of this custom getter. The cast
36
  # makes the return value not a variable so it cannot be assigned. Batch
37
  # norm variables are always in fp32 so this if statement is never
38
  # triggered for them.
39
  if cast_to_float16:
40
    var = tf.cast(var, tf.float16)
41
  return var
42

43

44
class BlocksTest(tf.test.TestCase):
45

46
  def test_padded_conv_can_be_called(self):
47
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)
48
    block = blocks.Conv2DFixedPadding(
49
        filters=4, kernel_size=3, strides=2, data_format='channels_last')
50
    outputs = block(inputs, training=True)
51
    grads = tf.gradients(outputs, inputs)
52
    self.assertTrue(tf.compat.v1.trainable_variables())
53
    self.assertTrue(grads)
54
    self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())
55

56
  def test_padded_conv_can_be_called_float16(self):
57
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
58
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
59
      block = blocks.Conv2DFixedPadding(
60
          filters=4, kernel_size=3, strides=2, data_format='channels_last')
61
      outputs = block(inputs, training=True)
62
      grads = tf.gradients(outputs, inputs)
63
    self.assertTrue(tf.compat.v1.trainable_variables())
64
    self.assertTrue(grads)
65
    self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())
66

67
  def test_padded_conv_can_be_called_channels_first(self):
68
    inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
69
    block = blocks.Conv2DFixedPadding(
70
        filters=4, kernel_size=3, strides=2, data_format='channels_first')
71
    outputs = block(inputs, training=True)
72
    grads = tf.gradients(outputs, inputs)
73
    self.assertTrue(tf.compat.v1.trainable_variables())
74
    self.assertTrue(grads)
75
    self.assertListEqual([2, 4, 8, 8], outputs.shape.as_list())
76

77
  def test_group_conv2d_can_be_called(self):
78
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)
79
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
80
      block = blocks.GroupConv2D(
81
          filters=4,
82
          kernel_size=3,
83
          strides=2,
84
          data_format='channels_last',
85
          groups=2)
86
      outputs = block(inputs, training=True)
87
      grads = tf.gradients(outputs, inputs)
88
    self.assertTrue(tf.compat.v1.trainable_variables())
89
    self.assertTrue(grads)
90
    self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())
91

92
  def test_group_conv2d_can_be_called_float16(self):
93
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
94
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
95
      block = blocks.GroupConv2D(
96
          filters=4,
97
          kernel_size=3,
98
          strides=2,
99
          data_format='channels_last',
100
          groups=2)
101
      outputs = block(inputs, training=True)
102
      grads = tf.gradients(outputs, inputs)
103
    self.assertTrue(tf.compat.v1.trainable_variables())
104
    self.assertTrue(grads)
105
    self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())
106

107
  def test_group_conv2d_can_be_called_channels_first(self):
108
    inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
109
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
110
      block = blocks.GroupConv2D(
111
          filters=4,
112
          kernel_size=3,
113
          strides=2,
114
          data_format='channels_first',
115
          groups=2)
116
      outputs = block(inputs, training=True)
117
      grads = tf.gradients(outputs, inputs)
118
    self.assertTrue(tf.compat.v1.trainable_variables())
119
    self.assertTrue(grads)
120
    self.assertListEqual([2, 4, 8, 8], outputs.shape.as_list())
121

122
  def test_bottleneck_block_can_be_called(self):
123
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)
124
    block = blocks.BottleneckResidualBlock(
125
        filters=3,
126
        strides=2,
127
        use_projection=True,
128
        data_format='channels_last')
129
    outputs = block(inputs, training=True)
130
    grads = tf.gradients(outputs, inputs)
131
    self.assertTrue(tf.compat.v1.trainable_variables())
132
    self.assertTrue(grads)
133
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
134
    self.assertListEqual([2, 8, 8, 12], outputs.shape.as_list())
135

136
  def test_bottleneck_block_can_be_called_float16(self):
137
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
138
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
139
      block = blocks.BottleneckResidualBlock(
140
          filters=3,
141
          strides=2,
142
          use_projection=True,
143
          data_format='channels_last')
144
      outputs = block(inputs, training=True)
145
      grads = tf.gradients(outputs, inputs)
146
    self.assertTrue(tf.compat.v1.trainable_variables())
147
    self.assertTrue(grads)
148
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
149
    self.assertListEqual([2, 8, 8, 12], outputs.shape.as_list())
150

151
  def test_bottleneck_block_can_be_called_channels_first(self):
152
    inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
153
    block = blocks.BottleneckResidualBlock(
154
        filters=3,
155
        strides=2,
156
        use_projection=True,
157
        data_format='channels_first')
158
    outputs = block(inputs, training=True)
159
    grads = tf.gradients(outputs, inputs)
160
    self.assertTrue(tf.compat.v1.trainable_variables())
161
    self.assertTrue(grads)
162
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
163
    self.assertListEqual([2, 12, 8, 8], outputs.shape.as_list())
164

165
  def test_residual_block_can_be_called(self):
166
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)
167
    block = blocks.ResidualBlock(
168
        filters=3,
169
        strides=2,
170
        use_projection=True,
171
        data_format='channels_last')
172
    outputs = block(inputs, training=True)
173
    grads = tf.gradients(outputs, inputs)
174
    self.assertTrue(tf.compat.v1.trainable_variables())
175
    self.assertTrue(grads)
176
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
177
    self.assertListEqual([2, 8, 8, 3], outputs.shape.as_list())
178

179
  def test_residual_block_can_be_called_float16(self):
180
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
181
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
182
      block = blocks.ResidualBlock(
183
          filters=3,
184
          strides=2,
185
          use_projection=True,
186
          data_format='channels_last')
187
      outputs = block(inputs, training=True)
188
      grads = tf.gradients(outputs, inputs)
189
    self.assertTrue(tf.compat.v1.trainable_variables())
190
    self.assertTrue(grads)
191
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
192
    self.assertListEqual([2, 8, 8, 3], outputs.shape.as_list())
193

194
  def test_residual_block_can_be_called_channels_first(self):
195
    inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
196
    block = blocks.ResidualBlock(
197
        filters=3,
198
        strides=2,
199
        use_projection=True,
200
        data_format='channels_first')
201
    outputs = block(inputs, training=True)
202
    grads = tf.gradients(outputs, inputs)
203
    self.assertTrue(tf.compat.v1.trainable_variables())
204
    self.assertTrue(grads)
205
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
206
    self.assertListEqual([2, 3, 8, 8], outputs.shape.as_list())
207

208
  def test_resnext_block_can_be_called(self):
209
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)
210
    block = blocks.ResNextBlock(
211
        filters=64,
212
        strides=2,
213
        use_projection=True,
214
        data_format='channels_last')
215
    outputs = block(inputs, training=True)
216
    grads = tf.gradients(outputs, inputs)
217
    self.assertTrue(tf.compat.v1.trainable_variables())
218
    self.assertTrue(grads)
219
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
220
    self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())
221

222
  def test_resnext_block_can_be_called_float16(self):
223
    inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
224
    with tf.variable_scope('float16', custom_getter=custom_float16_getter):
225
      block = blocks.ResNextBlock(
226
          filters=64,
227
          strides=2,
228
          use_projection=True,
229
          data_format='channels_last')
230
      outputs = block(inputs, training=True)
231
      grads = tf.gradients(outputs, inputs)
232
    self.assertTrue(tf.compat.v1.trainable_variables())
233
    self.assertTrue(grads)
234
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
235
    self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())
236

237
  def test_resnext_block_can_be_called_channels_first(self):
238
    inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
239
    block = blocks.ResNextBlock(
240
        filters=64,
241
        strides=2,
242
        use_projection=True,
243
        data_format='channels_first')
244
    outputs = block(inputs, training=True)
245
    grads = tf.gradients(outputs, inputs)
246
    self.assertTrue(tf.compat.v1.trainable_variables())
247
    self.assertTrue(grads)
248
    self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
249
    self.assertListEqual([2, 256, 8, 8], outputs.shape.as_list())
250

251

252
if __name__ == '__main__':
253
  tf.test.main()
254

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.