google-research

Форк
0
/
cifar_train_mentormix.py 
363 строки · 12.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2020 Google Inc. All Rights Reserved.
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
# http://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
# ==============================================================================
30

31
"""Trains MentorMix models.
32

33
See the README.md file for compilation and running instructions.
34
"""
35

36
import os
37
import time
38
import cifar_data_provider
39
import numpy as np
40
import resnet_model
41
import tensorflow as tf
42
import tensorflow.contrib.slim as slim
43
import utils
44

45
flags = tf.app.flags
46

47
flags.DEFINE_integer('batch_size', 128, 'The number of images in each batch.')
48

49
flags.DEFINE_string('master', None, 'BNS name of the TensorFlow master to use.')
50

51
flags.DEFINE_string('data_dir', '', 'Data dir')
52

53
flags.DEFINE_string('train_log_dir', '', 'Directory to the save trained model.')
54

55
flags.DEFINE_string('dataset_name', 'cifar100', 'cifar10 or cifar100')
56

57
flags.DEFINE_string('studentnet', 'resnet32', 'network backbone.')
58

59
flags.DEFINE_float('learning_rate', 0.1, 'The learning rate')
60
flags.DEFINE_float('learning_rate_decay_factor', 0.9,
61
                   'learning rate decay factor')
62

63
flags.DEFINE_float('num_epochs_per_decay', 3,
64
                   'Number of epochs after which learning rate decays.')
65

66
flags.DEFINE_integer(
67
    'save_summaries_secs', 120,
68
    'The frequency with which summaries are saved, in seconds.')
69

70
flags.DEFINE_integer(
71
    'save_interval_secs', 1200,
72
    'The frequency with which the model is saved, in seconds.')
73

74
flags.DEFINE_integer('max_number_of_steps', 100000,
75
                     'The maximum number of gradient steps.')
76

77
flags.DEFINE_integer(
78
    'ps_tasks', 0,
79
    'The number of parameter servers. If the value is 0, then the parameters '
80
    'are handled locally by the worker.')
81

82
flags.DEFINE_integer(
83
    'task', 0,
84
    'The Task ID. This value is used when training with multiple workers to '
85
    'identify each worker.')
86

87
flags.DEFINE_string('device_id', '0', 'GPU device ID to run the job.')
88

89
# Learned MentorNet location
90
flags.DEFINE_string('trained_mentornet_dir', '',
91
                    'Directory where to find the trained MentorNet model.')
92

93
flags.DEFINE_list('example_dropout_rates', '0.0, 100',
94
                  'Comma-separated list indicating the example drop-out rate.'
95
                  'This has little impact to the performance.')
96

97
# Hyper-parameters for MentorMix to tune
98
flags.DEFINE_integer('burn_in_epoch', 0, 'Number of first epochs to perform'
99
                     'burn-in. In the burn-in period, every sample has a'
100
                     'fixed 1.0 weight.')
101

102
flags.DEFINE_float('loss_p_percentile', 0.7, 'p-percentile used to compute'
103
                   'the loss moving average.')
104

105
flags.DEFINE_float('mixup_alpha', 8.0, 'Alpha parameter for the beta'
106
                   'distribution to sample during mixup.')
107

108
flags.DEFINE_bool('second_reweight', True, 'Whether to weight the mixed up'
109
                  'examples again with mentornet')
110
FLAGS = flags.FLAGS
111

112
# Turn this on if there are no log outputs
113
tf.logging.set_verbosity(tf.logging.INFO)
114

115

116
def resnet_train_step(sess, train_op, global_step, train_step_kwargs):
117
  """Function that takes a gradient step and specifies whether to stop.
118

119
  Args:
120
    sess: The current session.
121
    train_op: An `Operation` that evaluates the gradients and returns the
122
      total loss.
123
    global_step: A `Tensor` representing the global training step.
124
    train_step_kwargs: A dictionary of keyword arguments.
125

126
  Returns:
127
    The total loss and a boolean indicating whether or not to stop training.
128

129
  Raises:
130
    ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
131
  """
132
  start_time = time.time()
133

134
  total_loss = tf.get_collection('total_loss')[0]
135

136
  _, np_global_step, total_loss_val = sess.run(
137
      [train_op, global_step, total_loss])
138

139
  time_elapsed = time.time() - start_time
140

141
  if 'should_log' in train_step_kwargs:
142
    if sess.run(train_step_kwargs['should_log']):
143
      tf.logging.info('global step %d: loss = %.4f (%.3f sec/step)',
144
                      np_global_step, total_loss_val, time_elapsed)
145

146
  if 'should_stop' in train_step_kwargs:
