google-research

Форк
0
/
launcher.py 
783 строки · 31.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contrastive Learning training/eval code."""
17

18
import os
19

20
from absl import app
21
from absl import flags
22
import tensorflow.compat.v1 as tf
23
from tensorflow.compat.v1 import estimator as tf_estimator
24
import tensorflow.compat.v2 as tf2
25
import tf_slim as slim
26

27
from supcon import enums
28
from supcon import hparams as hparams_lib
29
from supcon import hparams_flags
30
from supcon import inputs
31
from supcon import losses
32
from supcon import models
33
from supcon import preprocessing
34
from supcon import utils
35

36
flags.DEFINE_string(
37
    'hparams', None,
38
    'A serialized hparams string representing the hyperparameters to use. If '
39
    'not set, fall back to using the individual hyperparameter flags defined '
40
    'in hparams_flags.py')
41
flags.DEFINE_enum(
42
    'mode', 'train', ['train', 'eval', 'train_then_eval'],
43
    'The mode for this job, either "train", "eval", or '
44
    '"train_then_eval".')
45
flags.DEFINE_string(
46
    'model_dir', '', 'Root of the tree containing all files for the current '
47
    'model.')
48
flags.DEFINE_string('master', None, 'Address of the TensorFlow runtime.')
49
flags.DEFINE_integer('summary_interval_steps', 100,
50
                     'Number of steps in between logging training summaries.')
51
flags.DEFINE_integer('save_interval_steps', 1000,
52
                     'Number of steps in between saving model checkpoints.')
53
flags.DEFINE_integer('max_checkpoints_to_keep', 5,
54
                     'Maximum number of recent checkpoints to keep.')
55
flags.DEFINE_float(
56
    'keep_checkpoint_interval_secs',
57
    60 * 60 * 1000 * 10,  # 10,000 hours
58
    'Number of seconds in between permanently retained checkpoints.')
59
flags.DEFINE_integer(
60
    'steps_per_loop', 1000,
61
    'Number of steps to execute on TPU before returning control to the '
62
    'coordinator. Checkpoints will be taken at least these many steps apart.')
63
flags.DEFINE_boolean('use_tpu', True, 'Whether this is running on a TPU.')
64
flags.DEFINE_integer('eval_interval_secs', 60, 'Time interval between evals.')
65
flags.DEFINE_string(
66
    'reference_ckpt', '',
67
    '[Optional] If set, attempt to initialize the model using the latest '
68
    'checkpoint in this directory.')
69
flags.DEFINE_string(
70
    'data_dir', None,
71
    'The directory that will be passed as the `data_dir` argument to '
72
    '`tfds.load`.')
73
tf.flags.DEFINE_string(
74
    'tpu_name', None,
75
    'The Cloud TPU to use for training. This should be either the name '
76
    'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
77
    'url.')
78
tf.flags.DEFINE_string(
79
    'tpu_zone', None,
80
    '[Optional] GCE zone where the Cloud TPU is located in. If not '
81
    'specified, we will attempt to automatically detect the GCE project from '
82
    'metadata.')
83
tf.flags.DEFINE_string(
84
    'gcp_project', None,
85
    '[Optional] Project name for the Cloud TPU-enabled project. If not '
86
    'specified, we will attempt to automatically detect the GCE project from '
87
    'metadata.')
88

89
FLAGS = flags.FLAGS
90

91
# Learning rate hparam values are defined with respect to this batch size. The
92
# true learning rate will be scaled by batch_size/BASE_BATCH_SIZE.
93
BASE_BATCH_SIZE = 256
94

95

96
class ContrastiveTrainer:
97
  """Encapsulates the train, eval, and inference logic of a contrastive model.
98

99
  Upon construction of this class, the model graph is created. In train and eval
100
  mode, the loss computation graph is also created at construction time.
101

102
  Attrs:
103
    model_inputs: The inputs to the model. These should be Tensors with shape
104
      [batch_size, side_length, side_length, 3 * views] with values in range
105
      [-1, 1] and dtype tf.float32 or tf.bfloat16. Currently views must be 1 or
106
      2.
107
    labels: The labels corresponding to `model_inputs`. A Tensor of shape
108
      [batch_size] with integer dtype.
109
    hparams: A hparams.HParams instance, reflecting the hyperparameters of the
