google-research

Форк
0
/
train_l2tl.py 
431 строка · 15.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Trains an L2TL model jointly on the source and target datasets."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import re
23

24
from absl import app
25
from absl import flags
26
import model
27
import model_utils
28
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
29
from tensorflow import estimator as tf_estimator
30
import tensorflow_datasets as tfds
31
import tensorflow_probability as tfp
32

33
FLAGS = flags.FLAGS
34

35
flags.DEFINE_string(
36
    'model_dir',
37
    None,
38
    help=('The directory where the model and training/evaluation summaries are'
39
          ' stored.'))
40
flags.DEFINE_integer(
41
    'log_step_count_steps', 200, 'The number of steps at '
42
    'which the global step information is logged.')
43
flags.DEFINE_string(
44
    'warm_start_ckpt_path', None, 'The path to the checkpoint '
45
    'that will be used before training.')
46
flags.DEFINE_integer('train_steps', 120000, 'Number of total training steps.')
47
flags.DEFINE_integer('num_choices', 100,
48
                     'Number of actions for the scaling variable.')
49
flags.DEFINE_float('base_learning_rate_scale', 0.001,
50
                   'The value of the learning rate')
51
flags.DEFINE_float('dst_weight_decay', 0.0005,
52
                   'Weight decay for the target dataset.')
53
flags.DEFINE_integer('save_checkpoints_steps', 100,
54
                     'Number of steps for each checkpoint saving.')
55
flags.DEFINE_float('rl_learning_rate', 0.001, 'Learning rate for RL updates.')
56
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for l2tl.')
57
flags.DEFINE_integer('target_num_classes', 10,
58
                     'The number of classes in the target dataset.')
59
flags.DEFINE_integer('train_batch_size', 128, 'The batch size during training.')
60
flags.DEFINE_integer(
61
    'source_train_batch_multiplier', 5,
62
    'The multiplier will be used to increase the batch size '
63
    'to sample more examples.')
64
flags.DEFINE_float('loss_weight_scale', 1000.0, 'Scaling of the loss weight.')
65
flags.DEFINE_integer('first_pretrain_steps', 0,
66
                     'Number of steps for pretraining.')
67
flags.DEFINE_integer('target_val_batch_multiplier', 4,
68
                     'Multiplier for the target evaluation batch size.')
69
flags.DEFINE_integer('target_train_batch_multiplier', 1,
70
                     'Multiplier for the target evaluation train batch size.')
71
flags.DEFINE_integer('uniform_weight', 0,
72
                     'Use of uniform weight in the ablation studies.')
73

74

75
def get_global_step(name):
76
  """Returns the global step variable."""
77
  global_step = tf.get_variable(
78
      name,
79
      shape=[],
80
      dtype=tf.int64,
81
      initializer=tf.initializers.zeros(),
82
      trainable=False,
83
      collections=[tf.GraphKeys.GLOBAL_VARIABLES])
84
  return global_step
85

86

87
def get_src_train_op(loss):  # pylint: disable=unused-argument
88
  """Returns the source training op."""
89
  global_step = tf.train.get_global_step()
90
  src_learning_rate = FLAGS.learning_rate
91
  src_learning_rate = tf.train.piecewise_constant(global_step, [
92
      800,
93
  ], [FLAGS.learning_rate, FLAGS.learning_rate * 0.1])
94
  optimizer = tf.train.MomentumOptimizer(
95
      learning_rate=src_learning_rate, momentum=0.9, use_nesterov=True)
96
  with tf.variable_scope('src'):
97
    return optimizer.minimize(loss, global_step), src_learning_rate
98

99

100
def meta_train_op(acc, rl_entropy, log_prob, rl_scope, params):  # pylint: disable=unused-argument
101
  """Returns the target training op.
102

103
  Update the control variables using policy gradient.
104
  Args:
105
    acc: reward on validation set. In our case, the reward is the top-1 acc;
106
    rl_entropy: entropy of action logits;
107
    log_prob: log prob of the action;
108
    rl_scope: variable scope;
109
    params: other params;
110

111
  Returns:
112
    target_train_op: train op;
113
    rl_learning_rate: lr;
114
    out_metric: metric dict;
115
  """
116
  target_global_step = get_global_step('train_rl_global_step')
117
  rl_reward = acc
118
  rl_step_baseline = rl_reward
119
  rl_baseline_momentum = 0.9
120
  rl_entropy_regularization = 0.001
