google-research
253 строки · 9.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for supcon.blocks."""
17
18import tensorflow.compat.v1 as tf19
20from supcon import blocks21
22
23# This is the equivalent of tf.tpu.bfloat16_scope but it can run on CPU where
24# bfloat16 isn't supported.
25def custom_float16_getter(getter, *args, **kwargs):26cast_to_float16 = False27requested_dtype = kwargs['dtype']28if requested_dtype == tf.float16:29# Only change the variable dtype if doing so does not decrease variable30# precision.31kwargs['dtype'] = tf.float3232cast_to_float16 = True33var = getter(*args, **kwargs)34# This if statement is needed to guard the cast, because batch norm35# assigns directly to the return value of this custom getter. The cast36# makes the return value not a variable so it cannot be assigned. Batch37# norm variables are always in fp32 so this if statement is never38# triggered for them.39if cast_to_float16:40var = tf.cast(var, tf.float16)41return var42
43
44class BlocksTest(tf.test.TestCase):45
46def test_padded_conv_can_be_called(self):47inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)48block = blocks.Conv2DFixedPadding(49filters=4, kernel_size=3, strides=2, data_format='channels_last')50outputs = block(inputs, training=True)51grads = tf.gradients(outputs, inputs)52self.assertTrue(tf.compat.v1.trainable_variables())53self.assertTrue(grads)54self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())55
56def test_padded_conv_can_be_called_float16(self):57inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)58with tf.variable_scope('float16', custom_getter=custom_float16_getter):59block = blocks.Conv2DFixedPadding(60filters=4, kernel_size=3, strides=2, data_format='channels_last')61outputs = block(inputs, training=True)62grads = tf.gradients(outputs, inputs)63self.assertTrue(tf.compat.v1.trainable_variables())64self.assertTrue(grads)65self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())66
67def test_padded_conv_can_be_called_channels_first(self):68inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)69block = blocks.Conv2DFixedPadding(70filters=4, kernel_size=3, strides=2, data_format='channels_first')71outputs = block(inputs, training=True)72grads = tf.gradients(outputs, inputs)73self.assertTrue(tf.compat.v1.trainable_variables())74self.assertTrue(grads)75self.assertListEqual([2, 4, 8, 8], outputs.shape.as_list())76
77def test_group_conv2d_can_be_called(self):78inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)79with tf.variable_scope('float16', custom_getter=custom_float16_getter):80block = blocks.GroupConv2D(81filters=4,82kernel_size=3,83strides=2,84data_format='channels_last',85groups=2)86outputs = block(inputs, training=True)87grads = tf.gradients(outputs, inputs)88self.assertTrue(tf.compat.v1.trainable_variables())89self.assertTrue(grads)90self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())91
92def test_group_conv2d_can_be_called_float16(self):93inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)94with tf.variable_scope('float16', custom_getter=custom_float16_getter):95block = blocks.GroupConv2D(96filters=4,97kernel_size=3,98strides=2,99data_format='channels_last',100groups=2)101outputs = block(inputs, training=True)102grads = tf.gradients(outputs, inputs)103self.assertTrue(tf.compat.v1.trainable_variables())104self.assertTrue(grads)105self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())106
107def test_group_conv2d_can_be_called_channels_first(self):108inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)109with tf.variable_scope('float16', custom_getter=custom_float16_getter):110block = blocks.GroupConv2D(111filters=4,112kernel_size=3,113strides=2,114data_format='channels_first',115groups=2)116outputs = block(inputs, training=True)117grads = tf.gradients(outputs, inputs)118self.assertTrue(tf.compat.v1.trainable_variables())119self.assertTrue(grads)120self.assertListEqual([2, 4, 8, 8], outputs.shape.as_list())121
122def test_bottleneck_block_can_be_called(self):123inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)124block = blocks.BottleneckResidualBlock(125filters=3,126strides=2,127use_projection=True,128data_format='channels_last')129outputs = block(inputs, training=True)130grads = tf.gradients(outputs, inputs)131self.assertTrue(tf.compat.v1.trainable_variables())132self.assertTrue(grads)133self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))134self.assertListEqual([2, 8, 8, 12], outputs.shape.as_list())135
136def test_bottleneck_block_can_be_called_float16(self):137inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)138with tf.variable_scope('float16', custom_getter=custom_float16_getter):139block = blocks.BottleneckResidualBlock(140filters=3,141strides=2,142use_projection=True,143data_format='channels_last')144outputs = block(inputs, training=True)145grads = tf.gradients(outputs, inputs)146self.assertTrue(tf.compat.v1.trainable_variables())147self.assertTrue(grads)148self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))149self.assertListEqual([2, 8, 8, 12], outputs.shape.as_list())150
151def test_bottleneck_block_can_be_called_channels_first(self):152inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)153block = blocks.BottleneckResidualBlock(154filters=3,155strides=2,156use_projection=True,157data_format='channels_first')158outputs = block(inputs, training=True)159grads = tf.gradients(outputs, inputs)160self.assertTrue(tf.compat.v1.trainable_variables())161self.assertTrue(grads)162self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))163self.assertListEqual([2, 12, 8, 8], outputs.shape.as_list())164
165def test_residual_block_can_be_called(self):166inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)167block = blocks.ResidualBlock(168filters=3,169strides=2,170use_projection=True,171data_format='channels_last')172outputs = block(inputs, training=True)173grads = tf.gradients(outputs, inputs)174self.assertTrue(tf.compat.v1.trainable_variables())175self.assertTrue(grads)176self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))177self.assertListEqual([2, 8, 8, 3], outputs.shape.as_list())178
179def test_residual_block_can_be_called_float16(self):180inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)181with tf.variable_scope('float16', custom_getter=custom_float16_getter):182block = blocks.ResidualBlock(183filters=3,184strides=2,185use_projection=True,186data_format='channels_last')187outputs = block(inputs, training=True)188grads = tf.gradients(outputs, inputs)189self.assertTrue(tf.compat.v1.trainable_variables())190self.assertTrue(grads)191self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))192self.assertListEqual([2, 8, 8, 3], outputs.shape.as_list())193
194def test_residual_block_can_be_called_channels_first(self):195inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)196block = blocks.ResidualBlock(197filters=3,198strides=2,199use_projection=True,200data_format='channels_first')201outputs = block(inputs, training=True)202grads = tf.gradients(outputs, inputs)203self.assertTrue(tf.compat.v1.trainable_variables())204self.assertTrue(grads)205self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))206self.assertListEqual([2, 3, 8, 8], outputs.shape.as_list())207
208def test_resnext_block_can_be_called(self):209inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float32)210block = blocks.ResNextBlock(211filters=64,212strides=2,213use_projection=True,214data_format='channels_last')215outputs = block(inputs, training=True)216grads = tf.gradients(outputs, inputs)217self.assertTrue(tf.compat.v1.trainable_variables())218self.assertTrue(grads)219self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))220self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())221
222def test_resnext_block_can_be_called_float16(self):223inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)224with tf.variable_scope('float16', custom_getter=custom_float16_getter):225block = blocks.ResNextBlock(226filters=64,227strides=2,228use_projection=True,229data_format='channels_last')230outputs = block(inputs, training=True)231grads = tf.gradients(outputs, inputs)232self.assertTrue(tf.compat.v1.trainable_variables())233self.assertTrue(grads)234self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))235self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())236
237def test_resnext_block_can_be_called_channels_first(self):238inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)239block = blocks.ResNextBlock(240filters=64,241strides=2,242use_projection=True,243data_format='channels_first')244outputs = block(inputs, training=True)245grads = tf.gradients(outputs, inputs)246self.assertTrue(tf.compat.v1.trainable_variables())247self.assertTrue(grads)248self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))249self.assertListEqual([2, 256, 8, 8], outputs.shape.as_list())250
251
252if __name__ == '__main__':253tf.test.main()254