110
      model and its training.
111
    mode: An enums.ModelMode value.
112
    num_classes: The cardinality of the labelset, and also the number of output
113
      logits of the classification head.
114
    training_set_size: The number of samples in the training set.
115
    is_tpu: Whether this is running on a TPU.
116
  """
117

118
  def __init__(self,
119
               model_inputs,
120
               labels,
121
               train_global_batch_size,
122
               hparams,
123
               mode,
124
               num_classes,
125
               training_set_size,
126
               is_tpu=False):
127
    self.model_inputs = model_inputs
128
    self.labels = labels
129
    self.train_global_batch_size = train_global_batch_size
130
    self.hparams = hparams
131
    self.mode = mode
132
    assert isinstance(mode, enums.ModelMode)
133
    self.num_classes = num_classes
134
    self.training_set_size = training_set_size
135
    self.is_tpu = is_tpu
136
    self._summary_dict = {}
137

138
    if not self.inference:
139
      if tf.compat.dimension_at_index(self.model_inputs.shape, -1) != 6:
140
        raise ValueError(
141
            'Both train and eval modes must have 2 views provided, '
142
            'concatenated in the channels dimension.')
143

144
    self.data_format = ('channels_first' if not self.inference and
145
                        tf.config.list_physical_devices('GPU') else
146
                        'channels_last')
147

148
    if self.eval:
149
      self._summary_update_ops = []
150

151
    self.model = self._create_model()
152

153
    is_bn_train_mode = (
154
        # We intentionally run with batch norm in train mode for inference. We
155
        # call model() a second time with training=False for inference mode
156
        # below, and include both in the inference graph and SavedModel.
157
        not self.eval and (not FLAGS.reference_ckpt or
158
                           self.hparams.warm_start.batch_norm_in_train_mode))
159
    (self.unnormalized_embedding, self.normalized_embedding, self.projection,
160
     self.logits) = self._call_model(training=is_bn_train_mode)
161

162
    if self.inference:
163
      (self.unnormalized_embedding_eval, self.normalized_embedding_eval,
164
       self.projection_eval,
165
       self.logits_eval) = self._call_model(training=False)
166
      return
167

168
    self._encoder_weights = self._compute_weights('Encoder')
169
    self._projection_head_weights = self._compute_weights('ProjectionHead')
170
    self._classification_head_weights = self._compute_weights(
171
        'ClassificationHead')
172

173
    self.contrastive_loss = self._compute_contrastive_loss()
174
    self.cross_entropy_loss = self._compute_cross_entropy_loss()
175

176
  @property
177
  def train(self):
178
    return self.mode == enums.ModelMode.TRAIN
179

180
  @property
181
  def eval(self):
182
    return self.mode == enums.ModelMode.EVAL
183

184
  @property
185
  def inference(self):
186
    return self.mode == enums.ModelMode.INFERENCE
187

188
  def _add_scalar_summary(self, name, tensor):
189
    """Collects tensors that should be written as summaries in `host_call`."""
190
    self._summary_dict[name] = tensor
191

192
  def _create_model(self):
193
    """Creates the model, but does not build it or create variables.
194

195
    Returns:
196
      A callable Keras layer that implements the model architecture.
197
    """
198
    arch_hparams = self.hparams.architecture
199
    model = models.ContrastiveModel(
200
        architecture=arch_hparams.encoder_architecture,
201
        normalize_projection_head_input=(
202
            arch_hparams.normalize_projection_head_inputs),
203
        normalize_classification_head_input=(
204
            arch_hparams.normalize_classifier_inputs),
205
        stop_gradient_before_classification_head=(
206
            arch_hparams.stop_gradient_before_classification_head),
207
        stop_gradient_before_projection_head=(
208
            arch_hparams.stop_gradient_before_projection_head),
209
        encoder_kwargs={
210
            'depth': arch_hparams.encoder_depth,
211
            'width': arch_hparams.encoder_width,
212
            'first_conv_kernel_size': arch_hparams.first_conv_kernel_size,
213
            'first_conv_stride': arch_hparams.first_conv_stride,
214
            'data_format': self.data_format,
215
            'use_initial_max_pool': arch_hparams.use_initial_max_pool,
216
            'use_global_batch_norm': arch_hparams.use_global_batch_norm,
217
        },
218
        projection_head_kwargs={
219
            'feature_dims':
220
                arch_hparams.projection_head_layers,
221
            'normalize_output':
222
                True,
223
            'use_batch_norm':
224
                arch_hparams.projection_head_use_batch_norm,
225
            'use_batch_norm_beta':
226
                arch_hparams.projection_head_use_batch_norm_beta,
227
            'use_global_batch_norm':
228
                arch_hparams.use_global_batch_norm,
229
        },
230
        classification_head_kwargs={
231
            'num_classes':
232
                self.num_classes,
233
            'kernel_initializer': (tf.initializers.zeros()
234
                                   if arch_hparams.zero_initialize_classifier
235
                                   else tf.initializers.glorot_uniform)
236
        })
237

238
    return model
239

240
  def _call_model(self, training):
241
    """Passes data through the model.
