google-research
783 строки · 31.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contrastive Learning training/eval code."""
17
18import os19
20from absl import app21from absl import flags22import tensorflow.compat.v1 as tf23from tensorflow.compat.v1 import estimator as tf_estimator24import tensorflow.compat.v2 as tf225import tf_slim as slim26
27from supcon import enums28from supcon import hparams as hparams_lib29from supcon import hparams_flags30from supcon import inputs31from supcon import losses32from supcon import models33from supcon import preprocessing34from supcon import utils35
36flags.DEFINE_string(37'hparams', None,38'A serialized hparams string representing the hyperparameters to use. If '39'not set, fall back to using the individual hyperparameter flags defined '40'in hparams_flags.py')41flags.DEFINE_enum(42'mode', 'train', ['train', 'eval', 'train_then_eval'],43'The mode for this job, either "train", "eval", or '44'"train_then_eval".')45flags.DEFINE_string(46'model_dir', '', 'Root of the tree containing all files for the current '47'model.')48flags.DEFINE_string('master', None, 'Address of the TensorFlow runtime.')49flags.DEFINE_integer('summary_interval_steps', 100,50'Number of steps in between logging training summaries.')51flags.DEFINE_integer('save_interval_steps', 1000,52'Number of steps in between saving model checkpoints.')53flags.DEFINE_integer('max_checkpoints_to_keep', 5,54'Maximum number of recent checkpoints to keep.')55flags.DEFINE_float(56'keep_checkpoint_interval_secs',5760 * 60 * 1000 * 10, # 10,000 hours58'Number of seconds in between permanently retained checkpoints.')59flags.DEFINE_integer(60'steps_per_loop', 1000,61'Number of steps to execute on TPU before returning control to the '62'coordinator. Checkpoints will be taken at least these many steps apart.')63flags.DEFINE_boolean('use_tpu', True, 'Whether this is running on a TPU.')64flags.DEFINE_integer('eval_interval_secs', 60, 'Time interval between evals.')65flags.DEFINE_string(66'reference_ckpt', '',67'[Optional] If set, attempt to initialize the model using the latest '68'checkpoint in this directory.')69flags.DEFINE_string(70'data_dir', None,71'The directory that will be passed as the `data_dir` argument to '72'`tfds.load`.')73tf.flags.DEFINE_string(74'tpu_name', None,75'The Cloud TPU to use for training. This should be either the name '76'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '77'url.')78tf.flags.DEFINE_string(79'tpu_zone', None,80'[Optional] GCE zone where the Cloud TPU is located in. If not '81'specified, we will attempt to automatically detect the GCE project from '82'metadata.')83tf.flags.DEFINE_string(84'gcp_project', None,85'[Optional] Project name for the Cloud TPU-enabled project. If not '86'specified, we will attempt to automatically detect the GCE project from '87'metadata.')88
89FLAGS = flags.FLAGS90
91# Learning rate hparam values are defined with respect to this batch size. The
92# true learning rate will be scaled by batch_size/BASE_BATCH_SIZE.
93BASE_BATCH_SIZE = 25694
95
96class ContrastiveTrainer:97"""Encapsulates the train, eval, and inference logic of a contrastive model.98
99Upon construction of this class, the model graph is created. In train and eval
100mode, the loss computation graph is also created at construction time.
101
102Attrs:
103model_inputs: The inputs to the model. These should be Tensors with shape
104[batch_size, side_length, side_length, 3 * views] with values in range
105[-1, 1] and dtype tf.float32 or tf.bfloat16. Currently views must be 1 or
1062.
107labels: The labels corresponding to `model_inputs`. A Tensor of shape
108[batch_size] with integer dtype.
109hparams: A hparams.HParams instance, reflecting the hyperparameters of the
110model and its training.
111mode: An enums.ModelMode value.
112num_classes: The cardinality of the labelset, and also the number of output
113logits of the classification head.
114training_set_size: The number of samples in the training set.
115is_tpu: Whether this is running on a TPU.
116"""
117
118def __init__(self,119model_inputs,120labels,121train_global_batch_size,122hparams,123mode,124num_classes,125training_set_size,126is_tpu=False):127self.model_inputs = model_inputs128self.labels = labels129self.train_global_batch_size = train_global_batch_size130self.hparams = hparams131self.mode = mode132assert isinstance(mode, enums.ModelMode)133self.num_classes = num_classes134self.training_set_size = training_set_size135self.is_tpu = is_tpu136self._summary_dict = {}137
138if not self.inference:139if tf.compat.dimension_at_index(self.model_inputs.shape, -1) != 6:140raise ValueError(141'Both train and eval modes must have 2 views provided, '142'concatenated in the channels dimension.')143
144self.data_format = ('channels_first' if not self.inference and145tf.config.list_physical_devices('GPU') else146'channels_last')147
148if self.eval:149self._summary_update_ops = []150
151self.model = self._create_model()152
153is_bn_train_mode = (154# We intentionally run with batch norm in train mode for inference. We155# call model() a second time with training=False for inference mode156# below, and include both in the inference graph and SavedModel.157not self.eval and (not FLAGS.reference_ckpt or158self.hparams.warm_start.batch_norm_in_train_mode))159(self.unnormalized_embedding, self.normalized_embedding, self.projection,160self.logits) = self._call_model(training=is_bn_train_mode)161
162if self.inference:163(self.unnormalized_embedding_eval, self.normalized_embedding_eval,164self.projection_eval,165self.logits_eval) = self._call_model(training=False)166return167
168self._encoder_weights = self._compute_weights('Encoder')169self._projection_head_weights = self._compute_weights('ProjectionHead')170self._classification_head_weights = self._compute_weights(171'ClassificationHead')172
173self.contrastive_loss = self._compute_contrastive_loss()174self.cross_entropy_loss = self._compute_cross_entropy_loss()175
176@property177def train(self):178return self.mode == enums.ModelMode.TRAIN179
180@property181def eval(self):182return self.mode == enums.ModelMode.EVAL183
184@property185def inference(self):186return self.mode == enums.ModelMode.INFERENCE187
188def _add_scalar_summary(self, name, tensor):189"""Collects tensors that should be written as summaries in `host_call`."""190self._summary_dict[name] = tensor191
192def _create_model(self):193"""Creates the model, but does not build it or create variables.194
195Returns:
196A callable Keras layer that implements the model architecture.
197"""
198arch_hparams = self.hparams.architecture199model = models.ContrastiveModel(200architecture=arch_hparams.encoder_architecture,201normalize_projection_head_input=(202arch_hparams.normalize_projection_head_inputs),203normalize_classification_head_input=(204arch_hparams.normalize_classifier_inputs),205stop_gradient_before_classification_head=(206arch_hparams.stop_gradient_before_classification_head),207stop_gradient_before_projection_head=(208arch_hparams.stop_gradient_before_projection_head),209encoder_kwargs={210'depth': arch_hparams.encoder_depth,211'width': arch_hparams.encoder_width,212'first_conv_kernel_size': arch_hparams.first_conv_kernel_size,213'first_conv_stride': arch_hparams.first_conv_stride,214'data_format': self.data_format,215'use_initial_max_pool': arch_hparams.use_initial_max_pool,216'use_global_batch_norm': arch_hparams.use_global_batch_norm,217},218projection_head_kwargs={219'feature_dims':220arch_hparams.projection_head_layers,221'normalize_output':222True,223'use_batch_norm':224arch_hparams.projection_head_use_batch_norm,225'use_batch_norm_beta':226arch_hparams.projection_head_use_batch_norm_beta,227'use_global_batch_norm':228arch_hparams.use_global_batch_norm,229},230classification_head_kwargs={231'num_classes':232self.num_classes,233'kernel_initializer': (tf.initializers.zeros()234if arch_hparams.zero_initialize_classifier235else tf.initializers.glorot_uniform)236})237
238return model239
240def _call_model(self, training):241"""Passes data through the model.242
243Manipulates the input data to get it ready for passing into the model,
244including applying some data augmentation that is more efficient to apply on
245the TPU than on the host. It then passes it into the model, which will first
246build the model and create its variables.
247
248Args:
249training: Whether the model should be run in training mode.
250
251Returns:
252A tuple of the model outputs (as Tensors):
253* unnormalized_embedding: The output of the encoder, not including
254normalization, which is sometimes applied before this gets passed into
255the projection and classification heads.
256* normalized_embedding: A normalized version of `unnormalized_embedding`.
257* projection: The output of the projection head.
258* logits: The output of the classification head.
259"""
260with tf.name_scope('call_model'):261model_inputs = self.model_inputs262
263# In most cases, the data format NCHW instead of NHWC should be used for a264# significant performance boost on GPU. NHWC should be used only if the265# network needs to be run on CPU since the pooling operations are only266# supported on NHWC. TPU uses XLA compiler to figure out best layout.267if self.data_format == 'channels_first':268model_inputs = tf.transpose(model_inputs, [0, 3, 1, 2])269
270channels_index = 1 if self.data_format == 'channels_first' else -1271inputs_are_multiview = tf.compat.dimension_value(272model_inputs.shape[channels_index]) > 3273if inputs_are_multiview:274model_inputs = utils.stacked_multiview_image_channels_to_batch(275model_inputs, self.data_format)276
277# Perform blur augmentations here, since they're faster on TPU than CPU.278if (self.hparams.input_data.preprocessing.augmentation_type in (279enums.AugmentationType.SIMCLR,280enums.AugmentationType.STACKED_RANDAUGMENT) and281self.hparams.input_data.preprocessing.blur_probability > 0. and282self.hparams.input_data.preprocessing.defer_blurring and self.train):283model_inputs = preprocessing.batch_random_blur(284model_inputs,285tf.compat.dimension_value(model_inputs.shape[1]),286blur_probability=(287self.hparams.input_data.preprocessing.blur_probability))288
289with tf.tpu.bfloat16_scope():290model_outputs = self.model(model_inputs, training)291
292if inputs_are_multiview:293model_outputs = [294utils.stacked_multiview_embeddings_to_channel(295tf.cast(x, tf.float32)) if x is not None else x296for x in model_outputs297]298
299(unnormalized_embedding, normalized_embedding, projection,300logits) = model_outputs301
302if inputs_are_multiview:303# If we keep everything in batch dimension then we don't need this. In304# cross_entropy mode we should just stop generating the 2nd305# augmentation.306logits = tf.split(logits, 2, axis=1)[0]307
308return unnormalized_embedding, normalized_embedding, projection, logits309
310def _compute_cross_entropy_loss(self):311"""Computes and returns the cross-entropy loss on the logits."""312with tf.name_scope('cross_entropy_loss'):313one_hot_labels = tf.one_hot(self.labels, self.num_classes)314cross_entropy = tf.losses.softmax_cross_entropy(315logits=self.logits,316onehot_labels=one_hot_labels,317label_smoothing=(318self.hparams.loss_all_stages.cross_entropy.label_smoothing),319reduction=tf.losses.Reduction.NONE)320
321if self.train:322in_top_1 = tf.cast(323tf.nn.in_top_k(self.logits, self.labels, 1), tf.float32)324in_top_5 = tf.cast(325tf.nn.in_top_k(self.logits, self.labels, 5), tf.float32)326self._add_scalar_summary('top_1_accuracy', in_top_1)327self._add_scalar_summary('top_5_accuracy', in_top_5)328self._add_scalar_summary('loss/cross_entropy_loss', cross_entropy)329cross_entropy = tf.reduce_mean(cross_entropy)330
331return cross_entropy332
333def _compute_contrastive_loss(self):334"""Computes and returns the contrastive loss on the projection."""335with tf.name_scope('contrastive_loss'):336contrastive_params = self.hparams.loss_all_stages.contrastive337labels = (338tf.one_hot(self.labels, self.num_classes)339if contrastive_params.use_labels else None)340projection = self.projection341projection_view_1, projection_view_2 = tf.split(projection, 2, axis=-1)342contrastive_loss = losses.contrastive_loss(343tf.stack([projection_view_1, projection_view_2], axis=1),344labels=labels,345temperature=contrastive_params.temperature,346contrast_mode=contrastive_params.contrast_mode,347summation_location=contrastive_params.summation_location,348denominator_mode=contrastive_params.denominator_mode,349positives_cap=contrastive_params.positives_cap,350scale_by_temperature=contrastive_params.scale_by_temperature)351
352if self.train:353self._add_scalar_summary('loss/contrastive_loss', contrastive_loss)354contrastive_loss = tf.reduce_mean(contrastive_loss)355
356return contrastive_loss357
358def _compute_stage_weight_decay(self, stage_params, stage_name):359"""Computes and returns the weight decay loss for a single training stage.360
361Args:
362stage_params: An instance of hparams.Stage.
363stage_name: A string name for this stage.
364
365Returns:
366A scalar Tensor representing the weight decay loss for this stage.
367"""
368with tf.name_scope(f'{stage_name}_weight_decay'):369# Don't use independent weight decay with LARS optimizer, since it handles370# it internally.371weight_decay_coeff = (372stage_params.loss.weight_decay_coeff if373stage_params.training.optimizer is not enums.Optimizer.LARS else 0.)374weights = (375self._encoder_weights *376float(stage_params.loss.use_encoder_weight_decay) +377self._projection_head_weights *378float(stage_params.loss.use_projection_head_weight_decay) +379self._classification_head_weights *380float(stage_params.loss.use_classification_head_weight_decay))381weight_decay_loss = weight_decay_coeff * weights382with tf.name_scope(''):383self._add_scalar_summary(f'weight_decay/{stage_name}_weight_decay_loss',384weight_decay_loss)385return weight_decay_loss386
387def _compute_weights(self, scope_name):388"""Computes the sum of the L2 norms of all kernel weights inside a scope."""389
390def is_valid_weight(v):391if (scope_name in v.name and 'batch_normalization' not in v.name and392('bias' not in v.name or393self.hparams.loss_all_stages.include_bias_in_weight_decay)):394return True395return False396
397with tf.name_scope(f'sum_{scope_name}_weights'):398valid_weights = filter(is_valid_weight, tf.trainable_variables())399sum_of_weights = tf.add_n([tf.nn.l2_loss(v) for v in valid_weights])400
401self._add_scalar_summary(f'weight_decay/{scope_name}_weights',402sum_of_weights)403return sum_of_weights404
405def train_op(self):406"""Creates the Op for training this network.407
408Computes learning rates, builds optimizers, and constructs the train ops to
409minimize the losses.
410
411Returns:
412A TensorFlow Op that will run one step of training when executed.
413"""
414with tf.name_scope('train'):415batch_size = self.train_global_batch_size416steps_per_epoch = self.training_set_size / batch_size417stage_1_epochs = self.hparams.stage_1.training.train_epochs418stage_1_steps = int(stage_1_epochs * steps_per_epoch)419stage_2_epochs = self.hparams.stage_2.training.train_epochs420global_step = tf.train.get_or_create_global_step()421stage_1_indicator = tf.math.less(global_step, stage_1_steps)422stage_2_indicator = tf.math.logical_not(stage_1_indicator)423
424def stage_learning_rate(stage_training_params, start_epoch, end_epoch):425schedule_kwargs = {}426if (stage_training_params.learning_rate_decay in (427enums.DecayType.PIECEWISE_LINEAR, enums.DecayType.EXPONENTIAL)):428schedule_kwargs['decay_rate'] = stage_training_params.decay_rate429if (stage_training_params.learning_rate_decay ==430enums.DecayType.PIECEWISE_LINEAR):431schedule_kwargs['boundary_epochs'] = (432stage_training_params.decay_boundary_epochs)433if (stage_training_params.learning_rate_decay ==434enums.DecayType.EXPONENTIAL):435schedule_kwargs['epochs_per_decay'] = (436stage_training_params.epochs_per_decay)437
438return utils.build_learning_rate_schedule(439learning_rate=(stage_training_params.base_learning_rate *440(batch_size / BASE_BATCH_SIZE)),441decay_type=stage_training_params.learning_rate_decay,442warmup_start_epoch=start_epoch,443max_learning_rate_epoch=(444start_epoch +445stage_training_params.learning_rate_warmup_epochs),446decay_end_epoch=end_epoch,447global_step=global_step,448steps_per_epoch=steps_per_epoch,449**schedule_kwargs)450
451stage_1_learning_rate = stage_learning_rate(452self.hparams.stage_1.training,453start_epoch=0,454end_epoch=stage_1_epochs) * tf.cast(stage_1_indicator, tf.float32)455stage_2_learning_rate = stage_learning_rate(456self.hparams.stage_2.training,457start_epoch=stage_1_epochs,458end_epoch=stage_1_epochs + stage_2_epochs) * tf.cast(459stage_2_indicator, tf.float32)460
461def stage_optimizer(stage_learning_rate, stage_params, stage_name):462lars_exclude_from_weight_decay = ['batch_normalization']463if not self.hparams.loss_all_stages.include_bias_in_weight_decay:464lars_exclude_from_weight_decay.append('bias')465if not stage_params.loss.use_encoder_weight_decay:466lars_exclude_from_weight_decay.append('Encoder')467if not stage_params.loss.use_projection_head_weight_decay:468lars_exclude_from_weight_decay.append('ProjectionHead')469if not stage_params.loss.use_classification_head_weight_decay:470lars_exclude_from_weight_decay.append('ClassificationHead')471
472return utils.build_optimizer(473stage_learning_rate,474optimizer_type=stage_params.training.optimizer,475lars_weight_decay=stage_params.loss.weight_decay_coeff,476lars_exclude_from_weight_decay=lars_exclude_from_weight_decay,477epsilon=stage_params.training.rmsprop_epsilon,478is_tpu=self.is_tpu,479name=stage_name)480
481stage_1_optimizer = stage_optimizer(stage_1_learning_rate,482self.hparams.stage_1, 'stage1')483stage_2_optimizer = stage_optimizer(stage_2_learning_rate,484self.hparams.stage_2, 'stage2')485
486def stage_loss(stage_params, stage_name):487return (488stage_params.loss.contrastive_weight * self.contrastive_loss +489stage_params.loss.cross_entropy_weight * self.cross_entropy_loss +490self._compute_stage_weight_decay(stage_params, stage_name))491
492stage_1_loss = stage_loss(self.hparams.stage_1, 'stage1')493stage_2_loss = stage_loss(self.hparams.stage_2, 'stage2')494
495def stage_1_train_op():496return utils.create_train_op(497stage_1_loss,498stage_1_optimizer,499update_ops=(None if500self.hparams.stage_1.training.update_encoder_batch_norm501else []))502
503def stage_2_train_op():504return utils.create_train_op(505stage_2_loss,506stage_2_optimizer,507update_ops=(None if508self.hparams.stage_2.training.update_encoder_batch_norm509else []))510
511train_op = tf.cond(stage_1_indicator, stage_1_train_op, stage_2_train_op)512
513self._add_scalar_summary('stage_1_learning_rate', stage_1_learning_rate)514self._add_scalar_summary('stage_2_learning_rate', stage_2_learning_rate)515self._add_scalar_summary('current_epoch',516tf.cast(global_step, tf.float32) / steps_per_epoch)517
518return train_op519
520def host_call(self, summary_dir):521"""Creates a host call to write summaries."""522
523# Ensure that all host_call inputs have batch dimensions, since they get524# concatenated from all cores along the batch dimension.525summary_dict = {526k: tf.expand_dims(v, axis=0) if v.shape.rank == 0 else v527for k, v in self._summary_dict.items()528}529
530# Pass in the global step, since otherwise we might use a stale copy of the531# variable from the host.532global_step_key = 'global_step_is_not_a_summary'533summary_dict[global_step_key] = tf.expand_dims(534tf.train.get_or_create_global_step(), axis=0)535
536def host_call_fn(**kwargs):537step = kwargs[global_step_key][0]538del kwargs[global_step_key]539writer = tf2.summary.create_file_writer(summary_dir, max_queue=1000)540always_record = tf2.summary.record_if(True)541with writer.as_default(), always_record:542for name, scalar in kwargs.items():543tf2.summary.scalar(name, tf.reduce_mean(scalar), step)544return tf.summary.all_v2_summary_ops()545
546return host_call_fn, summary_dict547
548def eval_metrics(self):549"""Returns eval metric_fn and metrics."""550
551def metric_fn(logits, labels, contrastive_loss, cross_entropy_loss):552metrics = {}553in_top_1 = tf.cast(tf.nn.in_top_k(logits, labels, 1), tf.float32)554in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)555metrics['top_1_accuracy'] = tf.metrics.mean(in_top_1)556metrics['top_5_accuracy'] = tf.metrics.mean(in_top_5)557metrics['loss/contrastive_loss'] = tf.metrics.mean(contrastive_loss)558metrics['loss/cross_entropy_loss'] = tf.metrics.mean(cross_entropy_loss)559return metrics560
561metrics = {562'logits': self.logits,563'labels': self.labels,564'contrastive_loss': self.contrastive_loss,565'cross_entropy_loss': self.cross_entropy_loss,566}567
568return metric_fn, metrics569
570def signature_def_map(self):571"""Returns a SignatureDef map that can be used to produce SavedModels."""572signature_def_map = {}573signature_def_map['contrastive_train'] = {574'embeddings': self.normalized_embedding,575'unnormalized_embeddings': self.unnormalized_embedding,576'projection': self.projection,577'logits': self.logits,578}579signature_def_map['contrastive_eval'] = {580'embeddings': self.normalized_embedding_eval,581'unnormalized_embeddings': self.unnormalized_embedding_eval,582'projection': self.projection_eval,583'logits': self.logits_eval,584}585return signature_def_map586
587def scaffold_fn(self):588"""Creates a function that produces a tf.train.Scaffold for custom init.589
590When appropriate, it restores all or some of the weights from a checkpoint
591at model initialization.
592
593Returns:
594A function that produces a tf.train.Scaffold.
595"""
596
597def var_matches_patterns(var, patterns):598return any(pattern in var.name for pattern in patterns)599
600def scaffold_fn():601"""Scaffold function."""602warm_start_hparams = self.hparams.warm_start603if FLAGS.reference_ckpt:604with tf.name_scope('warm_start'):605include_pattern_list = []606if warm_start_hparams.warm_start_encoder:607include_pattern_list.append('ContrastiveModel/Encoder')608if warm_start_hparams.warm_start_projection_head:609include_pattern_list.append('ContrastiveModel/ProjectionHead')610if warm_start_hparams.warm_start_classifier:611include_pattern_list.append('ContrastiveModel/ClassificationHead')612# This needs to be updated if new optimizers are added.613exclude_pattern_list = [614'Optimizer', 'Momentum', 'RMSProp', 'LARSOptimizer'615]616variables = filter(617lambda v: var_matches_patterns(v, include_pattern_list),618tf.global_variables())619variables = filter(620lambda v: not var_matches_patterns(v, exclude_pattern_list),621variables)622var_init_fn = slim.assign_from_checkpoint_fn(623tf.train.latest_checkpoint(FLAGS.reference_ckpt),624list(variables),625ignore_missing_vars=(626warm_start_hparams.ignore_missing_checkpoint_vars),627reshape_variables=True)628
629def init_fn(scaffold, sess):630del scaffold # unused.631
632if FLAGS.reference_ckpt:633var_init_fn(sess)634
635return tf.train.Scaffold(init_fn=init_fn)636
637return scaffold_fn638
639
640def model_fn(features, labels, mode, params):641"""Contrastive model function."""642
643model_mode = utils.estimator_mode_to_model_mode(mode)644hparams = params['hparams']645
646trainer = ContrastiveTrainer(647model_inputs=features,648labels=labels,649train_global_batch_size=hparams.bs,650hparams=hparams,651mode=model_mode,652num_classes=inputs.get_num_classes(hparams),653training_set_size=inputs.get_num_train_images(hparams),654is_tpu=params['use_tpu'])655
656if mode == tf_estimator.ModeKeys.PREDICT:657predictions_map = trainer.signature_def_map()658exports = {659k: tf_estimator.export.PredictOutput(v)660for k, v in predictions_map.items()661}662# Export a default SignatureDef to keep the API happy.663exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (664exports['contrastive_eval'])665spec = tf_estimator.tpu.TPUEstimatorSpec(666mode=mode,667predictions=predictions_map['contrastive_eval'],668export_outputs=exports)669return spec670
671# We directly write summaries for the relevant losses, so just hard-code a672# dummy value to keep the Estimator API happy.673loss = tf.constant(0.)674
675if mode == tf_estimator.ModeKeys.EVAL:676spec = tf_estimator.tpu.TPUEstimatorSpec(677mode=mode, loss=loss, eval_metrics=trainer.eval_metrics())678return spec679else: # TRAIN680spec = tf_estimator.tpu.TPUEstimatorSpec(681mode=mode,682train_op=trainer.train_op(),683loss=loss,684scaffold_fn=trainer.scaffold_fn(),685host_call=trainer.host_call(FLAGS.model_dir))686return spec687
688
689def main(_):690tf.disable_v2_behavior()691tf.enable_resource_variables()692
693if FLAGS.hparams is None:694hparams = hparams_flags.hparams_from_flags()695else:696hparams = hparams_lib.HParams(FLAGS.hparams)697
698cluster = None699if FLAGS.use_tpu and FLAGS.master is None:700if FLAGS.tpu_name:701cluster = tf.distribute.cluster_resolver.TPUClusterResolver(702FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)703else:704cluster = tf.distribute.cluster_resolver.TPUClusterResolver()705tf.config.experimental_connect_to_cluster(cluster)706tf.tpu.experimental.initialize_tpu_system(cluster)707
708session_config = tf.ConfigProto()709# Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where710# convolutions (used in blurring) get confused about data-format when used711# inside a tf.data pipeline that is run on GPU.712if (tf.test.is_built_with_cuda() and713not hparams.input_data.preprocessing.defer_blurring):714# RewriterConfig.OFF = 2715session_config.graph_options.rewrite_options.layout_optimizer = 2716run_config = tf_estimator.tpu.RunConfig(717master=FLAGS.master,718cluster=cluster,719model_dir=FLAGS.model_dir,720save_checkpoints_steps=FLAGS.save_interval_steps,721keep_checkpoint_max=FLAGS.max_checkpoints_to_keep,722keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs /723(60.0 * 60.0)),724log_step_count_steps=100,725session_config=session_config,726tpu_config=tf_estimator.tpu.TPUConfig(727iterations_per_loop=FLAGS.steps_per_loop,728per_host_input_for_training=True,729experimental_host_call_every_n_steps=FLAGS.summary_interval_steps,730tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None,731eval_training_input_configuration=(732tf_estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu else733tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1)))734params = {735'hparams': hparams,736'use_tpu': FLAGS.use_tpu,737'data_dir': FLAGS.data_dir,738}739estimator = tf_estimator.tpu.TPUEstimator(740model_fn=model_fn,741use_tpu=FLAGS.use_tpu,742config=run_config,743params=params,744train_batch_size=hparams.bs,745eval_batch_size=hparams.eval.batch_size)746
747if hparams.input_data.input_fn not in dir(inputs):748raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}')749input_fn = getattr(inputs, hparams.input_data.input_fn)750
751training_set_size = inputs.get_num_train_images(hparams)752steps_per_epoch = training_set_size / hparams.bs753stage_1_epochs = hparams.stage_1.training.train_epochs754stage_2_epochs = hparams.stage_2.training.train_epochs755total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch)756
757num_eval_examples = inputs.get_num_eval_images(hparams)758eval_steps = num_eval_examples // hparams.eval.batch_size759
760if FLAGS.mode == 'eval':761for ckpt_str in tf.train.checkpoints_iterator(762FLAGS.model_dir,763min_interval_secs=FLAGS.eval_interval_secs,764timeout=60 * 60):765result = estimator.evaluate(766input_fn=input_fn, checkpoint_path=ckpt_str, steps=eval_steps)767estimator.export_saved_model(768os.path.join(FLAGS.model_dir, 'exports'),769lambda: input_fn(tf_estimator.ModeKeys.PREDICT, params),770checkpoint_path=ckpt_str)771if result['global_step'] >= total_steps:772return773else: # 'train' or 'train_then_eval'.774estimator.train(input_fn=input_fn, max_steps=total_steps)775if FLAGS.mode == 'train_then_eval':776result = estimator.evaluate(input_fn=input_fn, steps=eval_steps)777estimator.export_saved_model(778os.path.join(FLAGS.model_dir, 'exports'),779lambda: input_fn(tf_estimator.ModeKeys.PREDICT, params))780
781
782if __name__ == '__main__':783app.run(main)784