google-research
431 строка · 15.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Trains an L2TL model jointly on the source and target datasets."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import re
23
24from absl import app
25from absl import flags
26import model
27import model_utils
28import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
29from tensorflow import estimator as tf_estimator
30import tensorflow_datasets as tfds
31import tensorflow_probability as tfp
32
33FLAGS = flags.FLAGS
34
35flags.DEFINE_string(
36'model_dir',
37None,
38help=('The directory where the model and training/evaluation summaries are'
39' stored.'))
40flags.DEFINE_integer(
41'log_step_count_steps', 200, 'The number of steps at '
42'which the global step information is logged.')
43flags.DEFINE_string(
44'warm_start_ckpt_path', None, 'The path to the checkpoint '
45'that will be used before training.')
46flags.DEFINE_integer('train_steps', 120000, 'Number of total training steps.')
47flags.DEFINE_integer('num_choices', 100,
48'Number of actions for the scaling variable.')
49flags.DEFINE_float('base_learning_rate_scale', 0.001,
50'The value of the learning rate')
51flags.DEFINE_float('dst_weight_decay', 0.0005,
52'Weight decay for the target dataset.')
53flags.DEFINE_integer('save_checkpoints_steps', 100,
54'Number of steps for each checkpoint saving.')
55flags.DEFINE_float('rl_learning_rate', 0.001, 'Learning rate for RL updates.')
56flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for l2tl.')
57flags.DEFINE_integer('target_num_classes', 10,
58'The number of classes in the target dataset.')
59flags.DEFINE_integer('train_batch_size', 128, 'The batch size during training.')
60flags.DEFINE_integer(
61'source_train_batch_multiplier', 5,
62'The multiplier will be used to increase the batch size '
63'to sample more examples.')
64flags.DEFINE_float('loss_weight_scale', 1000.0, 'Scaling of the loss weight.')
65flags.DEFINE_integer('first_pretrain_steps', 0,
66'Number of steps for pretraining.')
67flags.DEFINE_integer('target_val_batch_multiplier', 4,
68'Multiplier for the target evaluation batch size.')
69flags.DEFINE_integer('target_train_batch_multiplier', 1,
70'Multiplier for the target evaluation train batch size.')
71flags.DEFINE_integer('uniform_weight', 0,
72'Use of uniform weight in the ablation studies.')
73
74
75def get_global_step(name):
76"""Returns the global step variable."""
77global_step = tf.get_variable(
78name,
79shape=[],
80dtype=tf.int64,
81initializer=tf.initializers.zeros(),
82trainable=False,
83collections=[tf.GraphKeys.GLOBAL_VARIABLES])
84return global_step
85
86
87def get_src_train_op(loss): # pylint: disable=unused-argument
88"""Returns the source training op."""
89global_step = tf.train.get_global_step()
90src_learning_rate = FLAGS.learning_rate
91src_learning_rate = tf.train.piecewise_constant(global_step, [
92800,
93], [FLAGS.learning_rate, FLAGS.learning_rate * 0.1])
94optimizer = tf.train.MomentumOptimizer(
95learning_rate=src_learning_rate, momentum=0.9, use_nesterov=True)
96with tf.variable_scope('src'):
97return optimizer.minimize(loss, global_step), src_learning_rate
98
99
100def meta_train_op(acc, rl_entropy, log_prob, rl_scope, params): # pylint: disable=unused-argument
101"""Returns the target training op.
102
103Update the control variables using policy gradient.
104Args:
105acc: reward on validation set. In our case, the reward is the top-1 acc;
106rl_entropy: entropy of action logits;
107log_prob: log prob of the action;
108rl_scope: variable scope;
109params: other params;
110
111Returns:
112target_train_op: train op;
113rl_learning_rate: lr;
114out_metric: metric dict;
115"""
116target_global_step = get_global_step('train_rl_global_step')
117rl_reward = acc
118rl_step_baseline = rl_reward
119rl_baseline_momentum = 0.9
120rl_entropy_regularization = 0.001
121
122def update_rl_baseline():
123return model_utils.update_exponential_moving_average(
124rl_step_baseline, momentum=rl_baseline_momentum)
125
126rl_baseline = update_rl_baseline()
127
128rl_advantage = rl_reward - rl_baseline
129rl_empirical_loss = -tf.stop_gradient(rl_advantage) * log_prob
130
131rl_entropy_loss = -rl_entropy_regularization * rl_entropy
132
133enable_rl_optimizer = tf.cast(
134tf.greater_equal(target_global_step, FLAGS.first_pretrain_steps),
135tf.float32)
136rl_learning_rate = FLAGS.rl_learning_rate * enable_rl_optimizer
137rl_learning_rate = tf.train.piecewise_constant(target_global_step, [
138800,
139], [rl_learning_rate, rl_learning_rate * 0.1])
140
141optimizer = tf.train.AdamOptimizer(rl_learning_rate)
142target_train_op = optimizer.minimize(
143rl_empirical_loss,
144target_global_step,
145var_list=tf.trainable_variables(rl_scope.name))
146
147out_metric = {
148'rl_empirical_loss': rl_empirical_loss,
149'rl_entropy_loss': rl_entropy_loss,
150'rl_reward': rl_reward,
151'rl_step_baseline': rl_step_baseline,
152'rl_baseline': rl_baseline,
153'rl_advantage': rl_advantage,
154'log_prob': log_prob,
155}
156return target_train_op, rl_learning_rate, out_metric
157
158
159def get_logits(feature, mode, dataset_name, reuse=None):
160"""Returns the network logits."""
161avg_pool = model.conv_model(
162feature,
163mode,
164target_dataset=FLAGS.target_dataset,
165src_hw=FLAGS.src_hw,
166target_hw=FLAGS.target_hw,
167dataset_name=dataset_name,
168reuse=reuse)
169return avg_pool
170
171
172def do_cls(avg_pool, num_classes, name='dense'):
173"""Applies classification."""
174with tf.variable_scope('target_CLS', reuse=tf.AUTO_REUSE):
175logits = tf.layers.dense(
176inputs=avg_pool,
177units=num_classes,
178kernel_initializer=tf.random_normal_initializer(stddev=.05),
179name=name)
180return logits
181
182
183def get_model_logits(src_features, finetune_features, mode, num_classes,
184target_num_classes):
185"""Gets the logits from different models."""
186src_avg_pool = get_logits(
187src_features, mode, FLAGS.source_dataset, reuse=None)
188dst_avg_pool = get_logits(
189finetune_features, mode, FLAGS.target_dataset, reuse=True)
190
191src_logits = do_cls(src_avg_pool, num_classes, name='final_dense_dst')
192dst_logits = do_cls(
193dst_avg_pool, target_num_classes, name='final_target_dense')
194return src_logits, dst_logits
195
196
197def get_final_loss(src_logits, src_one_hot_labels, dst_logits,
198finetune_one_hot_labels, global_step, loss_weights,
199inst_weights):
200"""Gets the final loss for l2tl."""
201if FLAGS.uniform_weight:
202inst_weights = 1.0
203
204def get_loss(logits, inst_weights, one_hot_labels):
205"""Returns the loss function."""
206loss = tf.losses.softmax_cross_entropy(
207logits=logits, weights=inst_weights, onehot_labels=one_hot_labels)
208return loss
209
210src_loss = get_loss(src_logits, inst_weights, src_one_hot_labels)
211dst_loss = get_loss(dst_logits, 1., finetune_one_hot_labels)
212l2_loss = []
213for v in tf.trainable_variables():
214if 'batch_normalization' not in v.name and 'rl_controller' not in v.name:
215l2_loss.append(tf.nn.l2_loss(v))
216l2_loss = FLAGS.dst_weight_decay * tf.add_n(l2_loss)
217
218enable_pretrain = tf.cast(
219tf.greater_equal(global_step, FLAGS.first_pretrain_steps), tf.float32)
220
221loss = src_loss * tf.stop_gradient(loss_weights) * enable_pretrain
222loss += dst_loss + l2_loss
223
224return tf.identity(loss), src_loss, dst_loss
225
226
227def train_model_fn(features, labels, mode, params): # pylint: disable=unused-argument
228"""Defines the model function."""
229target_num_classes = FLAGS.target_num_classes
230global_step = tf.train.get_global_step()
231
232src_features, src_labels = features['src'], tf.cast(labels['src'], tf.int64)
233finetune_features = features['finetune']
234target_features = features['target']
235
236num_classes = FLAGS.src_num_classes
237
238finetune_one_hot_labels = tf.one_hot(
239tf.cast(labels['finetune'], tf.int64), target_num_classes)
240target_one_hot_labels = tf.one_hot(
241tf.cast(labels['target'], tf.int64), target_num_classes)
242
243with tf.variable_scope('rl_controller') as rl_scope:
244# It creates a `rl_scope` which will be used for ops.
245pass
246rl_entropy, label_weights, log_prob = rl_label_weights(rl_scope)
247loss_entropy, loss_weights, loss_log_prob = get_loss_weights(rl_scope)
248
249def gather_init_weights():
250inst_weights = tf.stop_gradient(tf.gather(label_weights, src_labels))
251return inst_weights
252
253inst_weights = gather_init_weights()
254bs = FLAGS.train_batch_size
255hw = FLAGS.src_hw
256inst_weights, indices = tf.nn.top_k(
257inst_weights,
258k=bs,
259sorted=True,
260)
261
262src_features = tf.reshape(src_features, [
263bs * FLAGS.source_train_batch_multiplier,
264hw,
265hw,
2661,
267])
268src_features = tf.gather(src_features, indices, axis=0)
269src_features = tf.stop_gradient(src_features)
270
271src_labels = tf.gather(src_labels, indices)
272
273inst_weights = bs * inst_weights / tf.reduce_sum(inst_weights)
274
275src_one_hot_labels = tf.one_hot(tf.cast(src_labels, tf.int64), num_classes)
276
277src_logits, dst_logits = get_model_logits(src_features, finetune_features,
278mode, num_classes,
279target_num_classes)
280
281loss, _, _ = get_final_loss(src_logits, src_one_hot_labels, dst_logits,
282finetune_one_hot_labels, global_step,
283loss_weights, inst_weights)
284
285update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
286
287with tf.control_dependencies(update_ops):
288src_train_op, _ = get_src_train_op(loss)
289with tf.control_dependencies([src_train_op]):
290target_avg_pool = get_logits(
291target_features, mode, FLAGS.target_dataset, reuse=True)
292target_logits = do_cls(
293target_avg_pool, target_num_classes, name='final_target_dense')
294is_prediction_correct = tf.equal(
295tf.argmax(tf.identity(target_logits), axis=1),
296tf.argmax(target_one_hot_labels, axis=1))
297acc = tf.reduce_mean(tf.cast(is_prediction_correct, tf.float32))
298
299entropy = loss_entropy + rl_entropy
300log_prob = loss_log_prob + log_prob
301train_op, _, _ = meta_train_op(acc, entropy, log_prob, rl_scope, params)
302
303return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
304
305
306def rl_label_weights(name=None):
307"""Returns the weight for importance."""
308with tf.variable_scope(name, 'rl_op_selection'):
309num_classes = FLAGS.src_num_classes
310num_choices = FLAGS.num_choices
311
312logits = tf.get_variable(
313name='logits_rl_w',
314initializer=tf.initializers.zeros(),
315shape=[num_classes, num_choices],
316dtype=tf.float32)
317dist = tfp.distributions.Categorical(logits=logits)
318dist_entropy = tf.reduce_sum(dist.entropy())
319
320sample = dist.sample()
321sample_masks = 1. * tf.cast(sample, tf.float32) / num_choices
322sample_log_prob = tf.reduce_mean(dist.log_prob(sample))
323
324return (dist_entropy, sample_masks, sample_log_prob)
325
326
327def get_loss_weights(name=None):
328"""Returns the weight for loss."""
329with tf.variable_scope(name, 'rl_op_selection'):
330
331logits = tf.get_variable(
332name='loss_logits_rl_w',
333initializer=tf.initializers.zeros(),
334shape=[
335FLAGS.num_choices,
336],
337dtype=tf.float32)
338dist = tfp.distributions.Categorical(logits=logits)
339dist_entropy = tf.reduce_sum(dist.entropy())
340
341sample = dist.sample()
342sample_masks = 1. * tf.cast(sample, tf.float32) / FLAGS.loss_weight_scale
343sample_log_prob = tf.reduce_mean(dist.log_prob(sample))
344
345return (dist_entropy, sample_masks, sample_log_prob)
346
347
348def main(unused_argv):
349tf.set_random_seed(FLAGS.random_seed)
350
351run_config_args = {
352'model_dir': FLAGS.model_dir,
353'save_checkpoints_steps': FLAGS.save_checkpoints_steps,
354'log_step_count_steps': FLAGS.log_step_count_steps,
355'keep_checkpoint_max': 100,
356}
357config = tf.contrib.tpu.RunConfig(**run_config_args)
358
359if FLAGS.warm_start_ckpt_path:
360var_names = []
361checkpoint_path = FLAGS.warm_start_ckpt_path
362reader = tf.train.NewCheckpointReader(checkpoint_path)
363for key in reader.get_variable_to_shape_map():
364keep_str = 'Momentum|global_step|finetune_global_step'
365if not re.findall('({})'.format(keep_str,), key):
366var_names.append(key)
367
368tf.logging.info('Warm-starting tensors: %s', sorted(var_names))
369
370vars_to_warm_start = var_names
371warm_start_settings = tf_estimator.WarmStartSettings(
372ckpt_to_initialize_from=checkpoint_path,
373vars_to_warm_start=vars_to_warm_start)
374else:
375warm_start_settings = None
376
377l2tl_classifier = tf_estimator.Estimator(
378train_model_fn, config=config, warm_start_from=warm_start_settings)
379
380def make_input_dataset():
381"""Return input dataset."""
382
383def _merge_datasets(train_batch, finetune_batch, target_batch):
384"""Merge different splits."""
385train_features, train_labels = train_batch['image'], train_batch['label']
386finetune_features, finetune_labels = finetune_batch[
387'image'], finetune_batch['label']
388target_features, target_labels = target_batch['image'], target_batch[
389'label']
390features = {
391'src': train_features,
392'finetune': finetune_features,
393'target': target_features
394}
395labels = {
396'src': train_labels,
397'finetune': finetune_labels,
398'target': target_labels
399}
400return (features, labels)
401
402source_train_batch_size = int(
403round(FLAGS.train_batch_size * FLAGS.source_train_batch_multiplier))
404
405train_data = tfds.load(name=FLAGS.source_dataset, split='train')
406train_data = train_data.shuffle(512).repeat().batch(source_train_batch_size)
407
408target_train_batch_size = int(
409round(FLAGS.train_batch_size * FLAGS.target_train_batch_multiplier))
410finetune_data = tfds.load(name=FLAGS.target_dataset, split='train')
411finetune_data = finetune_data.shuffle(512).repeat().batch(
412target_train_batch_size)
413
414target_val_batch_size = int(
415round(FLAGS.train_batch_size * FLAGS.target_val_batch_multiplier))
416
417target_data = tfds.load(name=FLAGS.target_dataset, split='validation')
418target_data = target_data.shuffle(512).repeat().batch(target_val_batch_size)
419
420dataset = tf.data.Dataset.zip((train_data, finetune_data, target_data))
421dataset = dataset.map(_merge_datasets)
422dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
423return dataset
424
425max_train_steps = FLAGS.train_steps
426l2tl_classifier.train(make_input_dataset, max_steps=max_train_steps)
427
428
429if __name__ == '__main__':
430tf.logging.set_verbosity(tf.logging.INFO)
431app.run(main)
432