242

243
    Manipulates the input data to get it ready for passing into the model,
244
    including applying some data augmentation that is more efficient to apply on
245
    the TPU than on the host. It then passes it into the model, which will first
246
    build the model and create its variables.
247

248
    Args:
249
      training: Whether the model should be run in training mode.
250

251
    Returns:
252
      A tuple of the model outputs (as Tensors):
253
      * unnormalized_embedding: The output of the encoder, not including
254
        normalization, which is sometimes applied before this gets passed into
255
        the projection and classification heads.
256
      * normalized_embedding: A normalized version of `unnormalized_embedding`.
257
      * projection: The output of the projection head.
258
      * logits: The output of the classification head.
259
    """
260
    with tf.name_scope('call_model'):
261
      model_inputs = self.model_inputs
262

263
      # In most cases, the data format NCHW instead of NHWC should be used for a
264
      # significant performance boost on GPU. NHWC should be used only if the
265
      # network needs to be run on CPU since the pooling operations are only
266
      # supported on NHWC. TPU uses XLA compiler to figure out best layout.
267
      if self.data_format == 'channels_first':
268
        model_inputs = tf.transpose(model_inputs, [0, 3, 1, 2])
269

270
      channels_index = 1 if self.data_format == 'channels_first' else -1
271
      inputs_are_multiview = tf.compat.dimension_value(
272
          model_inputs.shape[channels_index]) > 3
273
      if inputs_are_multiview:
274
        model_inputs = utils.stacked_multiview_image_channels_to_batch(
275
            model_inputs, self.data_format)
276

277
      # Perform blur augmentations here, since they're faster on TPU than CPU.
278
      if (self.hparams.input_data.preprocessing.augmentation_type in (
279
          enums.AugmentationType.SIMCLR,
280
          enums.AugmentationType.STACKED_RANDAUGMENT) and
281
          self.hparams.input_data.preprocessing.blur_probability > 0. and
282
          self.hparams.input_data.preprocessing.defer_blurring and self.train):
283
        model_inputs = preprocessing.batch_random_blur(
284
            model_inputs,
285
            tf.compat.dimension_value(model_inputs.shape[1]),
286
            blur_probability=(
287
                self.hparams.input_data.preprocessing.blur_probability))
288

289
      with tf.tpu.bfloat16_scope():
290
        model_outputs = self.model(model_inputs, training)
291

292
      if inputs_are_multiview:
293
        model_outputs = [
294
            utils.stacked_multiview_embeddings_to_channel(
295
                tf.cast(x, tf.float32)) if x is not None else x
296
            for x in model_outputs
297
        ]
298

299
      (unnormalized_embedding, normalized_embedding, projection,
300
       logits) = model_outputs
301

302
      if inputs_are_multiview:
303
        # If we keep everything in batch dimension then we don't need this. In
304
        # cross_entropy mode we should just stop generating the 2nd
305
        # augmentation.
306
        logits = tf.split(logits, 2, axis=1)[0]
307

308
      return unnormalized_embedding, normalized_embedding, projection, logits
309

310
  def _compute_cross_entropy_loss(self):
311
    """Computes and returns the cross-entropy loss on the logits."""
312
    with tf.name_scope('cross_entropy_loss'):
313
      one_hot_labels = tf.one_hot(self.labels, self.num_classes)
314
      cross_entropy = tf.losses.softmax_cross_entropy(
315
          logits=self.logits,
316
          onehot_labels=one_hot_labels,
317
          label_smoothing=(
318
              self.hparams.loss_all_stages.cross_entropy.label_smoothing),
319
          reduction=tf.losses.Reduction.NONE)
320

321
    if self.train:
322
      in_top_1 = tf.cast(
323
          tf.nn.in_top_k(self.logits, self.labels, 1), tf.float32)
324
      in_top_5 = tf.cast(
325
          tf.nn.in_top_k(self.logits, self.labels, 5), tf.float32)
326
      self._add_scalar_summary('top_1_accuracy', in_top_1)
327
      self._add_scalar_summary('top_5_accuracy', in_top_5)
328
      self._add_scalar_summary('loss/cross_entropy_loss', cross_entropy)
329
      cross_entropy = tf.reduce_mean(cross_entropy)
330

331
    return cross_entropy
332

333
  def _compute_contrastive_loss(self):
334
    """Computes and returns the contrastive loss on the projection."""
335
    with tf.name_scope('contrastive_loss'):
336
      contrastive_params = self.hparams.loss_all_stages.contrastive
337
      labels = (
338
          tf.one_hot(self.labels, self.num_classes)
339
          if contrastive_params.use_labels else None)
340
      projection = self.projection
341
      projection_view_1, projection_view_2 = tf.split(projection, 2, axis=-1)
342
      contrastive_loss = losses.contrastive_loss(
343
          tf.stack([projection_view_1, projection_view_2], axis=1),
344
          labels=labels,
345
          temperature=contrastive_params.temperature,
346
          contrast_mode=contrastive_params.contrast_mode,
347
          summation_location=contrastive_params.summation_location,
348
          denominator_mode=contrastive_params.denominator_mode,
349
          positives_cap=contrastive_params.positives_cap,
350
          scale_by_temperature=contrastive_params.scale_by_temperature)
351

352
    if self.train:
353
      self._add_scalar_summary('loss/contrastive_loss', contrastive_loss)
354
      contrastive_loss = tf.reduce_mean(contrastive_loss)
355

356
    return contrastive_loss
357

358
  def _compute_stage_weight_decay(self, stage_params, stage_name):
359
    """Computes and returns the weight decay loss for a single training stage.
