google-research
225 строк · 7.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Finetunes the pre-trained model on the target set."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import re23
24from absl import app25from absl import flags26import model27import model_utils28import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import29from tensorflow import estimator as tf_estimator30from tensorflow.python.estimator import estimator31import tensorflow_datasets as tfds32
33flags.DEFINE_string(34'model_dir',35None,36help=('The directory where the model and training/evaluation summaries are'37' stored.'))38flags.DEFINE_string(39'warm_start_ckpt_path', None, 'The path to the checkpoint '40'that will be used before training.')41flags.DEFINE_integer(42'log_step_count_steps', 200, 'The number of steps at '43'which the global step information is logged.')44flags.DEFINE_integer('train_steps', 100, 'Number of steps for training.')45flags.DEFINE_float('target_base_learning_rate', 0.001,46'Target base learning rate.')47flags.DEFINE_integer('train_batch_size', 256,48'The batch size for the target dataset.')49flags.DEFINE_float('weight_decay', 0.0005, 'The value for weight decay.')50
51FLAGS = flags.FLAGS52
53
54def lr_schedule():55"""Learning rate scheduling."""56target_lr = FLAGS.target_base_learning_rate57current_step = tf.train.get_global_step()58
59if FLAGS.target_dataset == 'mnist':60return tf.train.piecewise_constant(current_step, [61500,621500,63], [target_lr, target_lr * 0.1, target_lr * 0.01])64else:65return tf.train.piecewise_constant(current_step, [66800,67], [target_lr, target_lr * 0.1])68
69
70def get_model_fn():71"""Returns the model definition."""72
73def model_fn(features, labels, mode, params):74"""Returns the model function."""75feature = features['feature']76labels = labels['label']77one_hot_labels = model_utils.get_label(78labels,79params,80FLAGS.src_num_classes,81batch_size=FLAGS.train_batch_size)82
83def get_logits():84"""Return the logits."""85avg_pool = model.conv_model(86feature,87mode,88target_dataset=FLAGS.target_dataset,89src_hw=FLAGS.src_hw,90target_hw=FLAGS.target_hw)91name = 'final_dense_dst'92with tf.variable_scope('target_CLS'):93logits = tf.layers.dense(94inputs=avg_pool,95units=FLAGS.src_num_classes,96name=name,97kernel_initializer=tf.random_normal_initializer(stddev=.05),98)99return logits100
101logits = get_logits()102logits = tf.cast(logits, tf.float32)103
104dst_loss = tf.losses.softmax_cross_entropy(105logits=logits,106onehot_labels=one_hot_labels,107)108dst_l2_loss = FLAGS.weight_decay * tf.add_n([109tf.nn.l2_loss(v)110for v in tf.trainable_variables()111if 'batch_normalization' not in v.name and 'kernel' in v.name112])113
114loss = dst_loss + dst_l2_loss115
116train_op = None117if mode == tf_estimator.ModeKeys.TRAIN:118cur_finetune_step = tf.train.get_global_step()119update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)120with tf.control_dependencies(update_ops):121finetune_learning_rate = lr_schedule()122optimizer = tf.train.MomentumOptimizer(123learning_rate=finetune_learning_rate,124momentum=0.9,125use_nesterov=True)126train_op = tf.contrib.slim.learning.create_train_op(loss, optimizer)127with tf.variable_scope('finetune'):128train_op = optimizer.minimize(loss, cur_finetune_step)129else:130train_op = None131
132eval_metrics = None133if mode == tf_estimator.ModeKeys.EVAL:134eval_metrics = model_utils.metric_fn(labels, logits)135
136if mode == tf_estimator.ModeKeys.TRAIN:137with tf.control_dependencies([train_op]):138tf.summary.scalar('classifier/finetune_lr', finetune_learning_rate)139else:140train_op = None141
142return tf_estimator.EstimatorSpec(143mode=mode,144loss=loss,145train_op=train_op,146eval_metric_ops=eval_metrics,147)148
149return model_fn150
151
152def main(unused_argv):153tf.set_random_seed(FLAGS.random_seed)154
155save_checkpoints_steps = 100156run_config_args = {157'model_dir': FLAGS.model_dir,158'save_checkpoints_steps': save_checkpoints_steps,159'log_step_count_steps': FLAGS.log_step_count_steps,160'keep_checkpoint_max': 200,161}162
163config = tf_estimator.RunConfig(**run_config_args)164
165if FLAGS.warm_start_ckpt_path:166var_names = []167checkpoint_path = FLAGS.warm_start_ckpt_path168reader = tf.train.NewCheckpointReader(checkpoint_path)169for key in reader.get_variable_to_shape_map():170keep_str = 'Momentum|global_step|finetune_global_step|Adam|final_dense_dst'171if not re.findall('({})'.format(keep_str,), key):172var_names.append(key)173
174tf.logging.info('Warm-starting tensors: %s', sorted(var_names))175
176vars_to_warm_start = var_names177warm_start_settings = tf_estimator.WarmStartSettings(178ckpt_to_initialize_from=checkpoint_path,179vars_to_warm_start=vars_to_warm_start)180else:181warm_start_settings = None182
183classifier = tf_estimator.Estimator(184get_model_fn(), config=config, warm_start_from=warm_start_settings)185
186def _merge_datasets(train_batch):187feature, label = train_batch['image'], train_batch['label'],188features = {189'feature': feature,190}191labels = {192'label': label,193}194return (features, labels)195
196def get_dataset(dataset_split):197"""Returns dataset creation function."""198
199def make_input_dataset():200"""Returns input dataset."""201train_data = tfds.load(name=FLAGS.target_dataset, split=dataset_split)202train_data = train_data.shuffle(1024).repeat().batch(203FLAGS.train_batch_size)204dataset = tf.data.Dataset.zip((train_data,))205dataset = dataset.map(_merge_datasets)206dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)207return dataset208
209return make_input_dataset210
211# pylint: disable=protected-access212current_step = estimator._load_global_step_from_checkpoint_dir(213FLAGS.model_dir)214
215train_steps = FLAGS.train_steps216while current_step < train_steps:217print('Run {}'.format(current_step))218next_checkpoint = current_step + 500219classifier.train(input_fn=get_dataset('train'), max_steps=next_checkpoint)220current_step = next_checkpoint221
222
223if __name__ == '__main__':224tf.logging.set_verbosity(tf.logging.INFO)225app.run(main)226