google-research

Форк
0
463 строки · 17.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Training script for the VIS model.
17

18
See model.py for more details and usage.
19
"""
20
import os
21
import common  # pylint: disable=unused-import
22
from deeplab import preprocess_utils
23
from deeplab import train_utils
24
import model
25
import model_input
26
import six
27
import tensorflow.compat.v1 as tf
28
from tensorflow.compat.v1.python import debug as tf_debug
29
from tensorflow.compat.v1.python.platform import app
30

31
ZERO_DIV_OFFSET = 1e-20
32

33
flags = tf.app.flags
34
FLAGS = flags.FLAGS
35

36
# Settings for logging.
37

38
flags.DEFINE_string('train_logdir', None,
39
                    'Where the checkpoint and logs are stored.')
40

41
flags.DEFINE_integer('save_interval_secs', 60,
42
                     'How often, in seconds, we save the model to disk.')
43

44
flags.DEFINE_integer('save_summary_steps', 100, '')
45

46
# Settings for training strategry.
47

48
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step', 'cosine'],
49
                  'Learning rate policy for training.')
50

51
flags.DEFINE_float('base_learning_rate', 3e-5,
52
                   'The base learning rate for model training.')
53

54
flags.DEFINE_float('learning_rate_decay_factor', 0.1,
55
                   'The rate to decay the base learning rate.')
56

57
flags.DEFINE_integer(
58
    'learning_rate_decay_step', 8000,
59
    'Decay the base learning rate at a fixed step.'
60
    'Not used if learning_policy == "poly"')
61

62
flags.DEFINE_float('learning_power', 0.9,
63
                   'The power value used in the poly learning policy.')
64

65
flags.DEFINE_integer('training_number_of_steps', 1000000,
66
                     'The number of steps used for training')
67

68
flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
69

70
flags.DEFINE_integer('batch_size', 8,
71
                     'The number of images in each batch during training.')
72

73
flags.DEFINE_float('weight_decay', 0.00004,
74
                   'The value of the weight decay for training.')
75

76
flags.DEFINE_integer('resize_factor', None,
77
                     'Resized dimensions are multiple of factor plus one.')
78

79
flags.DEFINE_boolean('upsample_logits', True,
80
                     'Upsample logits during training.')
81

82
# Settings for fine-tuning the network.
83

84
flags.DEFINE_string('tf_initial_checkpoint', '',
85
                    'The initial checkpoint in tensorflow format.')
86

87
flags.DEFINE_boolean('initialize_last_layer', False,
88
                     'Initialize the last layer.')
89

90
flags.DEFINE_integer('slow_start_step', 0,
91
                     'Training model with small learning rate for few steps.')
92

93
flags.DEFINE_float('slow_start_learning_rate', 1e-4,
94
                   'Learning rate employed during slow start.')
95

96
flags.DEFINE_boolean('fine_tune_batch_norm', True,
97
                     'Fine tune the batch norm parameters or not.')
98

99
flags.DEFINE_string('split', 'train',
100
                    'Which split of the dataset to be used for training')
101

102
flags.DEFINE_bool('debug', False, 'Whether to use tf dbg.')
103

104
flags.DEFINE_boolean('profile', False, '')
105

106
flags.DEFINE_boolean('use_sigmoid', True,
107
                     'Use the custom sigmoid cross entropy function')
108

109
flags.DEFINE_float('sigmoid_recall_weight', 5,
110
                   'If <1 value precision, if >1 recall')
111

112
flags.DEFINE_enum(
113
    'distance_metric', 'euclidean_iter',
114
    ['mse', 'euclidean', 'euclidean_sqrt', 'euclidean_iter'],
115
    'the cost metric for the Click Regression'
116
    '"mse" for mean squared error'
117
    '"euclidean" for euclidean distance'
118
    '"euclidean_sqrt" for square root of euclidean distance')
119

120
flags.DEFINE_bool('ratio_box_distance', False,
121
                  'normalize the distance loss by the size of the box')
122

123
flags.DEFINE_integer('euclidean_step', 300000,
124
                     'decrease exponent of distance loss every euclidean_step')
125

126

127
def logits_summary(logits):
128
  if model_input.dataset_descriptors[FLAGS.dataset].num_classes == 2:
129
    logits_for_sum = tf.concat([logits, tf.zeros_like(logits[:, :, :, 0:1])], 3)
130
  else:
131
    logits_for_sum = logits
132

133
  tf.summary.image('logits', logits_for_sum, 4)
134
  resized = tf.image.resize_bilinear(
135
      logits_for_sum, [FLAGS.image_size, FLAGS.image_size], align_corners=True)
136
  tf.summary.image('resized_logits', resized, 4)
137

138

139
def label_summary(labels):
140
  labels = tf.clip_by_value(labels, 0, 3) * int(255 / 3)
141
  tf.summary.image('label', tf.cast(labels, tf.uint8), 4)
142

143

144
def add_cross_entropy_loss(labels, logits, add_loss):
145
  """Adds accuracy summary. Adds the loss if add_loss is true."""
146
  if add_loss:
147
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
148
        labels=labels, logits=logits)
149
    loss = tf.reduce_mean(loss, name='selected_loss')
150
    tf.losses.add_loss(loss)
151

152
  pred = tf.argmax(logits, 1)
153
  correct = tf.equal(pred, labels)
154
  accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
155
  tf.summary.scalar('selected_accuracy', accuracy)
156

157

158
def add_sigmoid_cross_entropy_loss_for_each_scale(scales_to_logits,
159
                                                  labels,
160
                                                  ignore_label,
161
                                                  loss_weight=1.0,
162
                                                  upsample_logits=True,
163
                                                  scope=None):
164
  """Adds sigmoid cross entropy loss for logits of each scale.