147
    should_stop = sess.run(train_step_kwargs['should_stop'])
148
  else:
149
    should_stop = False
150
  return total_loss, should_stop
151

152

153
def train_resnet_mentormix(max_step_run):
154
  """Trains the mentornet with the student resnet model.
155

156
  Args:
157
    max_step_run: The maximum number of gradient steps.
158
  """
159
  if not os.path.exists(FLAGS.train_log_dir):
160
    os.makedirs(FLAGS.train_log_dir)
161
  g = tf.Graph()
162

163
  with g.as_default():
164
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
165
      tf_global_step = tf.train.get_or_create_global_step()
166

167
      (images, one_hot_labels, num_samples_per_epoch,
168
       num_of_classes) = cifar_data_provider.provide_resnet_data(
169
           FLAGS.dataset_name,
170
           'train',
171
           FLAGS.batch_size,
172
           dataset_dir=FLAGS.data_dir)
173

174
      hps = resnet_model.HParams(
175
          batch_size=FLAGS.batch_size,
176
          num_classes=num_of_classes,
177
          min_lrn_rate=0.0001,
178
          lrn_rate=FLAGS.learning_rate,
179
          num_residual_units=5,
180
          use_bottleneck=False,
181
          weight_decay_rate=0.0002,
182
          relu_leakiness=0.1,
183
          optimizer='mom')
184

185
      images.set_shape([FLAGS.batch_size, 32, 32, 3])
186

187
      # Define the model:
188
      resnet = resnet_model.ResNet(hps, images, one_hot_labels, mode='train')
189
      with tf.variable_scope('ResNet32'):
190
        logits = resnet.build_model()
191

192
      # Specify the loss function:
193
      loss = tf.nn.softmax_cross_entropy_with_logits(
194
          labels=one_hot_labels, logits=logits)
195

196
      dropout_rates = utils.parse_dropout_rate_list(FLAGS.example_dropout_rates)
197
      example_dropout_rates = tf.convert_to_tensor(
198
          dropout_rates, np.float32, name='example_dropout_rates')
199

200
      loss_p_percentile = tf.convert_to_tensor(
201
          np.array([FLAGS.loss_p_percentile] * 100),
202
          np.float32,
203
          name='loss_p_percentile')
204

205
      loss = tf.reshape(loss, [-1, 1])
206

207
      epoch_step = tf.to_int32(
208
          tf.floor(tf.divide(tf_global_step, max_step_run) * 100))
209

210
      zero_labels = tf.zeros([tf.shape(loss)[0], 1], tf.float32)
211

212
      mentornet_net_hparams = utils.get_mentornet_network_hyperparameter(
213
          FLAGS.trained_mentornet_dir)
214

215
      # In the simplest case, this function can be replaced with a thresholding
216
      # function. See loss_thresholding_function in utils.py.
217
      v = utils.mentornet(
218
          epoch_step,
219
          loss,
220
          zero_labels,
221
          loss_p_percentile,
222
          example_dropout_rates,
223
          burn_in_epoch=FLAGS.burn_in_epoch,
224
          mentornet_net_hparams=mentornet_net_hparams,
225
          avg_name='individual')
226

227
      v = tf.stop_gradient(v)
228
      loss = tf.stop_gradient(tf.identity(loss))
229
      logits = tf.stop_gradient(tf.identity(logits))
230

231
      # Perform MentorMix
232
      images_mix, labels_mix = utils.mentor_mix_up(
233
          images, one_hot_labels, v, FLAGS.mixup_alpha)
234
      resnet = resnet_model.ResNet(hps, images_mix, labels_mix, mode='train')
235
      with tf.variable_scope('ResNet32', reuse=True):
236
        logits_mix = resnet.build_model()
237

238
      loss = tf.nn.softmax_cross_entropy_with_logits(
239
          labels=labels_mix, logits=logits_mix)
240
      decay_loss = resnet.decay()
241

242
      # second weighting
243
      if FLAGS.second_reweight:
244
        loss = tf.reshape(loss, [-1, 1])
245
        v = utils.mentornet(
246
            epoch_step,
247
            loss,
248
            zero_labels,
249
            loss_p_percentile,
250
            example_dropout_rates,
251
            burn_in_epoch=FLAGS.burn_in_epoch,
252
            mentornet_net_hparams=mentornet_net_hparams,
253
            avg_name='mixed')
254
        v = tf.stop_gradient(v)
255
        weighted_loss_vector = tf.multiply(loss, v)
256
        loss = tf.reduce_mean(weighted_loss_vector)
257
        # reproduced with the following decay loss which should be 0.