360

361
    Args:
362
      stage_params: An instance of hparams.Stage.
363
      stage_name: A string name for this stage.
364

365
    Returns:
366
      A scalar Tensor representing the weight decay loss for this stage.
367
    """
368
    with tf.name_scope(f'{stage_name}_weight_decay'):
369
      # Don't use independent weight decay with LARS optimizer, since it handles
370
      # it internally.
371
      weight_decay_coeff = (
372
          stage_params.loss.weight_decay_coeff if
373
          stage_params.training.optimizer is not enums.Optimizer.LARS else 0.)
374
      weights = (
375
          self._encoder_weights *
376
          float(stage_params.loss.use_encoder_weight_decay) +
377
          self._projection_head_weights *
378
          float(stage_params.loss.use_projection_head_weight_decay) +
379
          self._classification_head_weights *
380
          float(stage_params.loss.use_classification_head_weight_decay))
381
      weight_decay_loss = weight_decay_coeff * weights
382
    with tf.name_scope(''):
383
      self._add_scalar_summary(f'weight_decay/{stage_name}_weight_decay_loss',
384
                               weight_decay_loss)
385
    return weight_decay_loss
386

387
  def _compute_weights(self, scope_name):
388
    """Computes the sum of the L2 norms of all kernel weights inside a scope."""
389

390
    def is_valid_weight(v):
391
      if (scope_name in v.name and 'batch_normalization' not in v.name and
392
          ('bias' not in v.name or
393
           self.hparams.loss_all_stages.include_bias_in_weight_decay)):
394
        return True
395
      return False
396

397
    with tf.name_scope(f'sum_{scope_name}_weights'):
398
      valid_weights = filter(is_valid_weight, tf.trainable_variables())
399
      sum_of_weights = tf.add_n([tf.nn.l2_loss(v) for v in valid_weights])
400

401
    self._add_scalar_summary(f'weight_decay/{scope_name}_weights',
402
                             sum_of_weights)
403
    return sum_of_weights
404

405
  def train_op(self):
406
    """Creates the Op for training this network.
