google-research
625 строк · 23.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains definitions for Residual Networks.
17
18Residual networks ('v1' ResNets) were originally proposed in:
19[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20Deep Residual Learning for Image Recognition. arXiv:1512.03385
21
22The full preactivation 'v2' ResNet variant was introduced by:
23[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25
26The key difference of the full preactivation 'v2' variant compared to the
27'v1' variant in [1] is the use of batch normalization before every weight layer
28rather than after.
29"""
30
31from __future__ import print_function32import tensorflow.compat.v1 as tf33from tensorflow.contrib import layers as contrib_layers34
35_BATCH_NORM_DECAY = 0.99736_BATCH_NORM_EPSILON = 1e-537DEFAULT_VERSION = 138DEFAULT_DTYPE = tf.float3239CASTABLE_TYPES = (tf.float16,)40ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES41
42
43def batch_norm(inputs, training, data_format):44"""Performs a batch normalization using a standard set of parameters."""45# We set fused=True for a significant performance boost. See46# https://www.tensorflow.org/performance/performance_guide#common_fused_ops47return tf.compat.v1.layers.batch_normalization(48inputs=inputs,49axis=1 if data_format == 'channels_first' else 3,50momentum=_BATCH_NORM_DECAY,51epsilon=_BATCH_NORM_EPSILON,52center=True,53scale=True,54training=training,55fused=True)56
57
58def fixed_padding(inputs, kernel_size, data_format):59"""Pads the input along the spatial dimensions independently of input size.60
61Args:
62inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
63height_in, width_in, channels] depending on data_format.
64kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
65Should be a positive integer.
66data_format: The input format ('channels_last' or 'channels_first').
67
68Returns:
69A tensor with the same format as the input with the data either intact
70(if kernel_size == 1) or padded (if kernel_size > 1).
71"""
72pad_total = kernel_size - 173pad_beg = pad_total // 274pad_end = pad_total - pad_beg75
76if data_format == 'channels_first':77padded_inputs = tf.pad(78tensor=inputs,79paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])80else:81padded_inputs = tf.pad(82tensor=inputs,83paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])84return padded_inputs85
86
87def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):88"""Strided 2-D convolution with explicit padding."""89# The padding is consistent and is based only on `kernel_size`, not on the90# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).91if strides > 1:92inputs = fixed_padding(inputs, kernel_size, data_format)93regu = contrib_layers.l2_regularizer(scale=0.0002)94return tf.layers.conv2d(95inputs=inputs,96filters=filters,97kernel_size=kernel_size,98strides=strides,99padding=('SAME' if strides == 1 else 'VALID'),100use_bias=False,101kernel_initializer=tf.compat.v1.variance_scaling_initializer(),102kernel_regularizer=regu,103data_format=data_format)104
105
106def _building_block_v1(inputs, filters, training, projection_shortcut, strides,107data_format):108"""A single block for ResNet v1, without a bottleneck.109
110Convolution then batch normalization then ReLU as described by:
111Deep Residual Learning for Image Recognition
112https://arxiv.org/pdf/1512.03385.pdf
113by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
114
115Args:
116inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
117height_in, width_in, channels] depending on data_format.
118filters: The number of filters for the convolutions.
119training: A Boolean for whether the model is in training or inference mode.
120Needed for batch normalization.
121projection_shortcut: The function to use for projection shortcuts (typically
122a 1x1 convolution when downsampling the input).
123strides: The block's stride. If greater than 1, this block will ultimately
124downsample the input.
125data_format: The input format ('channels_last' or 'channels_first').
126
127Returns:
128The output tensor of the block; shape should match inputs.
129"""
130shortcut = inputs131
132if projection_shortcut is not None:133shortcut = projection_shortcut(inputs)134shortcut = batch_norm(135inputs=shortcut, training=training, data_format=data_format)136
137inputs = conv2d_fixed_padding(138inputs=inputs,139filters=filters,140kernel_size=3,141strides=strides,142data_format=data_format)143inputs = batch_norm(inputs, training, data_format)144inputs = tf.nn.relu(inputs)145
146inputs = conv2d_fixed_padding(147inputs=inputs,148filters=filters,149kernel_size=3,150strides=1,151data_format=data_format)152inputs = batch_norm(inputs, training, data_format)153inputs += shortcut154inputs = tf.nn.relu(inputs)155
156return inputs157
158
159def _building_block_v2(inputs, filters, training, projection_shortcut, strides,160data_format):161"""A single block for ResNet v2, without a bottleneck.162
163Batch normalization then ReLu then convolution as described by:
164Identity Mappings in Deep Residual Networks
165https://arxiv.org/pdf/1603.05027.pdf
166by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
167
168Args:
169inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
170height_in, width_in, channels] depending on data_format.
171filters: The number of filters for the convolutions.
172training: A Boolean for whether the model is in training or inference mode.
173Needed for batch normalization.
174projection_shortcut: The function to use for projection shortcuts (typically
175a 1x1 convolution when downsampling the input).
176strides: The block's stride. If greater than 1, this block will ultimately
177downsample the input.
178data_format: The input format ('channels_last' or 'channels_first').
179
180Returns:
181The output tensor of the block; shape should match inputs.
182"""
183shortcut = inputs184inputs = batch_norm(inputs, training, data_format)185inputs = tf.nn.relu(inputs)186
187# The projection shortcut should come after the first batch norm and ReLU188# since it performs a 1x1 convolution.189if projection_shortcut is not None:190shortcut = projection_shortcut(inputs)191
192inputs = conv2d_fixed_padding(193inputs=inputs,194filters=filters,195kernel_size=3,196strides=strides,197data_format=data_format)198
199inputs = batch_norm(inputs, training, data_format)200inputs = tf.nn.relu(inputs)201inputs = conv2d_fixed_padding(202inputs=inputs,203filters=filters,204kernel_size=3,205strides=1,206data_format=data_format)207
208return inputs + shortcut209
210
211def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,212strides, data_format):213"""A single block for ResNet v1, with a bottleneck.214
215Similar to _building_block_v1(), except using the "bottleneck" blocks
216described in:
217Convolution then batch normalization then ReLU as described by:
218Deep Residual Learning for Image Recognition
219https://arxiv.org/pdf/1512.03385.pdf
220by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
221
222Args:
223inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
224height_in, width_in, channels] depending on data_format.
225filters: The number of filters for the convolutions.
226training: A Boolean for whether the model is in training or inference mode.
227Needed for batch normalization.
228projection_shortcut: The function to use for projection shortcuts (typically
229a 1x1 convolution when downsampling the input).
230strides: The block's stride. If greater than 1, this block will ultimately
231downsample the input.
232data_format: The input format ('channels_last' or 'channels_first').
233
234Returns:
235The output tensor of the block; shape should match inputs.
236"""
237shortcut = inputs238
239if projection_shortcut is not None:240shortcut = projection_shortcut(inputs)241shortcut = batch_norm(242inputs=shortcut, training=training, data_format=data_format)243
244inputs = conv2d_fixed_padding(245inputs=inputs,246filters=filters,247kernel_size=1,248strides=1,249data_format=data_format)250inputs = batch_norm(inputs, training, data_format)251inputs = tf.nn.relu(inputs)252
253inputs = conv2d_fixed_padding(254inputs=inputs,255filters=filters,256kernel_size=3,257strides=strides,258data_format=data_format)259inputs = batch_norm(inputs, training, data_format)260inputs = tf.nn.relu(inputs)261
262inputs = conv2d_fixed_padding(263inputs=inputs,264filters=4 * filters,265kernel_size=1,266strides=1,267data_format=data_format)268inputs = batch_norm(inputs, training, data_format)269inputs += shortcut270inputs = tf.nn.relu(inputs)271
272return inputs273
274
275def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,276strides, data_format):277"""A single block for ResNet v2, with a bottleneck.278
279Similar to _building_block_v2(), except using the "bottleneck" blocks
280described in:
281Convolution then batch normalization then ReLU as described by:
282Deep Residual Learning for Image Recognition
283https://arxiv.org/pdf/1512.03385.pdf
284by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
285
286Adapted to the ordering conventions of:
287Batch normalization then ReLu then convolution as described by:
288Identity Mappings in Deep Residual Networks
289https://arxiv.org/pdf/1603.05027.pdf
290by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
291
292Args:
293inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
294height_in, width_in, channels] depending on data_format.
295filters: The number of filters for the convolutions.
296training: A Boolean for whether the model is in training or inference mode.
297Needed for batch normalization.
298projection_shortcut: The function to use for projection shortcuts (typically
299a 1x1 convolution when downsampling the input).
300strides: The block's stride. If greater than 1, this block will ultimately
301downsample the input.
302data_format: The input format ('channels_last' or 'channels_first').
303
304Returns:
305The output tensor of the block; shape should match inputs.
306"""
307shortcut = inputs308inputs = batch_norm(inputs, training, data_format)309inputs = tf.nn.relu(inputs)310
311# The projection shortcut should come after the first batch norm and ReLU312# since it performs a 1x1 convolution.313if projection_shortcut is not None:314shortcut = projection_shortcut(inputs)315
316inputs = conv2d_fixed_padding(317inputs=inputs,318filters=filters,319kernel_size=1,320strides=1,321data_format=data_format)322
323inputs = batch_norm(inputs, training, data_format)324inputs = tf.nn.relu(inputs)325inputs = conv2d_fixed_padding(326inputs=inputs,327filters=filters,328kernel_size=3,329strides=strides,330data_format=data_format)331
332inputs = batch_norm(inputs, training, data_format)333inputs = tf.nn.relu(inputs)334inputs = conv2d_fixed_padding(335inputs=inputs,336filters=4 * filters,337kernel_size=1,338strides=1,339data_format=data_format)340
341return inputs + shortcut342
343
344def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,345training, name, data_format):346"""Creates one layer of blocks for the ResNet model.347
348Args:
349inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
350height_in, width_in, channels] depending on data_format.
351filters: The number of filters for the first convolution of the layer.
352bottleneck: Is the block created a bottleneck block.
353block_fn: The block to use within the model, either `building_block` or
354`bottleneck_block`.
355blocks: The number of blocks contained in the layer.
356strides: The stride to use for the first convolution of the layer. If
357greater than 1, this layer will ultimately downsample the input.
358training: Either True or False, whether we are currently training the model.
359Needed for batch norm.
360name: A string name for the tensor output of the block layer.
361data_format: The input format ('channels_last' or 'channels_first').
362
363Returns:
364The output tensor of the block layer.
365"""
366
367# Bottleneck blocks end with 4x the number of filters as they start with368filters_out = filters * 4 if bottleneck else filters369
370def projection_shortcut(inputs):371return conv2d_fixed_padding(372inputs=inputs,373filters=filters_out,374kernel_size=1,375strides=strides,376data_format=data_format)377
378# Only the first block per block_layer uses projection_shortcut and strides379inputs = block_fn(inputs, filters, training, projection_shortcut, strides,380data_format)381
382for _ in range(1, blocks):383inputs = block_fn(inputs, filters, training, None, 1, data_format)384
385return tf.identity(inputs, name)386
387
388class Model(object):389"""Base class for building the Resnet Model."""390
391def __init__(self,392wd,393resnet_size,394bottleneck,395num_classes,396num_filters,397kernel_size,398conv_stride,399first_pool_size,400first_pool_stride,401block_sizes,402block_strides,403feature_dim,404resnet_version=DEFAULT_VERSION,405data_format=None,406dtype=DEFAULT_DTYPE):407"""Creates a model for classifying an image.408
409Args:
410wd: The co-efficient of weight decay.
411resnet_size: A single integer for the size of the ResNet model.
412bottleneck: Use regular blocks or bottleneck blocks.
413num_classes: The number of classes used as labels.
414num_filters: The number of filters to use for the first block layer of the
415model. This number is then doubled for each subsequent block layer.
416kernel_size: The kernel size to use for convolution.
417conv_stride: stride size for the initial convolutional layer
418first_pool_size: Pool size to be used for the first pooling layer. If
419none, the first pooling layer is skipped.
420first_pool_stride: stride size for the first pooling layer. Not used if
421first_pool_size is None.
422block_sizes: A list containing n values, where n is the number of sets of
423block layers desired. Each value should be the number of blocks in the
424i-th set.
425block_strides: List of integers representing the desired stride size for
426each of the sets of block layers. Should be same length as block_sizes.
427feature_dim: the dimension of the representation space.
428resnet_version: Integer representing which version of the ResNet network
429to use. See README for details. Valid values: [1, 2]
430data_format: Input format ('channels_last', 'channels_first', or None). If
431set to None, the format is dependent on whether a GPU is available.
432dtype: The TensorFlow dtype to use for calculations. If not specified
433tf.float32 is used.
434
435Raises:
436ValueError: if invalid version is selected.
437"""
438self.resnet_size = resnet_size439
440if not data_format:441data_format = ('channels_first'442if tf.test.is_built_with_cuda() else 'channels_last')443
444self.resnet_version = resnet_version445if resnet_version not in (1, 2):446raise ValueError(447'Resnet version should be 1 or 2. See README for citations.')448
449self.bottleneck = bottleneck450if bottleneck:451if resnet_version == 1:452self.block_fn = _bottleneck_block_v1453else:454self.block_fn = _bottleneck_block_v2455else:456if resnet_version == 1:457self.block_fn = _building_block_v1458else:459self.block_fn = _building_block_v2460
461if dtype not in ALLOWED_TYPES:462raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))463
464self.data_format = data_format465self.num_classes = num_classes466self.num_filters = num_filters467self.kernel_size = kernel_size468self.conv_stride = conv_stride469self.first_pool_size = first_pool_size470self.first_pool_stride = first_pool_stride471self.block_sizes = block_sizes472self.block_strides = block_strides473self.dtype = dtype474self.pre_activation = resnet_version == 2475self.regularizer = contrib_layers.l2_regularizer(scale=wd)476self.initializer = contrib_layers.xavier_initializer()477self.drop_rate = 0.5478self.feature_dim = feature_dim479
480def _custom_dtype_getter(self,481getter,482name,483shape=None,484dtype=DEFAULT_DTYPE,485*args,486**kwargs):487"""Creates variables in fp32, then casts to fp16 if necessary.488
489This function is a custom getter. A custom getter is a function with the
490same signature as tf.get_variable, except it has an additional getter
491parameter. Custom getters can be passed as the `custom_getter` parameter of
492tf.variable_scope. Then, tf.get_variable will call the custom getter,
493instead of directly getting a variable itself. This can be used to change
494the types of variables that are retrieved with tf.get_variable.
495The `getter` parameter is the underlying variable getter, that would have
496been called if no custom getter was used. Custom getters typically get a
497variable with `getter`, then modify it in some way.
498
499This custom getter will create an fp32 variable. If a low precision
500(e.g. float16) variable was requested it will then cast the variable to the
501requested dtype. The reason we do not directly create variables in low
502precision dtypes is that applying small gradients to such variables may
503cause the variable not to change.
504
505Args:
506getter: The underlying variable getter, that has the same signature as
507tf.get_variable and returns a variable.
508name: The name of the variable to get.
509shape: The shape of the variable to get.
510dtype: The dtype of the variable to get. Note that if this is a low
511precision dtype, the variable will be created as a tf.float32 variable,
512then cast to the appropriate dtype
513*args: Additional arguments to pass unmodified to getter.
514**kwargs: Additional keyword arguments to pass unmodified to getter.
515
516Returns:
517A variable which is cast to fp16 if necessary.
518"""
519# pylint: disable=keyword-arg-before-vararg520if dtype in CASTABLE_TYPES:521var = getter(name, shape, tf.float32, *args, **kwargs)522return tf.cast(var, dtype=dtype, name=name + '_cast')523else:524return getter(name, shape, dtype, *args, **kwargs)525
526def _model_variable_scope(self):527"""Returns a variable scope that the model should be created under.528
529If self.dtype is a castable type, model variable will be created in fp32
530then cast to self.dtype before being used.
531
532Returns:
533A variable scope for the model.
534"""
535
536return tf.compat.v1.variable_scope(537'resnet_model', custom_getter=self._custom_dtype_getter)538
539def confidence_model(self, mu, training):540"""Given a batch of mu, output a batch of variance."""541out = tf.layers.dropout(mu, rate=self.drop_rate, training=training)542out = tf.layers.dense(out, units=self.feature_dim, \543kernel_initializer=self.initializer, \544kernel_regularizer=self.regularizer, \545name='fc_variance')546out = tf.nn.relu(out)547out = tf.layers.batch_normalization(out, training=training, \548name='fc_variance_bn')549out = tf.layers.dropout(out, rate=self.drop_rate, training=training)550out = tf.layers.dense(out, units=self.feature_dim, \551kernel_initializer=self.initializer, \552kernel_regularizer=self.regularizer, \553name='fc_variance2')554return out555
556def encoder(self, inputs, training):557"""Add operations to classify a batch of input images.558
559Args:
560inputs: A Tensor representing a batch of input images.
561training: A boolean. Set to True to add operations required only when
562training the classifier.
563
564Returns:
565A logits Tensor with shape [<batch_size>, self.num_classes].
566"""
567
568with self._model_variable_scope():569if self.data_format == 'channels_first':570# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).571# This provides a large performance boost on GPU. See572# https://www.tensorflow.org/performance/performance_guide#data_formats573inputs = tf.transpose(a=inputs, perm=[0, 3, 1, 2])574
575inputs = conv2d_fixed_padding(576inputs=inputs,577filters=self.num_filters,578kernel_size=self.kernel_size,579strides=self.conv_stride,580data_format=self.data_format)581inputs = tf.identity(inputs, 'initial_conv')582
583if self.resnet_version == 1:584inputs = batch_norm(inputs, training, self.data_format)585inputs = tf.nn.relu(inputs)586
587if self.first_pool_size:588inputs = tf.compat.v1.layers.max_pooling2d(589inputs=inputs,590pool_size=self.first_pool_size,591strides=self.first_pool_stride,592padding='SAME',593data_format=self.data_format)594inputs = tf.identity(inputs, 'initial_max_pool')595
596for i, num_blocks in enumerate(self.block_sizes):597num_filters = self.num_filters * (2**i)598inputs = block_layer(599inputs=inputs,600filters=num_filters,601bottleneck=self.bottleneck,602block_fn=self.block_fn,603blocks=num_blocks,604strides=self.block_strides[i],605training=training,606name='block_layer{}'.format(i + 1),607data_format=self.data_format)608
609# Only apply the BN and ReLU for model that does pre_activation in each610# building/bottleneck block, eg resnet V2.611# if self.pre_activation:612# inputs = batch_norm(inputs, training, self.data_format)613# inputs = tf.nn.relu(inputs)614
615# The current top layer has shape616# `batch_size x pool_size x pool_size x final_size`.617# ResNet does an Average Pooling layer over pool_size,618# but that is the same as doing a reduce_mean. We do a reduce_mean619# here because it performs better than AveragePooling2D.620axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]621inputs = tf.reduce_mean(input_tensor=inputs, axis=axes, keepdims=True)622inputs = tf.identity(inputs, 'final_reduce_mean')623
624inputs = tf.squeeze(inputs, axes)625return inputs626