258
        decay_loss = tf.losses.get_regularization_loss()
259
        decay_loss = decay_loss * (tf.reduce_sum(v) / FLAGS.batch_size)
260

261
      # Log data utilization
262
      data_util = utils.summarize_data_utilization(v, tf_global_step,
263
                                                   FLAGS.batch_size)
264

265
      loss = tf.reduce_mean(loss)
266
      slim.summaries.add_scalar_summary(
267
          tf.reduce_mean(loss), 'mentormix/mix_loss')
268

269
      weighted_total_loss = loss + decay_loss
270

271
      slim.summaries.add_scalar_summary(weighted_total_loss, 'total_loss')
272
      tf.add_to_collection('total_loss', weighted_total_loss)
273

274
      # Set up the moving averages:
275
      moving_average_variables = tf.trainable_variables()
276
      moving_average_variables = tf.contrib.framework.filter_variables(
277
          moving_average_variables, exclude_patterns=['mentornet'])
278

279
      variable_averages = tf.train.ExponentialMovingAverage(
280
          0.9999, tf_global_step)
281
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
282
                           variable_averages.apply(moving_average_variables))
283

284
      decay_steps = FLAGS.num_epochs_per_decay * num_samples_per_epoch / FLAGS.batch_size
285
      lr = tf.train.exponential_decay(
286
          FLAGS.learning_rate,
287
          tf_global_step,
288
          decay_steps,
289
          FLAGS.learning_rate_decay_factor,
290
          staircase=True)
291
      lr = tf.squeeze(lr)
292
      slim.summaries.add_scalar_summary(lr, 'learning_rate')
293

294
      # Specify the optimization scheme:
295
      with tf.control_dependencies([weighted_total_loss, data_util]):
296
        # Set up training.
297
        trainable_variables = tf.trainable_variables()
298
        trainable_variables = tf.contrib.framework.filter_variables(
299
            trainable_variables, exclude_patterns=['mentornet'])
300

301
        grads = tf.gradients(weighted_total_loss, trainable_variables)
302
        optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)
303

304
        apply_op = optimizer.apply_gradients(
305
            zip(grads, trainable_variables),
306
            global_step=tf_global_step,
307
            name='train_step')
308

309
        train_ops = [apply_op] + resnet.extra_train_ops + tf.get_collection(
310
            tf.GraphKeys.UPDATE_OPS)
311
        train_op = tf.group(*train_ops)
312

313
      # Parameter restore setup
314
      if FLAGS.trained_mentornet_dir is not None:
315
        ckpt_model = FLAGS.trained_mentornet_dir
316
        if os.path.isdir(FLAGS.trained_mentornet_dir):
317
          ckpt_model = tf.train.latest_checkpoint(ckpt_model)
318

319
        # Fix the mentornet parameters
320
        variables_to_restore = slim.get_variables_to_restore(
321
            include=['mentornet', 'mentornet_inputs'])
322
        iassign_op1, ifeed_dict1 = tf.contrib.framework.assign_from_checkpoint(
323
            ckpt_model, variables_to_restore)
324

325
        # Create an initial assignment function.
326
        def init_assign_fn(sess):
327
          tf.logging.info('Restore using customer initializer %s', '.' * 10)
328
          sess.run(iassign_op1, ifeed_dict1)
329
      else:
330
        init_assign_fn = None
331

332
      tf.logging.info('-' * 20 + 'MentorMix' + '-' * 20)
333
      tf.logging.info('loss_p_percentile=%3f', FLAGS.loss_p_percentile)
334
      tf.logging.info('mixup_alpha=%d', FLAGS.mixup_alpha)
335
      tf.logging.info('-' * 20)
336

337
      saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=24)
338

339
      # Run training.
340
      slim.learning.train(
341
          train_op=train_op,
342
          train_step_fn=resnet_train_step,
343
          logdir=FLAGS.train_log_dir,
344
          master=FLAGS.master,
345
          is_chief=FLAGS.task == 0,
346
          saver=saver,
347
          number_of_steps=max_step_run,
348
          init_fn=init_assign_fn,
349
          save_summaries_secs=FLAGS.save_summaries_secs,
350
          save_interval_secs=FLAGS.save_interval_secs)
351

352

353
def main(_):
354
  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.device_id
355

356
  if FLAGS.studentnet == 'resnet32':
357
    train_resnet_mentormix(FLAGS.max_number_of_steps)
358
  else:
359
    tf.logging.error('unknown backbone student network %s', FLAGS.studentnet)
360

361

362
if __name__ == '__main__':
363
  tf.app.run()
364

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.