121

122
  def update_rl_baseline():
123
    return model_utils.update_exponential_moving_average(
124
        rl_step_baseline, momentum=rl_baseline_momentum)
125

126
  rl_baseline = update_rl_baseline()
127

128
  rl_advantage = rl_reward - rl_baseline
129
  rl_empirical_loss = -tf.stop_gradient(rl_advantage) * log_prob
130

131
  rl_entropy_loss = -rl_entropy_regularization * rl_entropy
132

133
  enable_rl_optimizer = tf.cast(
134
      tf.greater_equal(target_global_step, FLAGS.first_pretrain_steps),
135
      tf.float32)
136
  rl_learning_rate = FLAGS.rl_learning_rate * enable_rl_optimizer
137
  rl_learning_rate = tf.train.piecewise_constant(target_global_step, [
138
      800,
139
  ], [rl_learning_rate, rl_learning_rate * 0.1])
140

141
  optimizer = tf.train.AdamOptimizer(rl_learning_rate)
142
  target_train_op = optimizer.minimize(
143
      rl_empirical_loss,
144
      target_global_step,
145
      var_list=tf.trainable_variables(rl_scope.name))
146

147
  out_metric = {
148
      'rl_empirical_loss': rl_empirical_loss,
149
      'rl_entropy_loss': rl_entropy_loss,
150
      'rl_reward': rl_reward,
151
      'rl_step_baseline': rl_step_baseline,
152
      'rl_baseline': rl_baseline,
153
      'rl_advantage': rl_advantage,
154
      'log_prob': log_prob,
155
  }
156
  return target_train_op, rl_learning_rate, out_metric
157

158

159
def get_logits(feature, mode, dataset_name, reuse=None):
160
  """Returns the network logits."""
161
  avg_pool = model.conv_model(
162
      feature,
163
      mode,
164
      target_dataset=FLAGS.target_dataset,
165
      src_hw=FLAGS.src_hw,
166
      target_hw=FLAGS.target_hw,
167
      dataset_name=dataset_name,
168
      reuse=reuse)
169
  return avg_pool
170

171

172
def do_cls(avg_pool, num_classes, name='dense'):
173
  """Applies classification."""
174
  with tf.variable_scope('target_CLS', reuse=tf.AUTO_REUSE):
175
    logits = tf.layers.dense(
176
        inputs=avg_pool,
177
        units=num_classes,
178
        kernel_initializer=tf.random_normal_initializer(stddev=.05),
179
        name=name)
180
    return logits
181

182

183
def get_model_logits(src_features, finetune_features, mode, num_classes,
184
                     target_num_classes):
185
  """Gets the logits from different models."""
186
  src_avg_pool = get_logits(
187
      src_features, mode, FLAGS.source_dataset, reuse=None)
188
  dst_avg_pool = get_logits(
189
      finetune_features, mode, FLAGS.target_dataset, reuse=True)
190

191
  src_logits = do_cls(src_avg_pool, num_classes, name='final_dense_dst')
192
  dst_logits = do_cls(
193
      dst_avg_pool, target_num_classes, name='final_target_dense')
194
  return src_logits, dst_logits
195

196

197
def get_final_loss(src_logits, src_one_hot_labels, dst_logits,
198
                   finetune_one_hot_labels, global_step, loss_weights,
199
                   inst_weights):
200
  """Gets the final loss for l2tl."""
201
  if FLAGS.uniform_weight:
202
    inst_weights = 1.0
203

204
  def get_loss(logits, inst_weights, one_hot_labels):
205
    """Returns the loss function."""
206
    loss = tf.losses.softmax_cross_entropy(
207
        logits=logits, weights=inst_weights, onehot_labels=one_hot_labels)
208
    return loss
209

210
  src_loss = get_loss(src_logits, inst_weights, src_one_hot_labels)
211
  dst_loss = get_loss(dst_logits, 1., finetune_one_hot_labels)
212
  l2_loss = []
213
  for v in tf.trainable_variables():
214
    if 'batch_normalization' not in v.name and 'rl_controller' not in v.name:
215
      l2_loss.append(tf.nn.l2_loss(v))
216
  l2_loss = FLAGS.dst_weight_decay * tf.add_n(l2_loss)
217

218
  enable_pretrain = tf.cast(
219
      tf.greater_equal(global_step, FLAGS.first_pretrain_steps), tf.float32)
220