407

408
    Computes learning rates, builds optimizers, and constructs the train ops to
409
    minimize the losses.
410

411
    Returns:
412
      A TensorFlow Op that will run one step of training when executed.
413
    """
414
    with tf.name_scope('train'):
415
      batch_size = self.train_global_batch_size
416
      steps_per_epoch = self.training_set_size / batch_size
417
      stage_1_epochs = self.hparams.stage_1.training.train_epochs
418
      stage_1_steps = int(stage_1_epochs * steps_per_epoch)
419
      stage_2_epochs = self.hparams.stage_2.training.train_epochs
420
      global_step = tf.train.get_or_create_global_step()
421
      stage_1_indicator = tf.math.less(global_step, stage_1_steps)
422
      stage_2_indicator = tf.math.logical_not(stage_1_indicator)
423

424
      def stage_learning_rate(stage_training_params, start_epoch, end_epoch):
425
        schedule_kwargs = {}
426
        if (stage_training_params.learning_rate_decay in (
427
            enums.DecayType.PIECEWISE_LINEAR, enums.DecayType.EXPONENTIAL)):
428
          schedule_kwargs['decay_rate'] = stage_training_params.decay_rate
429
          if (stage_training_params.learning_rate_decay ==
430
              enums.DecayType.PIECEWISE_LINEAR):
431
            schedule_kwargs['boundary_epochs'] = (
432
                stage_training_params.decay_boundary_epochs)
433
          if (stage_training_params.learning_rate_decay ==
434
              enums.DecayType.EXPONENTIAL):
435
            schedule_kwargs['epochs_per_decay'] = (
436
                stage_training_params.epochs_per_decay)
437

438
        return utils.build_learning_rate_schedule(
439
            learning_rate=(stage_training_params.base_learning_rate *
440
                           (batch_size / BASE_BATCH_SIZE)),
441
            decay_type=stage_training_params.learning_rate_decay,
442
            warmup_start_epoch=start_epoch,
443
            max_learning_rate_epoch=(
444
                start_epoch +
445
                stage_training_params.learning_rate_warmup_epochs),
446
            decay_end_epoch=end_epoch,
447
            global_step=global_step,
448
            steps_per_epoch=steps_per_epoch,
449
            **schedule_kwargs)
450

451
      stage_1_learning_rate = stage_learning_rate(
452
          self.hparams.stage_1.training,
453
          start_epoch=0,
454
          end_epoch=stage_1_epochs) * tf.cast(stage_1_indicator, tf.float32)
455
      stage_2_learning_rate = stage_learning_rate(
456
          self.hparams.stage_2.training,
457
          start_epoch=stage_1_epochs,
458
          end_epoch=stage_1_epochs + stage_2_epochs) * tf.cast(
459
              stage_2_indicator, tf.float32)
460

461
      def stage_optimizer(stage_learning_rate, stage_params, stage_name):
462
        lars_exclude_from_weight_decay = ['batch_normalization']
463
        if not self.hparams.loss_all_stages.include_bias_in_weight_decay:
464
          lars_exclude_from_weight_decay.append('bias')
465
        if not stage_params.loss.use_encoder_weight_decay:
466
          lars_exclude_from_weight_decay.append('Encoder')
467
        if not stage_params.loss.use_projection_head_weight_decay:
468
          lars_exclude_from_weight_decay.append('ProjectionHead')
469
        if not stage_params.loss.use_classification_head_weight_decay:
470
          lars_exclude_from_weight_decay.append('ClassificationHead')
471

472
        return utils.build_optimizer(
473
            stage_learning_rate,
474
            optimizer_type=stage_params.training.optimizer,
475
            lars_weight_decay=stage_params.loss.weight_decay_coeff,
476
            lars_exclude_from_weight_decay=lars_exclude_from_weight_decay,
477
            epsilon=stage_params.training.rmsprop_epsilon,
478
            is_tpu=self.is_tpu,
479
            name=stage_name)
480

481
      stage_1_optimizer = stage_optimizer(stage_1_learning_rate,
482
                                          self.hparams.stage_1, 'stage1')
483
      stage_2_optimizer = stage_optimizer(stage_2_learning_rate,
484
                                          self.hparams.stage_2, 'stage2')
485

486
      def stage_loss(stage_params, stage_name):
487
        return (
488
            stage_params.loss.contrastive_weight * self.contrastive_loss +
489
            stage_params.loss.cross_entropy_weight * self.cross_entropy_loss +
490
            self._compute_stage_weight_decay(stage_params, stage_name))
491

492
      stage_1_loss = stage_loss(self.hparams.stage_1, 'stage1')
493
      stage_2_loss = stage_loss(self.hparams.stage_2, 'stage2')
494

495
      def stage_1_train_op():
496
        return utils.create_train_op(
497
            stage_1_loss,
498
            stage_1_optimizer,
499
            update_ops=(None if
500
                        self.hparams.stage_1.training.update_encoder_batch_norm
501
                        else []))
502

503
      def stage_2_train_op():
504
        return utils.create_train_op(
505
            stage_2_loss,
506
            stage_2_optimizer,
507
            update_ops=(None if
508
                        self.hparams.stage_2.training.update_encoder_batch_norm
509
                        else []))
510

511
      train_op = tf.cond(stage_1_indicator, stage_1_train_op, stage_2_train_op)
512

513
    self._add_scalar_summary('stage_1_learning_rate', stage_1_learning_rate)
514
    self._add_scalar_summary('stage_2_learning_rate', stage_2_learning_rate)
515
    self._add_scalar_summary('current_epoch',
516
                             tf.cast(global_step, tf.float32) / steps_per_epoch)
517

518
    return train_op
519

520
  def host_call(self, summary_dir):
521
    """Creates a host call to write summaries."""
522

523
    # Ensure that all host_call inputs have batch dimensions, since they get
524
    # concatenated from all cores along the batch dimension.
525
    summary_dict = {
526
        k: tf.expand_dims(v, axis=0) if v.shape.rank == 0 else v
527
        for k, v in self._summary_dict.items()
528
    }
529

530
    # Pass in the global step, since otherwise we might use a stale copy of the
531
    # variable from the host.
532
    global_step_key = 'global_step_is_not_a_summary'
533
    summary_dict[global_step_key] = tf.expand_dims(
534
        tf.train.get_or_create_global_step(), axis=0)
535

536
    def host_call_fn(**kwargs):
537
      step = kwargs[global_step_key][0]
538
      del kwargs[global_step_key]
539
      writer = tf2.summary.create_file_writer(summary_dir, max_queue=1000)
540
      always_record = tf2.summary.record_if(True)
541
      with writer.as_default(), always_record:
542
        for name, scalar in kwargs.items():
543
          tf2.summary.scalar(name, tf.reduce_mean(scalar), step)
544
      return tf.summary.all_v2_summary_ops()
545

546
    return host_call_fn, summary_dict
547

548
  def eval_metrics(self):
549
    """Returns eval metric_fn and metrics."""
550

551
    def metric_fn(logits, labels, contrastive_loss, cross_entropy_loss):
552
      metrics = {}
553
      in_top_1 = tf.cast(tf.nn.in_top_k(logits, labels, 1), tf.float32)
554
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
555
      metrics['top_1_accuracy'] = tf.metrics.mean(in_top_1)
556
      metrics['top_5_accuracy'] = tf.metrics.mean(in_top_5)
557
      metrics['loss/contrastive_loss'] = tf.metrics.mean(contrastive_loss)
558
      metrics['loss/cross_entropy_loss'] = tf.metrics.mean(cross_entropy_loss)
559
      return metrics
560

561
    metrics = {
562
        'logits': self.logits,
563
        'labels': self.labels,
564
        'contrastive_loss': self.contrastive_loss,
565
        'cross_entropy_loss': self.cross_entropy_loss,
566
    }
567

568
    return metric_fn, metrics
569

570
  def signature_def_map(self):
571
    """Returns a SignatureDef map that can be used to produce SavedModels."""
572
    signature_def_map = {}
573
    signature_def_map['contrastive_train'] = {
574
        'embeddings': self.normalized_embedding,
575
        'unnormalized_embeddings': self.unnormalized_embedding,
576
        'projection': self.projection,
577
        'logits': self.logits,
578
    }
579
    signature_def_map['contrastive_eval'] = {
580
        'embeddings': self.normalized_embedding_eval,
581
        'unnormalized_embeddings': self.unnormalized_embedding_eval,
582
        'projection': self.projection_eval,
583
        'logits': self.logits_eval,
584
    }
585
    return signature_def_map
586

587
  def scaffold_fn(self):
588
    """Creates a function that produces a tf.train.Scaffold for custom init.
