google-research

Форк
0
/
finetuning.py 
225 строк · 7.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Finetunes the pre-trained model on the target set."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import re
23

24
from absl import app
25
from absl import flags
26
import model
27
import model_utils
28
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
29
from tensorflow import estimator as tf_estimator
30
from tensorflow.python.estimator import estimator
31
import tensorflow_datasets as tfds
32

33
flags.DEFINE_string(
34
    'model_dir',
35
    None,
36
    help=('The directory where the model and training/evaluation summaries are'
37
          ' stored.'))
38
flags.DEFINE_string(
39
    'warm_start_ckpt_path', None, 'The path to the checkpoint '
40
    'that will be used before training.')
41
flags.DEFINE_integer(
42
    'log_step_count_steps', 200, 'The number of steps at '
43
    'which the global step information is logged.')
44
flags.DEFINE_integer('train_steps', 100, 'Number of steps for training.')
45
flags.DEFINE_float('target_base_learning_rate', 0.001,
46
                   'Target base learning rate.')
47
flags.DEFINE_integer('train_batch_size', 256,
48
                     'The batch size for the target dataset.')
49
flags.DEFINE_float('weight_decay', 0.0005, 'The value for weight decay.')
50

51
FLAGS = flags.FLAGS
52

53

54
def lr_schedule():
55
  """Learning rate scheduling."""
56
  target_lr = FLAGS.target_base_learning_rate
57
  current_step = tf.train.get_global_step()
58

59
  if FLAGS.target_dataset == 'mnist':
60
    return tf.train.piecewise_constant(current_step, [
61
        500,
62
        1500,
63
    ], [target_lr, target_lr * 0.1, target_lr * 0.01])
64
  else:
65
    return tf.train.piecewise_constant(current_step, [
66
        800,
67
    ], [target_lr, target_lr * 0.1])
68

69

70
def get_model_fn():
71
  """Returns the model definition."""
72

73
  def model_fn(features, labels, mode, params):
74
    """Returns the model function."""
75
    feature = features['feature']
76
    labels = labels['label']
77
    one_hot_labels = model_utils.get_label(
78
        labels,
79
        params,
80
        FLAGS.src_num_classes,
81
        batch_size=FLAGS.train_batch_size)
82

83
    def get_logits():
84
      """Return the logits."""
85
      avg_pool = model.conv_model(
86
          feature,
87
          mode,
88
          target_dataset=FLAGS.target_dataset,
89
          src_hw=FLAGS.src_hw,
90
          target_hw=FLAGS.target_hw)
91
      name = 'final_dense_dst'
92
      with tf.variable_scope('target_CLS'):
93
        logits = tf.layers.dense(
94
            inputs=avg_pool,
95
            units=FLAGS.src_num_classes,
96
            name=name,
97
            kernel_initializer=tf.random_normal_initializer(stddev=.05),
98
        )
99
      return logits
100

101
    logits = get_logits()
102
    logits = tf.cast(logits, tf.float32)
103

104
    dst_loss = tf.losses.softmax_cross_entropy(
105
        logits=logits,
106
        onehot_labels=one_hot_labels,
107
    )
108
    dst_l2_loss = FLAGS.weight_decay * tf.add_n([
109
        tf.nn.l2_loss(v)
110
        for v in tf.trainable_variables()
111
        if 'batch_normalization' not in v.name and 'kernel' in v.name
112
    ])
113

114
    loss = dst_loss + dst_l2_loss
115

116
    train_op = None
117
    if mode == tf_estimator.ModeKeys.TRAIN:
118
      cur_finetune_step = tf.train.get_global_step()
119
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
120
      with tf.control_dependencies(update_ops):
121
        finetune_learning_rate = lr_schedule()
122
        optimizer = tf.train.MomentumOptimizer(
123
            learning_rate=finetune_learning_rate,
124
            momentum=0.9,
125
            use_nesterov=True)
126
        train_op = tf.contrib.slim.learning.create_train_op(loss, optimizer)
127
        with tf.variable_scope('finetune'):
128
          train_op = optimizer.minimize(loss, cur_finetune_step)
129
    else:
130
      train_op = None
131

132
    eval_metrics = None
133
    if mode == tf_estimator.ModeKeys.EVAL:
134
      eval_metrics = model_utils.metric_fn(labels, logits)
135

136
    if mode == tf_estimator.ModeKeys.TRAIN:
137
      with tf.control_dependencies([train_op]):
138
        tf.summary.scalar('classifier/finetune_lr', finetune_learning_rate)
139
    else:
140
      train_op = None
141

142
    return tf_estimator.EstimatorSpec(
143
        mode=mode,
144
        loss=loss,
145
        train_op=train_op,
146
        eval_metric_ops=eval_metrics,
147
    )
148

149
  return model_fn
150

151

152
def main(unused_argv):
153
  tf.set_random_seed(FLAGS.random_seed)
154

155
  save_checkpoints_steps = 100
156
  run_config_args = {
157
      'model_dir': FLAGS.model_dir,
158
      'save_checkpoints_steps': save_checkpoints_steps,
159
      'log_step_count_steps': FLAGS.log_step_count_steps,
160
      'keep_checkpoint_max': 200,
161
  }
162

163
  config = tf_estimator.RunConfig(**run_config_args)
164

165
  if FLAGS.warm_start_ckpt_path:
166
    var_names = []
167
    checkpoint_path = FLAGS.warm_start_ckpt_path
168
    reader = tf.train.NewCheckpointReader(checkpoint_path)
169
    for key in reader.get_variable_to_shape_map():
170
      keep_str = 'Momentum|global_step|finetune_global_step|Adam|final_dense_dst'
171
      if not re.findall('({})'.format(keep_str,), key):
172
        var_names.append(key)
173

174
    tf.logging.info('Warm-starting tensors: %s', sorted(var_names))
175

176
    vars_to_warm_start = var_names
177
    warm_start_settings = tf_estimator.WarmStartSettings(
178
        ckpt_to_initialize_from=checkpoint_path,
179
        vars_to_warm_start=vars_to_warm_start)
180
  else:
181
    warm_start_settings = None
182

183
  classifier = tf_estimator.Estimator(
184
      get_model_fn(), config=config, warm_start_from=warm_start_settings)
185

186
  def _merge_datasets(train_batch):
187
    feature, label = train_batch['image'], train_batch['label'],
188
    features = {
189
        'feature': feature,
190
    }
191
    labels = {
192
        'label': label,
193
    }
194
    return (features, labels)
195

196
  def get_dataset(dataset_split):
197
    """Returns dataset creation function."""
198

199
    def make_input_dataset():
200
      """Returns input dataset."""
201
      train_data = tfds.load(name=FLAGS.target_dataset, split=dataset_split)
202
      train_data = train_data.shuffle(1024).repeat().batch(
203
          FLAGS.train_batch_size)
204
      dataset = tf.data.Dataset.zip((train_data,))
205
      dataset = dataset.map(_merge_datasets)
206
      dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
207
      return dataset
208

209
    return make_input_dataset
210

211
  # pylint: disable=protected-access
212
  current_step = estimator._load_global_step_from_checkpoint_dir(
213
      FLAGS.model_dir)
214

215
  train_steps = FLAGS.train_steps
216
  while current_step < train_steps:
217
    print('Run {}'.format(current_step))
218
    next_checkpoint = current_step + 500
219
    classifier.train(input_fn=get_dataset('train'), max_steps=next_checkpoint)
220
    current_step = next_checkpoint
221

222

223
if __name__ == '__main__':
224
  tf.logging.set_verbosity(tf.logging.INFO)
225
  app.run(main)
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.