221
  loss = src_loss * tf.stop_gradient(loss_weights) * enable_pretrain
222
  loss += dst_loss + l2_loss
223

224
  return tf.identity(loss), src_loss, dst_loss
225

226

227
def train_model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
228
  """Defines the model function."""
229
  target_num_classes = FLAGS.target_num_classes
230
  global_step = tf.train.get_global_step()
231

232
  src_features, src_labels = features['src'], tf.cast(labels['src'], tf.int64)
233
  finetune_features = features['finetune']
234
  target_features = features['target']
235

236
  num_classes = FLAGS.src_num_classes
237

238
  finetune_one_hot_labels = tf.one_hot(
239
      tf.cast(labels['finetune'], tf.int64), target_num_classes)
240
  target_one_hot_labels = tf.one_hot(
241
      tf.cast(labels['target'], tf.int64), target_num_classes)
242

243
  with tf.variable_scope('rl_controller') as rl_scope:
244
    # It creates a `rl_scope` which will be used for ops.
245
    pass
246
  rl_entropy, label_weights, log_prob = rl_label_weights(rl_scope)
247
  loss_entropy, loss_weights, loss_log_prob = get_loss_weights(rl_scope)
248

249
  def gather_init_weights():
250
    inst_weights = tf.stop_gradient(tf.gather(label_weights, src_labels))
251
    return inst_weights
252

253
  inst_weights = gather_init_weights()
254
  bs = FLAGS.train_batch_size
255
  hw = FLAGS.src_hw
256
  inst_weights, indices = tf.nn.top_k(
257
      inst_weights,
258
      k=bs,
259
      sorted=True,
260
  )
261

262
  src_features = tf.reshape(src_features, [
263
      bs * FLAGS.source_train_batch_multiplier,
264
      hw,
265
      hw,
266
      1,
267
  ])
268
  src_features = tf.gather(src_features, indices, axis=0)
269
  src_features = tf.stop_gradient(src_features)
270

271
  src_labels = tf.gather(src_labels, indices)
272

273
  inst_weights = bs * inst_weights / tf.reduce_sum(inst_weights)
274

275
  src_one_hot_labels = tf.one_hot(tf.cast(src_labels, tf.int64), num_classes)
276

277
  src_logits, dst_logits = get_model_logits(src_features, finetune_features,
278
                                            mode, num_classes,
279
                                            target_num_classes)
280

281
  loss, _, _ = get_final_loss(src_logits, src_one_hot_labels, dst_logits,
282
                              finetune_one_hot_labels, global_step,
283
                              loss_weights, inst_weights)
284

285
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
286

287
  with tf.control_dependencies(update_ops):
288
    src_train_op, _ = get_src_train_op(loss)
289
    with tf.control_dependencies([src_train_op]):
290
      target_avg_pool = get_logits(
291
          target_features, mode, FLAGS.target_dataset, reuse=True)
292
      target_logits = do_cls(
293
          target_avg_pool, target_num_classes, name='final_target_dense')
294
      is_prediction_correct = tf.equal(
295
          tf.argmax(tf.identity(target_logits), axis=1),
296
          tf.argmax(target_one_hot_labels, axis=1))
297
      acc = tf.reduce_mean(tf.cast(is_prediction_correct, tf.float32))
298

299
      entropy = loss_entropy + rl_entropy
300
      log_prob = loss_log_prob + log_prob
301
      train_op, _, _ = meta_train_op(acc, entropy, log_prob, rl_scope, params)
302

303
  return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
304

305

306
def rl_label_weights(name=None):
307
  """Returns the weight for importance."""
308
  with tf.variable_scope(name, 'rl_op_selection'):
309
    num_classes = FLAGS.src_num_classes
310
    num_choices = FLAGS.num_choices
311

312
    logits = tf.get_variable(
313
        name='logits_rl_w',
314
        initializer=tf.initializers.zeros(),
315
        shape=[num_classes, num_choices],
316
        dtype=tf.float32)
317
    dist = tfp.distributions.Categorical(logits=logits)
318
    dist_entropy = tf.reduce_sum(dist.entropy())
319

320
    sample = dist.sample()
321
    sample_masks = 1. * tf.cast(sample, tf.float32) / num_choices
322
    sample_log_prob = tf.reduce_mean(dist.log_prob(sample))
323

324
  return (dist_entropy, sample_masks, sample_log_prob)
325

326

