google-research
363 строки · 12.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2020 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""Trains MentorMix models.
32
33See the README.md file for compilation and running instructions.
34"""
35
36import os
37import time
38import cifar_data_provider
39import numpy as np
40import resnet_model
41import tensorflow as tf
42import tensorflow.contrib.slim as slim
43import utils
44
45flags = tf.app.flags
46
47flags.DEFINE_integer('batch_size', 128, 'The number of images in each batch.')
48
49flags.DEFINE_string('master', None, 'BNS name of the TensorFlow master to use.')
50
51flags.DEFINE_string('data_dir', '', 'Data dir')
52
53flags.DEFINE_string('train_log_dir', '', 'Directory to the save trained model.')
54
55flags.DEFINE_string('dataset_name', 'cifar100', 'cifar10 or cifar100')
56
57flags.DEFINE_string('studentnet', 'resnet32', 'network backbone.')
58
59flags.DEFINE_float('learning_rate', 0.1, 'The learning rate')
60flags.DEFINE_float('learning_rate_decay_factor', 0.9,
61'learning rate decay factor')
62
63flags.DEFINE_float('num_epochs_per_decay', 3,
64'Number of epochs after which learning rate decays.')
65
66flags.DEFINE_integer(
67'save_summaries_secs', 120,
68'The frequency with which summaries are saved, in seconds.')
69
70flags.DEFINE_integer(
71'save_interval_secs', 1200,
72'The frequency with which the model is saved, in seconds.')
73
74flags.DEFINE_integer('max_number_of_steps', 100000,
75'The maximum number of gradient steps.')
76
77flags.DEFINE_integer(
78'ps_tasks', 0,
79'The number of parameter servers. If the value is 0, then the parameters '
80'are handled locally by the worker.')
81
82flags.DEFINE_integer(
83'task', 0,
84'The Task ID. This value is used when training with multiple workers to '
85'identify each worker.')
86
87flags.DEFINE_string('device_id', '0', 'GPU device ID to run the job.')
88
89# Learned MentorNet location
90flags.DEFINE_string('trained_mentornet_dir', '',
91'Directory where to find the trained MentorNet model.')
92
93flags.DEFINE_list('example_dropout_rates', '0.0, 100',
94'Comma-separated list indicating the example drop-out rate.'
95'This has little impact to the performance.')
96
97# Hyper-parameters for MentorMix to tune
98flags.DEFINE_integer('burn_in_epoch', 0, 'Number of first epochs to perform'
99'burn-in. In the burn-in period, every sample has a'
100'fixed 1.0 weight.')
101
102flags.DEFINE_float('loss_p_percentile', 0.7, 'p-percentile used to compute'
103'the loss moving average.')
104
105flags.DEFINE_float('mixup_alpha', 8.0, 'Alpha parameter for the beta'
106'distribution to sample during mixup.')
107
108flags.DEFINE_bool('second_reweight', True, 'Whether to weight the mixed up'
109'examples again with mentornet')
110FLAGS = flags.FLAGS
111
112# Turn this on if there are no log outputs
113tf.logging.set_verbosity(tf.logging.INFO)
114
115
116def resnet_train_step(sess, train_op, global_step, train_step_kwargs):
117"""Function that takes a gradient step and specifies whether to stop.
118
119Args:
120sess: The current session.
121train_op: An `Operation` that evaluates the gradients and returns the
122total loss.
123global_step: A `Tensor` representing the global training step.
124train_step_kwargs: A dictionary of keyword arguments.
125
126Returns:
127The total loss and a boolean indicating whether or not to stop training.
128
129Raises:
130ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
131"""
132start_time = time.time()
133
134total_loss = tf.get_collection('total_loss')[0]
135
136_, np_global_step, total_loss_val = sess.run(
137[train_op, global_step, total_loss])
138
139time_elapsed = time.time() - start_time
140
141if 'should_log' in train_step_kwargs:
142if sess.run(train_step_kwargs['should_log']):
143tf.logging.info('global step %d: loss = %.4f (%.3f sec/step)',
144np_global_step, total_loss_val, time_elapsed)
145
146if 'should_stop' in train_step_kwargs:
147should_stop = sess.run(train_step_kwargs['should_stop'])
148else:
149should_stop = False
150return total_loss, should_stop
151
152
153def train_resnet_mentormix(max_step_run):
154"""Trains the mentornet with the student resnet model.
155
156Args:
157max_step_run: The maximum number of gradient steps.
158"""
159if not os.path.exists(FLAGS.train_log_dir):
160os.makedirs(FLAGS.train_log_dir)
161g = tf.Graph()
162
163with g.as_default():
164with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
165tf_global_step = tf.train.get_or_create_global_step()
166
167(images, one_hot_labels, num_samples_per_epoch,
168num_of_classes) = cifar_data_provider.provide_resnet_data(
169FLAGS.dataset_name,
170'train',
171FLAGS.batch_size,
172dataset_dir=FLAGS.data_dir)
173
174hps = resnet_model.HParams(
175batch_size=FLAGS.batch_size,
176num_classes=num_of_classes,
177min_lrn_rate=0.0001,
178lrn_rate=FLAGS.learning_rate,
179num_residual_units=5,
180use_bottleneck=False,
181weight_decay_rate=0.0002,
182relu_leakiness=0.1,
183optimizer='mom')
184
185images.set_shape([FLAGS.batch_size, 32, 32, 3])
186
187# Define the model:
188resnet = resnet_model.ResNet(hps, images, one_hot_labels, mode='train')
189with tf.variable_scope('ResNet32'):
190logits = resnet.build_model()
191
192# Specify the loss function:
193loss = tf.nn.softmax_cross_entropy_with_logits(
194labels=one_hot_labels, logits=logits)
195
196dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates)
197example_dropout_rates = tf.convert_to_tensor(
198dropout_rates, np.float32, name='example_dropout_rates')
199
200loss_p_percentile = tf.convert_to_tensor(
201np.array([FLAGS.loss_p_percentile] * 100),
202np.float32,
203name='loss_p_percentile')
204
205loss = tf.reshape(loss, [-1, 1])
206
207epoch_step = tf.to_int32(
208tf.floor(tf.divide(tf_global_step, max_step_run) * 100))
209
210zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32)
211
212mentornet_net_hparams = utils.get_mentornet_network_hyperparameter(
213FLAGS.trained_mentornet_dir)
214
215# In the simplest case, this function can be replaced with a thresholding
216# function. See loss_thresholding_function in utils.py.
217v = utils.mentornet(
218epoch_step,
219loss,
220zero_labels,
221loss_p_percentile,
222example_dropout_rates,
223burn_in_epoch=FLAGS.burn_in_epoch,
224mentornet_net_hparams=mentornet_net_hparams,
225avg_name='individual')
226
227v = tf.stop_gradient(v)
228loss = tf.stop_gradient(tf.identity(loss))
229logits = tf.stop_gradient(tf.identity(logits))
230
231# Perform MentorMix
232images_mix, labels_mix = utils.mentor_mix_up(
233images, one_hot_labels, v, FLAGS.mixup_alpha)
234resnet = resnet_model.ResNet(hps, images_mix, labels_mix, mode='train')
235with tf.variable_scope('ResNet32', reuse=True):
236logits_mix = resnet.build_model()
237
238loss = tf.nn.softmax_cross_entropy_with_logits(
239labels=labels_mix, logits=logits_mix)
240decay_loss = resnet.decay()
241
242# second weighting
243if FLAGS.second_reweight:
244loss = tf.reshape(loss, [-1, 1])
245v = utils.mentornet(
246epoch_step,
247loss,
248zero_labels,
249loss_p_percentile,
250example_dropout_rates,
251burn_in_epoch=FLAGS.burn_in_epoch,
252mentornet_net_hparams=mentornet_net_hparams,
253avg_name='mixed')
254v = tf.stop_gradient(v)
255weighted_loss_vector = tf.multiply(loss, v)
256loss = tf.reduce_mean(weighted_loss_vector)
257# reproduced with the following decay loss which should be 0.
258decay_loss = tf.losses.get_regularization_loss()
259decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size)
260
261# Log data utilization
262data_util = utils.summarize_data_utilization(v, tf_global_step,
263FLAGS.batch_size)
264
265loss = tf.reduce_mean(loss)
266slim.summaries.add_scalar_summary(
267tf.reduce_mean(loss), 'mentormix/mix_loss')
268
269weighted_total_loss = loss + decay_loss
270
271slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss')
272tf.add_to_collection('total_loss', weighted_total_loss)
273
274# Set up the moving averages:
275moving_average_variables = tf.trainable_variables()
276moving_average_variables = tf.contrib.framework.filter_variables(
277moving_average_variables, exclude_patterns=['mentornet'])
278
279variable_averages = tf.train.ExponentialMovingAverage(
2800.9999, tf_global_step)
281tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
282variable_averages.apply(moving_average_variables))
283
284decay_steps = FLAGS.num_epochs_per_decay * num_samples_per_epoch / FLAGS.batch_size
285lr = tf.train.exponential_decay(
286FLAGS.learning_rate,
287tf_global_step,
288decay_steps,
289FLAGS.learning_rate_decay_factor,
290staircase=True)
291lr = tf.squeeze(lr)
292slim.summaries.add_scalar_summary(lr, 'learning_rate')
293
294# Specify the optimization scheme:
295with tf.control_dependencies([weighted_total_loss, data_util]):
296# Set up training.
297trainable_variables = tf.trainable_variables()
298trainable_variables = tf.contrib.framework.filter_variables(
299trainable_variables, exclude_patterns=['mentornet'])
300
301grads = tf.gradients(weighted_total_loss, trainable_variables)
302optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)
303
304apply_op = optimizer.apply_gradients(
305zip(grads, trainable_variables),
306global_step=tf_global_step,
307name='train_step')
308
309train_ops = [apply_op] + resnet.extra_train_ops + tf.get_collection(
310tf.GraphKeys.UPDATE_OPS)
311train_op = tf.group(*train_ops)
312
313# Parameter restore setup
314if FLAGS.trained_mentornet_dir is not None:
315ckpt_model = FLAGS.trained_mentornet_dir
316if os.path.isdir(FLAGS.trained_mentornet_dir):
317ckpt_model = tf.train.latest_checkpoint(ckpt_model)
318
319# Fix the mentornet parameters
320variables_to_restore = slim.get_variables_to_restore(
321include=['mentornet', 'mentornet_inputs'])
322iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint(
323ckpt_model, variables_to_restore)
324
325# Create an initial assignment function.
326def init_assign_fn(sess):
327tf.logging.info('Restore using customer initializer %s', '.' * 10)
328sess.run(iassign_op1, ifeed_dict1)
329else:
330init_assign_fn = None
331
332tf.logging.info('-' * 20 + 'MentorMix' + '-' * 20)
333tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile)
334tf.logging.info('mixup_alpha=%d', FLAGS.mixup_alpha)
335tf.logging.info('-' * 20)
336
337saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=24)
338
339# Run training.
340slim.learning.train(
341train_op=train_op,
342train_step_fn=resnet_train_step,
343logdir=FLAGS.train_log_dir,
344master=FLAGS.master,
345is_chief=FLAGS.task == 0,
346saver=saver,
347number_of_steps=max_step_run,
348init_fn=init_assign_fn,
349save_summaries_secs=FLAGS.save_summaries_secs,
350save_interval_secs=FLAGS.save_interval_secs)
351
352
353def main(_):
354os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.device_id
355
356if FLAGS.studentnet == 'resnet32':
357train_resnet_mentormix(FLAGS.max_number_of_steps)
358else:
359tf.logging.error('unknown backbone student network %s', FLAGS.studentnet)
360
361
362if __name__ == '__main__':
363tf.app.run()
364