589

590
    When appropriate, it restores all or some of the weights from a checkpoint
591
    at model initialization.
592

593
    Returns:
594
      A function that produces a tf.train.Scaffold.
595
    """
596

597
    def var_matches_patterns(var, patterns):
598
      return any(pattern in var.name for pattern in patterns)
599

600
    def scaffold_fn():
601
      """Scaffold function."""
602
      warm_start_hparams = self.hparams.warm_start
603
      if FLAGS.reference_ckpt:
604
        with tf.name_scope('warm_start'):
605
          include_pattern_list = []
606
          if warm_start_hparams.warm_start_encoder:
607
            include_pattern_list.append('ContrastiveModel/Encoder')
608
          if warm_start_hparams.warm_start_projection_head:
609
            include_pattern_list.append('ContrastiveModel/ProjectionHead')
610
          if warm_start_hparams.warm_start_classifier:
611
            include_pattern_list.append('ContrastiveModel/ClassificationHead')
612
          # This needs to be updated if new optimizers are added.
613
          exclude_pattern_list = [
614
              'Optimizer', 'Momentum', 'RMSProp', 'LARSOptimizer'
615
          ]
616
          variables = filter(
617
              lambda v: var_matches_patterns(v, include_pattern_list),
618
              tf.global_variables())
619
          variables = filter(
620
              lambda v: not var_matches_patterns(v, exclude_pattern_list),
621
              variables)
622
          var_init_fn = slim.assign_from_checkpoint_fn(
623
              tf.train.latest_checkpoint(FLAGS.reference_ckpt),
624
              list(variables),
625
              ignore_missing_vars=(
626
                  warm_start_hparams.ignore_missing_checkpoint_vars),
627
              reshape_variables=True)
628

629
      def init_fn(scaffold, sess):
630
        del scaffold  # unused.
631

632
        if FLAGS.reference_ckpt:
633
          var_init_fn(sess)
634

635
      return tf.train.Scaffold(init_fn=init_fn)
636

637
    return scaffold_fn
638

639

640
def model_fn(features, labels, mode, params):
641
  """Contrastive model function."""
642

643
  model_mode = utils.estimator_mode_to_model_mode(mode)
644
  hparams = params['hparams']
645

646
  trainer = ContrastiveTrainer(
647
      model_inputs=features,
648
      labels=labels,
649
      train_global_batch_size=hparams.bs,
650
      hparams=hparams,
651
      mode=model_mode,
652
      num_classes=inputs.get_num_classes(hparams),
653
      training_set_size=inputs.get_num_train_images(hparams),
654
      is_tpu=params['use_tpu'])
655

656
  if mode == tf_estimator.ModeKeys.PREDICT:
657
    predictions_map = trainer.signature_def_map()
658
    exports = {
659
        k: tf_estimator.export.PredictOutput(v)
660
        for k, v in predictions_map.items()
661
    }
662
    # Export a default SignatureDef to keep the API happy.
663
    exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
664
        exports['contrastive_eval'])
665
    spec = tf_estimator.tpu.TPUEstimatorSpec(
666
        mode=mode,
667
        predictions=predictions_map['contrastive_eval'],
668
        export_outputs=exports)
669
    return spec
670

671
  # We directly write summaries for the relevant losses, so just hard-code a
672
  # dummy value to keep the Estimator API happy.
673
  loss = tf.constant(0.)
674

675
  if mode == tf_estimator.ModeKeys.EVAL:
676
    spec = tf_estimator.tpu.TPUEstimatorSpec(
677
        mode=mode, loss=loss, eval_metrics=trainer.eval_metrics())
678
    return spec
679
  else:  # TRAIN
680
    spec = tf_estimator.tpu.TPUEstimatorSpec(
681
        mode=mode,
682
        train_op=trainer.train_op(),
683
        loss=loss,
684
        scaffold_fn=trainer.scaffold_fn(),
685
        host_call=trainer.host_call(FLAGS.model_dir))
686
    return spec
687

688

689
def main(_):
690
  tf.disable_v2_behavior()
691
  tf.enable_resource_variables()
692

693
  if FLAGS.hparams is None:
694
    hparams = hparams_flags.hparams_from_flags()
695
  else:
696
    hparams = hparams_lib.HParams(FLAGS.hparams)
697

698
  cluster = None
699
  if FLAGS.use_tpu and FLAGS.master is None:
700
    if FLAGS.tpu_name:
701
      cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
702
          FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
703
    else:
704
      cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
705
      tf.config.experimental_connect_to_cluster(cluster)
706
      tf.tpu.experimental.initialize_tpu_system(cluster)
707

708
  session_config = tf.ConfigProto()
709
  # Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where
710
  # convolutions (used in blurring) get confused about data-format when used
711
  # inside a tf.data pipeline that is run on GPU.
712
  if (tf.test.is_built_with_cuda() and
713
      not hparams.input_data.preprocessing.defer_blurring):
714
    # RewriterConfig.OFF = 2
715
    session_config.graph_options.rewrite_options.layout_optimizer = 2
716
  run_config = tf_estimator.tpu.RunConfig(
717
      master=FLAGS.master,
718
      cluster=cluster,
719
      model_dir=FLAGS.model_dir,
720
      save_checkpoints_steps=FLAGS.save_interval_steps,
721
      keep_checkpoint_max=FLAGS.max_checkpoints_to_keep,
722
      keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs /
723
                                     (60.0 * 60.0)),
724
      log_step_count_steps=100,
725
      session_config=session_config,
726
      tpu_config=tf_estimator.tpu.TPUConfig(
727
          iterations_per_loop=FLAGS.steps_per_loop,
728
          per_host_input_for_training=True,
729
          experimental_host_call_every_n_steps=FLAGS.summary_interval_steps,
730
          tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None,
731
          eval_training_input_configuration=(
732
              tf_estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu else
733
              tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1)))
734
  params = {
735
      'hparams': hparams,
736
      'use_tpu': FLAGS.use_tpu,
737
      'data_dir': FLAGS.data_dir,
738
  }
739
  estimator = tf_estimator.tpu.TPUEstimator(
740
      model_fn=model_fn,
741
      use_tpu=FLAGS.use_tpu,
742
      config=run_config,
743
      params=params,
744
      train_batch_size=hparams.bs,
745
      eval_batch_size=hparams.eval.batch_size)
746

747
  if hparams.input_data.input_fn not in dir(inputs):
748
    raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}')
749
  input_fn = getattr(inputs, hparams.input_data.input_fn)
750

751
  training_set_size = inputs.get_num_train_images(hparams)
752
  steps_per_epoch = training_set_size / hparams.bs
753
  stage_1_epochs = hparams.stage_1.training.train_epochs
754
  stage_2_epochs = hparams.stage_2.training.train_epochs
755
  total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch)
756

757
  num_eval_examples = inputs.get_num_eval_images(hparams)
758
  eval_steps = num_eval_examples // hparams.eval.batch_size
759

760
  if FLAGS.mode == 'eval':
761
    for ckpt_str in tf.train.checkpoints_iterator(
762
        FLAGS.model_dir,
763
        min_interval_secs=FLAGS.eval_interval_secs,
764
        timeout=60 * 60):
765
      result = estimator.evaluate(
766
          input_fn=input_fn, checkpoint_path=ckpt_str, steps=eval_steps)
767
      estimator.export_saved_model(
768
          os.path.join(FLAGS.model_dir, 'exports'),
769
          lambda: input_fn(tf_estimator.ModeKeys.PREDICT, params),
770
          checkpoint_path=ckpt_str)
771
      if result['global_step'] >= total_steps:
772
        return
773
  else:  # 'train' or 'train_then_eval'.
774
    estimator.train(input_fn=input_fn, max_steps=total_steps)
775
    if FLAGS.mode == 'train_then_eval':
776
      result = estimator.evaluate(input_fn=input_fn, steps=eval_steps)
777
      estimator.export_saved_model(
778
          os.path.join(FLAGS.model_dir, 'exports'),
779
          lambda: input_fn(tf_estimator.ModeKeys.PREDICT, params))
780

781

782
if __name__ == '__main__':
783
  app.run(main)
784

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.