165

166
  Implemented based on deeplab's add_softmax_cross_entropy_loss_for_each_scale
167
  in deeplab/utils/train_utils.py.
168

169
  Args:
170
    scales_to_logits: A map from logits names for different scales to logits.
171
      The logits have shape [batch, logits_height, logits_width, num_classes].
172
    labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
173
    ignore_label: Integer, label to ignore.
174
    loss_weight: Float, loss weight.
175
    upsample_logits: Boolean, upsample logits or not.
176
    scope: String, the scope for the loss.
177

178
  Raises:
179
    ValueError: Label or logits is None.
180
  """
181
  if labels is None:
182
    raise ValueError('No label for softmax cross entropy loss.')
183

184
  for scale, logits in six.iteritems(scales_to_logits):
185
    loss_scope = None
186
    if scope:
187
      loss_scope = '%s_%s' % (scope, scale)
188

189
    if upsample_logits:
190
      # Label is not downsampled, and instead we upsample logits.
191
      logits = tf.image.resize_bilinear(
192
          logits,
193
          preprocess_utils.resolve_shape(labels, 4)[1:3],
194
          align_corners=True)
195
      scaled_labels = labels
196
    else:
197
      # Label is downsampled to the same size as logits.
198
      scaled_labels = tf.image.resize_nearest_neighbor(
199
          labels,
200
          preprocess_utils.resolve_shape(logits, 4)[1:3],
201
          align_corners=True)
202

203
    logits = logits[:, :, :, 1]
204
    scaled_labels = tf.to_float(scaled_labels)
205
    scaled_labels = tf.squeeze(scaled_labels)
206
    not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
207
                                               ignore_label)) * loss_weight
208
    losses = tf.nn.weighted_cross_entropy_with_logits(
209
        scaled_labels, logits, FLAGS.sigmoid_recall_weight)
210

211
    # Loss added later in model_fn by tf.losses.get_total_loss()
212
    tf.losses.compute_weighted_loss(
213
        losses, weights=not_ignore_mask, scope=loss_scope)
214

215

216
def add_distance_loss_to_center(labels, logits, groundtruth_coords):
217
  """Add distance loss function for ClickRegression."""
218
  weights = tf.to_int32(
219
      tf.not_equal(labels,
220
                   model_input.dataset_descriptors[FLAGS.dataset].ignore_label))
221
  labels *= weights
222

223
  # Use GT box to get center if it exists. Less computation required.
224
  # Otherwise, calculate from label mask.
225
  if FLAGS.use_groundtruth_box:
226
    center_x = (groundtruth_coords['xmin'] + groundtruth_coords['xmax']) / 2.0
227
    center_y = (groundtruth_coords['ymin'] + groundtruth_coords['ymax']) / 2.0
228
    center = tf.stack([center_y, center_x], axis=1)
229
  else:
230
    # Make array of coordinates (each row contains three coordinates)
231
    ii, jj = tf.meshgrid(
232
        tf.range(FLAGS.image_size), tf.range(FLAGS.image_size), indexing='ij')
233
    coords = tf.stack([tf.reshape(ii, (-1,)), tf.reshape(jj, (-1,))], axis=-1)
234
    coords = tf.cast(coords, tf.int32)
235

236
    # Rearrange input into one vector per volume
237
    volumes_flat = tf.reshape(labels,
238
                              [-1, FLAGS.image_size * FLAGS.image_size * 1, 1])
239
    # Compute total mass for each volume. Add 0.00001 to prevent division by 0
240
    total_mass = tf.cast(tf.reduce_sum(volumes_flat, axis=1),
241
                         tf.float32) + ZERO_DIV_OFFSET
242
    # Compute centre of mass
243
    center = tf.cast(tf.reduce_sum(volumes_flat * coords, axis=1),
244
                     tf.float32) / total_mass
245
    center = center / FLAGS.image_size
246

247
  # Normalize coordinates by size of image
248
  logits = logits / FLAGS.image_size
249

250
  # Calculate loss based on the distance metric specified
251
  # Loss added later in model_fn by tf.losses.get_total_loss()
252
  if FLAGS.distance_metric == 'mse':
253
    tf.losses.mean_squared_error(center, logits)
254
  elif FLAGS.distance_metric in [
255
      'euclidean', 'euclidean_sqrt', 'euclidean_iter'
256
  ]:
257
    distance_to_center = tf.sqrt(
258
        tf.reduce_sum(tf.square(logits - center), axis=-1) + ZERO_DIV_OFFSET)
259
    if FLAGS.ratio_box_distance:
260
      distance_to_box = calc_distance_to_edge(groundtruth_coords, logits)
261
      box_distance_to_center = (
262
          tf.to_float(distance_to_center) - distance_to_box)
263
      loss = distance_to_center / (box_distance_to_center + ZERO_DIV_OFFSET)
264
    else:
265
      loss = distance_to_center
266

267
    if FLAGS.distance_metric == 'euclidean_sqrt':
268
      loss = tf.sqrt(loss)
269
    if FLAGS.distance_metric == 'euclidean_iter':
270
      iter_num = tf.to_float(tf.train.get_or_create_global_step())
271
      step = (iter_num // FLAGS.euclidean_step) + 1.0
272
      loss = tf.pow(loss, tf.to_float(1.0 / step))
273
    tf.losses.compute_weighted_loss(loss)
274

275

276
def calc_distance_to_edge(groundtruth_coords, logits):
277
  """Calculate distance between predicted point to box of ground truth."""
278

279
  # Returns 0 if predicted point is inside the groundtruth box
280
  dx = tf.maximum(
281
      tf.maximum(groundtruth_coords['xmin'] - logits[:, 1],
282
                 logits[:, 1] - groundtruth_coords['xmax']), 0)
283
  dy = tf.maximum(
284
      tf.maximum(groundtruth_coords['ymin'] - logits[:, 0],
285
                 logits[:, 0] - groundtruth_coords['ymax']), 0)
286

287
  distance = tf.sqrt(tf.square(dx) + tf.square(dy))
288
  return distance
289

290

291
def add_distance_loss_to_edge(groundtruth_coords, logits):
292
  distance = calc_distance_to_edge(groundtruth_coords, logits)
293
  tf.losses.compute_weighted_loss(distance)
294

295

296
def _build_deeplab(samples, outputs_to_num_classes, ignore_label):
297
  """Builds a clone of DeepLab.