327
def get_loss_weights(name=None):
328
  """Returns the weight for loss."""
329
  with tf.variable_scope(name, 'rl_op_selection'):
330

331
    logits = tf.get_variable(
332
        name='loss_logits_rl_w',
333
        initializer=tf.initializers.zeros(),
334
        shape=[
335
            FLAGS.num_choices,
336
        ],
337
        dtype=tf.float32)
338
    dist = tfp.distributions.Categorical(logits=logits)
339
    dist_entropy = tf.reduce_sum(dist.entropy())
340

341
    sample = dist.sample()
342
    sample_masks = 1. * tf.cast(sample, tf.float32) / FLAGS.loss_weight_scale
343
    sample_log_prob = tf.reduce_mean(dist.log_prob(sample))
344

345
  return (dist_entropy, sample_masks, sample_log_prob)
346

347

348
def main(unused_argv):
349
  tf.set_random_seed(FLAGS.random_seed)
350

351
  run_config_args = {
352
      'model_dir': FLAGS.model_dir,
353
      'save_checkpoints_steps': FLAGS.save_checkpoints_steps,
354
      'log_step_count_steps': FLAGS.log_step_count_steps,
355
      'keep_checkpoint_max': 100,
356
  }
357
  config = tf.contrib.tpu.RunConfig(**run_config_args)
358

359
  if FLAGS.warm_start_ckpt_path:
360
    var_names = []
361
    checkpoint_path = FLAGS.warm_start_ckpt_path
362
    reader = tf.train.NewCheckpointReader(checkpoint_path)
363
    for key in reader.get_variable_to_shape_map():
364
      keep_str = 'Momentum|global_step|finetune_global_step'
365
      if not re.findall('({})'.format(keep_str,), key):
366
        var_names.append(key)
367

368
    tf.logging.info('Warm-starting tensors: %s', sorted(var_names))
369

370
    vars_to_warm_start = var_names
371
    warm_start_settings = tf_estimator.WarmStartSettings(
372
        ckpt_to_initialize_from=checkpoint_path,
373
        vars_to_warm_start=vars_to_warm_start)
374
  else:
375
    warm_start_settings = None
376

377
  l2tl_classifier = tf_estimator.Estimator(
378
      train_model_fn, config=config, warm_start_from=warm_start_settings)
379

380
  def make_input_dataset():
381
    """Return input dataset."""
382

383
    def _merge_datasets(train_batch, finetune_batch, target_batch):
384
      """Merge different splits."""
385
      train_features, train_labels = train_batch['image'], train_batch['label']
386
      finetune_features, finetune_labels = finetune_batch[
387
          'image'], finetune_batch['label']
388
      target_features, target_labels = target_batch['image'], target_batch[
389
          'label']
390
      features = {
391
          'src': train_features,
392
          'finetune': finetune_features,
393
          'target': target_features
394
      }
395
      labels = {
396
          'src': train_labels,
397
          'finetune': finetune_labels,
398
          'target': target_labels
399
      }
400
      return (features, labels)
401

402
    source_train_batch_size = int(
403
        round(FLAGS.train_batch_size * FLAGS.source_train_batch_multiplier))
404

405
    train_data = tfds.load(name=FLAGS.source_dataset, split='train')
406
    train_data = train_data.shuffle(512).repeat().batch(source_train_batch_size)
407

408
    target_train_batch_size = int(
409
        round(FLAGS.train_batch_size * FLAGS.target_train_batch_multiplier))
410
    finetune_data = tfds.load(name=FLAGS.target_dataset, split='train')
411
    finetune_data = finetune_data.shuffle(512).repeat().batch(
412
        target_train_batch_size)
413

414
    target_val_batch_size = int(
415
        round(FLAGS.train_batch_size * FLAGS.target_val_batch_multiplier))
416

417
    target_data = tfds.load(name=FLAGS.target_dataset, split='validation')
418
    target_data = target_data.shuffle(512).repeat().batch(target_val_batch_size)
419

420
    dataset = tf.data.Dataset.zip((train_data, finetune_data, target_data))
421
    dataset = dataset.map(_merge_datasets)
422
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
423
    return dataset
424

425
  max_train_steps = FLAGS.train_steps
426
  l2tl_classifier.train(make_input_dataset, max_steps=max_train_steps)
427

428

429
if __name__ == '__main__':
430
  tf.logging.set_verbosity(tf.logging.INFO)
431
  app.run(main)
432

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.