google-research
463 строки · 17.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Training script for the VIS model.
17
18See model.py for more details and usage.
19"""
20import os21import common # pylint: disable=unused-import22from deeplab import preprocess_utils23from deeplab import train_utils24import model25import model_input26import six27import tensorflow.compat.v1 as tf28from tensorflow.compat.v1.python import debug as tf_debug29from tensorflow.compat.v1.python.platform import app30
31ZERO_DIV_OFFSET = 1e-2032
33flags = tf.app.flags34FLAGS = flags.FLAGS35
36# Settings for logging.
37
38flags.DEFINE_string('train_logdir', None,39'Where the checkpoint and logs are stored.')40
41flags.DEFINE_integer('save_interval_secs', 60,42'How often, in seconds, we save the model to disk.')43
44flags.DEFINE_integer('save_summary_steps', 100, '')45
46# Settings for training strategry.
47
48flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step', 'cosine'],49'Learning rate policy for training.')50
51flags.DEFINE_float('base_learning_rate', 3e-5,52'The base learning rate for model training.')53
54flags.DEFINE_float('learning_rate_decay_factor', 0.1,55'The rate to decay the base learning rate.')56
57flags.DEFINE_integer(58'learning_rate_decay_step', 8000,59'Decay the base learning rate at a fixed step.'60'Not used if learning_policy == "poly"')61
62flags.DEFINE_float('learning_power', 0.9,63'The power value used in the poly learning policy.')64
65flags.DEFINE_integer('training_number_of_steps', 1000000,66'The number of steps used for training')67
68flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')69
70flags.DEFINE_integer('batch_size', 8,71'The number of images in each batch during training.')72
73flags.DEFINE_float('weight_decay', 0.00004,74'The value of the weight decay for training.')75
76flags.DEFINE_integer('resize_factor', None,77'Resized dimensions are multiple of factor plus one.')78
79flags.DEFINE_boolean('upsample_logits', True,80'Upsample logits during training.')81
82# Settings for fine-tuning the network.
83
84flags.DEFINE_string('tf_initial_checkpoint', '',85'The initial checkpoint in tensorflow format.')86
87flags.DEFINE_boolean('initialize_last_layer', False,88'Initialize the last layer.')89
90flags.DEFINE_integer('slow_start_step', 0,91'Training model with small learning rate for few steps.')92
93flags.DEFINE_float('slow_start_learning_rate', 1e-4,94'Learning rate employed during slow start.')95
96flags.DEFINE_boolean('fine_tune_batch_norm', True,97'Fine tune the batch norm parameters or not.')98
99flags.DEFINE_string('split', 'train',100'Which split of the dataset to be used for training')101
102flags.DEFINE_bool('debug', False, 'Whether to use tf dbg.')103
104flags.DEFINE_boolean('profile', False, '')105
106flags.DEFINE_boolean('use_sigmoid', True,107'Use the custom sigmoid cross entropy function')108
109flags.DEFINE_float('sigmoid_recall_weight', 5,110'If <1 value precision, if >1 recall')111
112flags.DEFINE_enum(113'distance_metric', 'euclidean_iter',114['mse', 'euclidean', 'euclidean_sqrt', 'euclidean_iter'],115'the cost metric for the Click Regression'116'"mse" for mean squared error'117'"euclidean" for euclidean distance'118'"euclidean_sqrt" for square root of euclidean distance')119
120flags.DEFINE_bool('ratio_box_distance', False,121'normalize the distance loss by the size of the box')122
123flags.DEFINE_integer('euclidean_step', 300000,124'decrease exponent of distance loss every euclidean_step')125
126
127def logits_summary(logits):128if model_input.dataset_descriptors[FLAGS.dataset].num_classes == 2:129logits_for_sum = tf.concat([logits, tf.zeros_like(logits[:, :, :, 0:1])], 3)130else:131logits_for_sum = logits132
133tf.summary.image('logits', logits_for_sum, 4)134resized = tf.image.resize_bilinear(135logits_for_sum, [FLAGS.image_size, FLAGS.image_size], align_corners=True)136tf.summary.image('resized_logits', resized, 4)137
138
139def label_summary(labels):140labels = tf.clip_by_value(labels, 0, 3) * int(255 / 3)141tf.summary.image('label', tf.cast(labels, tf.uint8), 4)142
143
144def add_cross_entropy_loss(labels, logits, add_loss):145"""Adds accuracy summary. Adds the loss if add_loss is true."""146if add_loss:147loss = tf.nn.sparse_softmax_cross_entropy_with_logits(148labels=labels, logits=logits)149loss = tf.reduce_mean(loss, name='selected_loss')150tf.losses.add_loss(loss)151
152pred = tf.argmax(logits, 1)153correct = tf.equal(pred, labels)154accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))155tf.summary.scalar('selected_accuracy', accuracy)156
157
158def add_sigmoid_cross_entropy_loss_for_each_scale(scales_to_logits,159labels,160ignore_label,161loss_weight=1.0,162upsample_logits=True,163scope=None):164"""Adds sigmoid cross entropy loss for logits of each scale.165
166Implemented based on deeplab's add_softmax_cross_entropy_loss_for_each_scale
167in deeplab/utils/train_utils.py.
168
169Args:
170scales_to_logits: A map from logits names for different scales to logits.
171The logits have shape [batch, logits_height, logits_width, num_classes].
172labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
173ignore_label: Integer, label to ignore.
174loss_weight: Float, loss weight.
175upsample_logits: Boolean, upsample logits or not.
176scope: String, the scope for the loss.
177
178Raises:
179ValueError: Label or logits is None.
180"""
181if labels is None:182raise ValueError('No label for softmax cross entropy loss.')183
184for scale, logits in six.iteritems(scales_to_logits):185loss_scope = None186if scope:187loss_scope = '%s_%s' % (scope, scale)188
189if upsample_logits:190# Label is not downsampled, and instead we upsample logits.191logits = tf.image.resize_bilinear(192logits,193preprocess_utils.resolve_shape(labels, 4)[1:3],194align_corners=True)195scaled_labels = labels196else:197# Label is downsampled to the same size as logits.198scaled_labels = tf.image.resize_nearest_neighbor(199labels,200preprocess_utils.resolve_shape(logits, 4)[1:3],201align_corners=True)202
203logits = logits[:, :, :, 1]204scaled_labels = tf.to_float(scaled_labels)205scaled_labels = tf.squeeze(scaled_labels)206not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,207ignore_label)) * loss_weight208losses = tf.nn.weighted_cross_entropy_with_logits(209scaled_labels, logits, FLAGS.sigmoid_recall_weight)210
211# Loss added later in model_fn by tf.losses.get_total_loss()212tf.losses.compute_weighted_loss(213losses, weights=not_ignore_mask, scope=loss_scope)214
215
216def add_distance_loss_to_center(labels, logits, groundtruth_coords):217"""Add distance loss function for ClickRegression."""218weights = tf.to_int32(219tf.not_equal(labels,220model_input.dataset_descriptors[FLAGS.dataset].ignore_label))221labels *= weights222
223# Use GT box to get center if it exists. Less computation required.224# Otherwise, calculate from label mask.225if FLAGS.use_groundtruth_box:226center_x = (groundtruth_coords['xmin'] + groundtruth_coords['xmax']) / 2.0227center_y = (groundtruth_coords['ymin'] + groundtruth_coords['ymax']) / 2.0228center = tf.stack([center_y, center_x], axis=1)229else:230# Make array of coordinates (each row contains three coordinates)231ii, jj = tf.meshgrid(232tf.range(FLAGS.image_size), tf.range(FLAGS.image_size), indexing='ij')233coords = tf.stack([tf.reshape(ii, (-1,)), tf.reshape(jj, (-1,))], axis=-1)234coords = tf.cast(coords, tf.int32)235
236# Rearrange input into one vector per volume237volumes_flat = tf.reshape(labels,238[-1, FLAGS.image_size * FLAGS.image_size * 1, 1])239# Compute total mass for each volume. Add 0.00001 to prevent division by 0240total_mass = tf.cast(tf.reduce_sum(volumes_flat, axis=1),241tf.float32) + ZERO_DIV_OFFSET242# Compute centre of mass243center = tf.cast(tf.reduce_sum(volumes_flat * coords, axis=1),244tf.float32) / total_mass245center = center / FLAGS.image_size246
247# Normalize coordinates by size of image248logits = logits / FLAGS.image_size249
250# Calculate loss based on the distance metric specified251# Loss added later in model_fn by tf.losses.get_total_loss()252if FLAGS.distance_metric == 'mse':253tf.losses.mean_squared_error(center, logits)254elif FLAGS.distance_metric in [255'euclidean', 'euclidean_sqrt', 'euclidean_iter'256]:257distance_to_center = tf.sqrt(258tf.reduce_sum(tf.square(logits - center), axis=-1) + ZERO_DIV_OFFSET)259if FLAGS.ratio_box_distance:260distance_to_box = calc_distance_to_edge(groundtruth_coords, logits)261box_distance_to_center = (262tf.to_float(distance_to_center) - distance_to_box)263loss = distance_to_center / (box_distance_to_center + ZERO_DIV_OFFSET)264else:265loss = distance_to_center266
267if FLAGS.distance_metric == 'euclidean_sqrt':268loss = tf.sqrt(loss)269if FLAGS.distance_metric == 'euclidean_iter':270iter_num = tf.to_float(tf.train.get_or_create_global_step())271step = (iter_num // FLAGS.euclidean_step) + 1.0272loss = tf.pow(loss, tf.to_float(1.0 / step))273tf.losses.compute_weighted_loss(loss)274
275
276def calc_distance_to_edge(groundtruth_coords, logits):277"""Calculate distance between predicted point to box of ground truth."""278
279# Returns 0 if predicted point is inside the groundtruth box280dx = tf.maximum(281tf.maximum(groundtruth_coords['xmin'] - logits[:, 1],282logits[:, 1] - groundtruth_coords['xmax']), 0)283dy = tf.maximum(284tf.maximum(groundtruth_coords['ymin'] - logits[:, 0],285logits[:, 0] - groundtruth_coords['ymax']), 0)286
287distance = tf.sqrt(tf.square(dx) + tf.square(dy))288return distance289
290
291def add_distance_loss_to_edge(groundtruth_coords, logits):292distance = calc_distance_to_edge(groundtruth_coords, logits)293tf.losses.compute_weighted_loss(distance)294
295
296def _build_deeplab(samples, outputs_to_num_classes, ignore_label):297"""Builds a clone of DeepLab.298
299Args:
300samples: Feature map from input pipeline.
301outputs_to_num_classes: A map from output type to the number of classes.
302For example, for the task of semantic segmentation with 21 semantic
303classes, we would have outputs_to_num_classes['semantic'] = 21.
304ignore_label: Ignore label.
305
306Returns:
307A map of maps from output_type (e.g., semantic prediction) to a
308dictionary of multi-scale logits names to logits. For each output_type,
309the dictionary has keys which correspond to the scales and values which
310correspond to the logits. For example, if `scales` equals [1.0, 1.5],
311then the keys would include 'merged_logits', 'logits_1.00' and
312'logits_1.50'.
313"""
314
315tf.summary.image('image', samples['image'], 4)316if 'label' in samples:317label_summary(samples['label'])318if FLAGS.use_ref_exp:319tf.summary.text('ref', samples[model_input.REF_EXP_ID])320
321outputs_to_scales_to_logits = model.multi_scale_logits(322samples['image'],323samples,324FLAGS,325outputs_to_num_classes=outputs_to_num_classes,326image_pyramid=FLAGS.image_pyramid,327merge_method=FLAGS.merge_method,328atrous_rates=FLAGS.atrous_rates,329add_image_level_feature=FLAGS.add_image_level_feature,330aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,331aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,332multi_grid=FLAGS.multi_grid,333depth_multiplier=FLAGS.depth_multiplier,334output_stride=FLAGS.output_stride,335decoder_output_stride=FLAGS.decoder_output_stride,336decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,337logits_kernel_size=FLAGS.logits_kernel_size,338crop_size=[FLAGS.image_size, FLAGS.image_size],339model_variant=FLAGS.model_variant,340weight_decay=FLAGS.weight_decay,341is_training=True,342fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)343
344for output, num_classes in outputs_to_num_classes.iteritems():345if output == 'segment':346logits_summary(outputs_to_scales_to_logits[output]['merged_logits'])347if FLAGS.use_sigmoid:348add_sigmoid_cross_entropy_loss_for_each_scale(349outputs_to_scales_to_logits[output], samples['label'], ignore_label,3501.0, FLAGS.upsample_logits, output)351else:352train_utils.add_softmax_cross_entropy_loss_for_each_scale(353outputs_to_scales_to_logits[output],354samples['label'],355num_classes,356ignore_label,357loss_weight=1.0,358upsample_logits=FLAGS.upsample_logits,359scope=output)360
361elif output == 'regression':362for _, logits in outputs_to_scales_to_logits[output].iteritems():363groundtruth_box = {364'xmin': samples[model_input.GROUNDTRUTH_XMIN_ID],365'xmax': samples[model_input.GROUNDTRUTH_XMAX_ID],366'ymin': samples[model_input.GROUNDTRUTH_YMIN_ID],367'ymax': samples[model_input.GROUNDTRUTH_YMAX_ID]368}369add_distance_loss_to_center(samples['label'], logits, groundtruth_box)370return outputs_to_scales_to_logits371
372
373def model_fn(features, labels, mode, params):374"""Defines the model compatible with tf.estimator."""375del labels, params376if mode == tf.estimator.ModeKeys.TRAIN:377_build_deeplab(features, model.get_output_to_num_classes(FLAGS),378model_input.dataset_descriptors[FLAGS.dataset].ignore_label)379
380# Print out the objective loss and regularization loss independently to381# track NaN loss issue382objective_losses = tf.losses.get_losses()383objective_losses = tf.Print(384objective_losses, [objective_losses],385message='Objective Losses: ',386summarize=100)387objective_loss = tf.reduce_sum(objective_losses)388tf.summary.scalar('objective_loss', objective_loss)389
390reg_losses = tf.losses.get_regularization_losses()391reg_losses = tf.Print(392reg_losses, [reg_losses], message='Reg Losses: ', summarize=100)393reg_loss = tf.reduce_sum(reg_losses)394tf.summary.scalar('regularization_loss', reg_loss)395
396loss = objective_loss + reg_loss397
398learning_rate = train_utils.get_model_learning_rate(399FLAGS.learning_policy, FLAGS.base_learning_rate,400FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,401FLAGS.training_number_of_steps, FLAGS.learning_power,402FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)403optimizer = tf.train.AdamOptimizer(learning_rate)404tf.summary.scalar('learning_rate', learning_rate)405
406grads_and_vars = optimizer.compute_gradients(loss)407grad_updates = optimizer.apply_gradients(grads_and_vars,408tf.train.get_global_step())409update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)410update_ops.append(grad_updates)411update_op = tf.group(*update_ops)412with tf.control_dependencies([update_op]):413train_op = tf.identity(loss, name='train_op')414
415return tf.estimator.EstimatorSpec(416mode=mode,417loss=loss,418train_op=train_op,419)420
421
422def tf_dbg_sess_wrapper(sess):423if FLAGS.debug:424print 'DEBUG'425sess = tf_debug.LocalCLIDebugWrapperSession(426sess,427thread_name_filter='MainThread$',428dump_root=os.path.join(FLAGS.train_logdir, 'tfdbg2'))429sess.add_tensor_filter('has_inf_or_nan', tf_debug.has_inf_or_nan)430return sess431
432
433def main(unused_argv):434config = tf.estimator.RunConfig(435model_dir=FLAGS.train_logdir,436save_summary_steps=FLAGS.save_summary_steps,437save_checkpoints_secs=FLAGS.save_interval_secs,438)439
440ws = None441if FLAGS.tf_initial_checkpoint:442checkpoint_vars = tf.train.list_variables(FLAGS.tf_initial_checkpoint)443# Add a ':' so we will only match the specific variable and not others.444checkpoint_vars = [var[0] + ':' for var in checkpoint_vars]445checkpoint_vars.remove('global_step:')446
447ws = tf.estimator.WarmStartSettings(448ckpt_to_initialize_from=FLAGS.tf_initial_checkpoint,449vars_to_warm_start=checkpoint_vars)450
451estimator = tf.estimator.Estimator(452model_fn, FLAGS.train_logdir, config, warm_start_from=ws)453
454with tf.contrib.tfprof.ProfileContext(455FLAGS.train_logdir, enabled=FLAGS.profile):456estimator.train(457model_input.get_input_fn(FLAGS),458max_steps=FLAGS.training_number_of_steps)459
460
461if __name__ == '__main__':462flags.mark_flag_as_required('train_logdir')463app.run()464