298

299
  Args:
300
    samples: Feature map from input pipeline.
301
    outputs_to_num_classes: A map from output type to the number of classes.
302
      For example, for the task of semantic segmentation with 21 semantic
303
      classes, we would have outputs_to_num_classes['semantic'] = 21.
304
    ignore_label: Ignore label.
305

306
  Returns:
307
    A map of maps from output_type (e.g., semantic prediction) to a
308
      dictionary of multi-scale logits names to logits. For each output_type,
309
      the dictionary has keys which correspond to the scales and values which
310
      correspond to the logits. For example, if `scales` equals [1.0, 1.5],
311
      then the keys would include 'merged_logits', 'logits_1.00' and
312
      'logits_1.50'.
313
  """
314

315
  tf.summary.image('image', samples['image'], 4)
316
  if 'label' in samples:
317
    label_summary(samples['label'])
318
  if FLAGS.use_ref_exp:
319
    tf.summary.text('ref', samples[model_input.REF_EXP_ID])
320

321
  outputs_to_scales_to_logits = model.multi_scale_logits(
322
      samples['image'],
323
      samples,
324
      FLAGS,
325
      outputs_to_num_classes=outputs_to_num_classes,
326
      image_pyramid=FLAGS.image_pyramid,
327
      merge_method=FLAGS.merge_method,
328
      atrous_rates=FLAGS.atrous_rates,
329
      add_image_level_feature=FLAGS.add_image_level_feature,
330
      aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
331
      aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
332
      multi_grid=FLAGS.multi_grid,
333
      depth_multiplier=FLAGS.depth_multiplier,
334
      output_stride=FLAGS.output_stride,
335
      decoder_output_stride=FLAGS.decoder_output_stride,
336
      decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
337
      logits_kernel_size=FLAGS.logits_kernel_size,
338
      crop_size=[FLAGS.image_size, FLAGS.image_size],
339
      model_variant=FLAGS.model_variant,
340
      weight_decay=FLAGS.weight_decay,
341
      is_training=True,
342
      fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
343

344
  for output, num_classes in outputs_to_num_classes.iteritems():
345
    if output == 'segment':
346
      logits_summary(outputs_to_scales_to_logits[output]['merged_logits'])
347
      if FLAGS.use_sigmoid:
348
        add_sigmoid_cross_entropy_loss_for_each_scale(
349
            outputs_to_scales_to_logits[output], samples['label'], ignore_label,
350
            1.0, FLAGS.upsample_logits, output)
351
      else:
352
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
353
            outputs_to_scales_to_logits[output],
354
            samples['label'],
355
            num_classes,
356
            ignore_label,
357
            loss_weight=1.0,
358
            upsample_logits=FLAGS.upsample_logits,
359
            scope=output)
360

361
    elif output == 'regression':
362
      for _, logits in outputs_to_scales_to_logits[output].iteritems():
363
        groundtruth_box = {
364
            'xmin': samples[model_input.GROUNDTRUTH_XMIN_ID],
365
            'xmax': samples[model_input.GROUNDTRUTH_XMAX_ID],
366
            'ymin': samples[model_input.GROUNDTRUTH_YMIN_ID],
367
            'ymax': samples[model_input.GROUNDTRUTH_YMAX_ID]
368
        }
369
        add_distance_loss_to_center(samples['label'], logits, groundtruth_box)
370
  return outputs_to_scales_to_logits
371

372

373
def model_fn(features, labels, mode, params):
374
  """Defines the model compatible with tf.estimator."""
375
  del labels, params
376
  if mode == tf.estimator.ModeKeys.TRAIN:
377
    _build_deeplab(features, model.get_output_to_num_classes(FLAGS),
378
                   model_input.dataset_descriptors[FLAGS.dataset].ignore_label)
379

380
    #  Print out the objective loss and regularization loss independently to
381
    #  track NaN loss issue
382
    objective_losses = tf.losses.get_losses()
383
    objective_losses = tf.Print(
384
        objective_losses, [objective_losses],
385
        message='Objective Losses: ',
386
        summarize=100)
387
    objective_loss = tf.reduce_sum(objective_losses)
388
    tf.summary.scalar('objective_loss', objective_loss)
389

390
    reg_losses = tf.losses.get_regularization_losses()
391
    reg_losses = tf.Print(
392
        reg_losses, [reg_losses], message='Reg Losses: ', summarize=100)
393
    reg_loss = tf.reduce_sum(reg_losses)
394
    tf.summary.scalar('regularization_loss', reg_loss)
395

396
    loss = objective_loss + reg_loss
397

398
    learning_rate = train_utils.get_model_learning_rate(
399
        FLAGS.learning_policy, FLAGS.base_learning_rate,
400
        FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
401
        FLAGS.training_number_of_steps, FLAGS.learning_power,
402
        FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
403
    optimizer = tf.train.AdamOptimizer(learning_rate)
404
    tf.summary.scalar('learning_rate', learning_rate)
405

406
    grads_and_vars = optimizer.compute_gradients(loss)
407
    grad_updates = optimizer.apply_gradients(grads_and_vars,
408
                                             tf.train.get_global_step())
409
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
410
    update_ops.append(grad_updates)
411
    update_op = tf.group(*update_ops)
412
    with tf.control_dependencies([update_op]):
413
      train_op = tf.identity(loss, name='train_op')
414

415
    return tf.estimator.EstimatorSpec(
416
        mode=mode,
417
        loss=loss,
418
        train_op=train_op,
419
    )
420

421

422
def tf_dbg_sess_wrapper(sess):
423
  if FLAGS.debug:
424
    print 'DEBUG'
425
    sess = tf_debug.LocalCLIDebugWrapperSession(
426
        sess,
427
        thread_name_filter='MainThread$',
428
        dump_root=os.path.join(FLAGS.train_logdir, 'tfdbg2'))
429
    sess.add_tensor_filter('has_inf_or_nan', tf_debug.has_inf_or_nan)
430
  return sess
431

432

433
def main(unused_argv):
434
  config = tf.estimator.RunConfig(
435
      model_dir=FLAGS.train_logdir,
436
      save_summary_steps=FLAGS.save_summary_steps,
437
      save_checkpoints_secs=FLAGS.save_interval_secs,
438
  )
439

440
  ws = None
441
  if FLAGS.tf_initial_checkpoint:
442
    checkpoint_vars = tf.train.list_variables(FLAGS.tf_initial_checkpoint)
443
    # Add a ':' so we will only match the specific variable and not others.
444
    checkpoint_vars = [var[0] + ':' for var in checkpoint_vars]
445
    checkpoint_vars.remove('global_step:')
446

447
    ws = tf.estimator.WarmStartSettings(
448
        ckpt_to_initialize_from=FLAGS.tf_initial_checkpoint,
449
        vars_to_warm_start=checkpoint_vars)
450

451
  estimator = tf.estimator.Estimator(
452
      model_fn, FLAGS.train_logdir, config, warm_start_from=ws)
453

454
  with tf.contrib.tfprof.ProfileContext(
455
      FLAGS.train_logdir, enabled=FLAGS.profile):
456
    estimator.train(
457
        model_input.get_input_fn(FLAGS),
458
        max_steps=FLAGS.training_number_of_steps)
459

460

461
if __name__ == '__main__':
462
  flags.mark_flag_as_required('train_logdir')
463
  app.run()
464

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.