google-research

Форк
0
444 строки · 16.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train DNN of a specified architecture on a specified data set."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import json
23
import os
24
import sys
25
import time
26

27
from absl import app
28
from absl import flags
29
from absl import logging
30

31
import numpy as np
32
import tensorflow.compat.v2 as tf
33
from tensorflow.io import gfile
34
import tensorflow_datasets as tfds
35

36
FLAGS = flags.FLAGS
37
CNN_KERNEL_SIZE = 3
38

39
flags.DEFINE_integer('num_layers', 3, 'Number of layers in the network.')
40
flags.DEFINE_integer('num_units', 16, 'Number of units in a dense layer.')
41
flags.DEFINE_integer('batchsize', 512, 'Size of the mini-batch.')
42
flags.DEFINE_float(
43
    'train_fraction', 1.0, 'How much of the dataset to use for'
44
    'training [as fraction]: eg. 0.15, 0.5, 1.0')
45
flags.DEFINE_integer('epochs', 18, 'How many epochs to train for')
46
flags.DEFINE_integer('epochs_between_checkpoints', 6,
47
                     'How many epochs to train between creating checkpoints')
48
flags.DEFINE_integer('random_seed', 42, 'Random seed.')
49
flags.DEFINE_integer('cnn_stride', 2, 'Stride of the CNN')
50
flags.DEFINE_float('dropout', 0.0, 'Dropout Rate')
51
flags.DEFINE_float('l2reg', 0.0, 'L2 regularization strength')
52
flags.DEFINE_float('init_std', 0.05, 'Standard deviation of the initializer.')
53
flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.')
54
flags.DEFINE_string('optimizer', 'sgd',
55
                    'Optimizer algorithm: sgd / adam / momentum.')
56
flags.DEFINE_string('activation', 'relu',
57
                    'Nonlinear activation: relu / tanh / sigmoind / selu.')
58
flags.DEFINE_string(
59
    'w_init', 'he_normal', 'Initialization for weights. '
60
    'see tf.keras.initializers for options')
61
flags.DEFINE_string(
62
    'b_init', 'zero', 'Initialization for biases.'
63
    'see tf.keras.initializers for options')
64
flags.DEFINE_boolean('grayscale', True, 'Convert input images to grayscale.')
65
flags.DEFINE_boolean('augment_traindata', False, 'Augmenting Training data.')
66
flags.DEFINE_boolean('reduce_learningrate', False,
67
                     'Reduce LR towards end of training.')
68
flags.DEFINE_string('dataset', 'mnist', 'Name of the dataset compatible '
69
                    'with TFDS.')
70
flags.DEFINE_string('dnn_architecture', 'cnn',
71
                    'Architecture of the DNN [fc, cnn, cnnbn]')
72
flags.DEFINE_string(
73
    'workdir', '/tmp/dnn_science_workdir', 'Base working directory for storing'
74
    'checkpoints, summaries, etc.')
75
flags.DEFINE_integer('verbose', 0, 'Verbosity')
76
flags.DEFINE_bool('use_tpu', False, 'Whether running on TPU or not.')
77
flags.DEFINE_string('master', 'local',
78
                    'Name of the TensorFlow master to use. "local" for GPU.')
79
flags.DEFINE_string(
80
    'tpu_job_name', 'tpu_worker',
81
    'Name of the TPU worker job. This is required when having multiple TPU '
82
    'worker jobs.')
83

84

85
def _get_workunit_params():
86
  """Get command line parameters of the current process as dict."""
87
  main_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
88
  params = {'config.' + k.name: k.value for k in main_flags}
89
  return params
90

91

92
def store_results(info_dict, filepath):
93
  """Save results in the json file."""
94
  with gfile.GFile(filepath, 'w') as json_fp:
95
    json.dump(info_dict, json_fp)
96

97

98
def restore_results(filepath):
99
  """Retrieve results in the json file."""
100
  with gfile.GFile(filepath, 'r') as json_fp:
101
    info = json.load(json_fp)
102
  return info
103

104

105
def _preprocess_batch(batch,
106
                      normalize,
107
                      to_grayscale,
108
                      augment=False):
109
  """Preprocessing function for each batch of data."""
110
  min_out = -1.0
111
  max_out = 1.0
112
  image = tf.cast(batch['image'], tf.float32)
113
  image /= 255.0
114

115
  if augment:
116
    shape = image.shape
117
    image = tf.image.resize_with_crop_or_pad(image, shape[1] + 2, shape[2] + 2)
118
    image = tf.image.random_crop(image, size=shape)
119

120
    image = tf.image.random_flip_left_right(image)
121
    image = tf.image.random_hue(image, 0.08)
122
    image = tf.image.random_saturation(image, 0.6, 1.6)
123
    image = tf.image.random_brightness(image, 0.05)
124
    image = tf.image.random_contrast(image, 0.7, 1.3)
125

126
  if normalize:
127
    image = min_out + image * (max_out - min_out)
128
  if to_grayscale:
129
    image = tf.math.reduce_mean(image, axis=-1, keepdims=True)
130
  return image, batch['label']
131

132

133
def get_dataset(dataset,
134
                batchsize,
135
                to_grayscale=True,
136
                train_fraction=1.0,
137
                shuffle_buffer=1024,
138
                random_seed=None,
139
                normalize=True,
140
                augment=False):
141
  """Load and preprocess the dataset.
142

143
  Args:
144
    dataset: The dataset name. Either 'toy' or a TFDS dataset
145
    batchsize: the desired batch size
146
    to_grayscale: if True, all images will be converted into grayscale
147
    train_fraction: what fraction of the overall training set should we use
148
    shuffle_buffer: size of the shuffle.buffer for tf.data.Dataset.shuffle
149
    random_seed: random seed for shuffling operations
150
    normalize: whether to normalize the data into [-1, 1]
151
    augment: use data augmentation on the training set.
152

153
  Returns:
154
    tuple (training_dataset, test_dataset, info), where info is a dictionary
155
    with some relevant information about the dataset.
156
  """
157
  data_tr, ds_info = tfds.load(dataset, split='train', with_info=True)
158
  effective_train_size = ds_info.splits['train'].num_examples
159

160
  if train_fraction < 1.0:
161
    effective_train_size = int(effective_train_size * train_fraction)
162
    data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed)
163
    data_tr = data_tr.take(effective_train_size)
164

165
  fn_tr = lambda b: _preprocess_batch(b, normalize, to_grayscale, augment)
166
  data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed)
167
  data_tr = data_tr.batch(batchsize, drop_remainder=True)
168
  data_tr = data_tr.map(fn_tr, tf.data.experimental.AUTOTUNE)
169
  data_tr = data_tr.prefetch(tf.data.experimental.AUTOTUNE)
170

171
  fn_te = lambda b: _preprocess_batch(b, normalize, to_grayscale, False)
172
  data_te = tfds.load(dataset, split='test')
173
  data_te = data_te.batch(batchsize)
174
  data_te = data_te.map(fn_te, tf.data.experimental.AUTOTUNE)
175
  data_te = data_te.prefetch(tf.data.experimental.AUTOTUNE)
176

177
  dataset_info = {
178
      'num_classes': ds_info.features['label'].num_classes,
179
      'data_shape': ds_info.features['image'].shape,
180
      'train_num_examples': effective_train_size
181
  }
182
  return data_tr, data_te, dataset_info
183

184

185
def build_cnn(n_layers, n_hidden, n_outputs, dropout_rate, activation, stride,
186
              w_regularizer, w_init, b_init, use_batchnorm):
187
  """Convolutional deep neural network."""
188
  model = tf.keras.Sequential()
189
  for _ in range(n_layers):
190
    model.add(
191
        tf.keras.layers.Conv2D(
192
            n_hidden,
193
            kernel_size=CNN_KERNEL_SIZE,
194
            strides=stride,
195
            activation=activation,
196
            kernel_regularizer=w_regularizer,
197
            kernel_initializer=w_init,
198
            bias_initializer=b_init))
199
    if dropout_rate > 0.0:
200
      model.add(tf.keras.layers.Dropout(dropout_rate))
201
    if use_batchnorm:
202
      model.add(tf.keras.layers.BatchNormalization())
203
  model.add(tf.keras.layers.GlobalAveragePooling2D())
204
  model.add(
205
      tf.keras.layers.Dense(
206
          n_outputs,
207
          kernel_regularizer=w_regularizer,
208
          kernel_initializer=w_init,
209
          bias_initializer=b_init))
210
  return model
211

212

213
def build_fcn(n_layers, n_hidden, n_outputs, dropout_rate, activation,
214
              w_regularizer, w_init, b_init, use_batchnorm):
215
  """Fully Connected deep neural network."""
216
  model = tf.keras.Sequential()
217
  model.add(tf.keras.layers.Flatten())
218
  for _ in range(n_layers):
219
    model.add(
220
        tf.keras.layers.Dense(
221
            n_hidden,
222
            activation=activation,
223
            kernel_regularizer=w_regularizer,
224
            kernel_initializer=w_init,
225
            bias_initializer=b_init))
226
    if dropout_rate > 0.0:
227
      model.add(tf.keras.layers.Dropout(dropout_rate))
228
    if use_batchnorm:
229
      model.add(tf.keras.layers.BatchNormalization())
230
  model.add(
231
      tf.keras.layers.Dense(
232
          n_outputs,
233
          kernel_regularizer=w_regularizer,
234
          kernel_initializer=w_init,
235
          bias_initializer=b_init))
236
  return model
237

238

239
def eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir):
240
  """Runs Model Evaluation."""
241
  # get training set metrics in eval-mode (no dropout etc.)
242
  metrics_te = model.evaluate(data_te, verbose=0)
243
  res_te = dict(zip(model.metrics_names, metrics_te))
244
  metrics_tr = model.evaluate(data_tr, verbose=0)
245
  res_tr = dict(zip(model.metrics_names, metrics_tr))
246
  metrics = {
247
      'train_accuracy': res_tr['accuracy'],
248
      'train_loss': res_tr['loss'],
249
      'test_accuracy': res_te['accuracy'],
250
      'test_loss': res_te['loss'],
251
  }
252
  for k in metrics:
253
    info[k][cur_epoch] = float(metrics[k])
254
  metrics['epoch'] = cur_epoch  # so it's included in the logging output
255
  print(metrics)
256
  savepath = os.path.join(workdir, 'permanent_ckpt-%d' % cur_epoch)
257
  model.save(savepath)
258

259

260
def run(workdir,
261
        data,
262
        strategy,
263
        architecture,
264
        n_layers,
265
        n_hiddens,
266
        activation,
267
        dropout_rate,
268
        l2_penalty,
269
        w_init_name,
270
        b_init_name,
271
        optimizer_name,
272
        learning_rate,
273
        n_epochs,
274
        epochs_between_checkpoints,
275
        init_stddev,
276
        cnn_stride,
277
        reduce_learningrate=False,
278
        verbosity=0):
279
  """Runs the whole training procedure."""
280
  data_tr, data_te, dataset_info = data
281
  n_outputs = dataset_info['num_classes']
282

283
  with strategy.scope():
284
    optimizer = tf.keras.optimizers.get(optimizer_name)
285
    optimizer.learning_rate = learning_rate
286
    w_init = tf.keras.initializers.get(w_init_name)
287
    if w_init_name.lower() in ['truncatednormal', 'randomnormal']:
288
      w_init.stddev = init_stddev
289
    b_init = tf.keras.initializers.get(b_init_name)
290
    if b_init_name.lower() in ['truncatednormal', 'randomnormal']:
291
      b_init.stddev = init_stddev
292
    w_reg = tf.keras.regularizers.l2(l2_penalty) if l2_penalty > 0 else None
293

294
    if architecture == 'cnn' or architecture == 'cnnbn':
295
      model = build_cnn(n_layers, n_hiddens, n_outputs, dropout_rate,
296
                        activation, cnn_stride, w_reg, w_init, b_init,
297
                        architecture == 'cnnbn')
298
    elif architecture == 'fcn':
299
      model = build_fcn(n_layers, n_hiddens, n_outputs, dropout_rate,
300
                        activation, w_reg, w_init, b_init, False)
301
    else:
302
      assert False, 'Unknown architecture: ' % architecture
303

304
    model.compile(
305
        optimizer=optimizer,
306
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
307
        metrics=['accuracy', 'mse', 'sparse_categorical_crossentropy'])
308

309
  # force the model to set input shapes and init weights
310
  for x, _ in data_tr:
311
    model.predict(x)
312
    if verbosity:
313
      model.summary()
314
    break
315

316
  ckpt = tf.train.Checkpoint(
317
      step=optimizer.iterations, optimizer=optimizer, model=model)
318
  ckpt_dir = os.path.join(workdir, 'temporary-ckpt')
319
  ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
320
  if ckpt_manager.latest_checkpoint:
321
    logging.info('restoring checkpoint: %s', ckpt_manager.latest_checkpoint)
322
    print('restoring from %s' % ckpt_manager.latest_checkpoint)
323
    with strategy.scope():
324
      ckpt.restore(ckpt_manager.latest_checkpoint)
325
    info = restore_results(os.path.join(workdir, '.intermediate-results.json'))
326
    print(info, flush=True)
327
  else:
328
    info = {
329
        'steps': 0,
330
        'start_time': time.time(),
331
        'train_loss': dict(),
332
        'train_accuracy': dict(),
333
        'test_loss': dict(),
334
        'test_accuracy': dict(),
335
    }
336
    info.update(_get_workunit_params())  # Add command line parameters.
337

338
  logger = None
339
  starting_epoch = len(info['train_loss'])
340
  cur_epoch = starting_epoch
341
  for cur_epoch in range(starting_epoch, n_epochs):
342
    if reduce_learningrate and cur_epoch == n_epochs - (n_epochs // 10):
343
      optimizer.learning_rate = learning_rate / 10
344
    elif reduce_learningrate and cur_epoch == n_epochs - 2:
345
      optimizer.learning_rate = learning_rate / 100
346

347
    # Train until we reach the criterion or get NaNs
348
    try:
349
      # always keep checkpoints for the first few epochs
350
      # we evaluate first and train afterwards so we have the at-init data
351
      if cur_epoch < 4 or (cur_epoch % epochs_between_checkpoints) == 0:
352
        eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir)
353

354
      model.fit(data_tr, epochs=1, verbose=verbosity)
355
      ckpt_manager.save()
356
      store_results(info, os.path.join(workdir, '.intermediate-results.json'))
357

358
      dt = time.time() - info['start_time']
359
      logging.info('epoch %d (%3.2fs)', cur_epoch, dt)
360

361
    except tf.errors.InvalidArgumentError as e:
362
      # We got NaN in the loss, most likely gradients resulted in NaNs
363
      logging.info(str(e))
364
      info['status'] = 'NaN'
365
      logging.info('Stop training because NaNs encountered')
366
      break
367

368
  eval_model(model, data_tr, data_te, info, logger, cur_epoch+1, workdir)
369
  store_results(info, os.path.join(workdir, 'results.json'))
370

371
  # we don't need the temporary checkpoints anymore
372
  gfile.rmtree(os.path.join(workdir, 'temporary-ckpt'))
373
  gfile.remove(os.path.join(workdir, '.intermediate-results.json'))
374

375

376
def main(unused_argv):
377
  workdir = FLAGS.workdir
378

379

380
  if not gfile.isdir(workdir):
381
    gfile.makedirs(workdir)
382

383
  tf.random.set_seed(FLAGS.random_seed)
384
  np.random.seed(FLAGS.random_seed)
385
  data = get_dataset(
386
      FLAGS.dataset,
387
      FLAGS.batchsize,
388
      to_grayscale=FLAGS.grayscale,
389
      train_fraction=FLAGS.train_fraction,
390
      random_seed=FLAGS.random_seed,
391
      augment=FLAGS.augment_traindata)
392

393
  # Figure out TPU related stuff and create distribution strategy
394
  use_remote_eager = FLAGS.master and FLAGS.master != 'local'
395
  if FLAGS.use_tpu:
396
    logging.info("Use TPU at %s with job name '%s'.", FLAGS.master,
397
                 FLAGS.tpu_job_name)
398
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
399
        tpu=FLAGS.master, job_name=FLAGS.tpu_job_name)
400
    if use_remote_eager:
401
      tf.config.experimental_connect_to_cluster(resolver)
402
      logging.warning('Remote eager configured. Remote eager can be slow.')
403
    tf.tpu.experimental.initialize_tpu_system(resolver)
404
    strategy = tf.distribute.experimental.TPUStrategy(resolver)
405
  else:
406
    if use_remote_eager:
407
      tf.config.experimental_connect_to_host(
408
          FLAGS.master, job_name='gpu_worker')
409
      logging.warning('Remote eager configured. Remote eager can be slow.')
410
    gpus = tf.config.experimental.list_logical_devices(device_type='GPU')
411
    if gpus:
412
      logging.info('Found GPUs: %s', gpus)
413
      strategy = tf.distribute.MirroredStrategy()
414
    else:
415
      logging.info('Devices: %s', tf.config.list_logical_devices())
416
      strategy = tf.distribute.OneDeviceStrategy('CPU')
417
  logging.info('Devices: %s', tf.config.list_logical_devices())
418
  logging.info('Distribution strategy: %s', strategy)
419
  logging.info('Model directory: %s', workdir)
420

421
  run(workdir,
422
      data,
423
      strategy,
424
      architecture=FLAGS.dnn_architecture,
425
      n_layers=FLAGS.num_layers,
426
      n_hiddens=FLAGS.num_units,
427
      activation=FLAGS.activation,
428
      dropout_rate=FLAGS.dropout,
429
      l2_penalty=FLAGS.l2reg,
430
      w_init_name=FLAGS.w_init,
431
      b_init_name=FLAGS.b_init,
432
      optimizer_name=FLAGS.optimizer,
433
      learning_rate=FLAGS.learning_rate,
434
      n_epochs=FLAGS.epochs,
435
      epochs_between_checkpoints=FLAGS.epochs_between_checkpoints,
436
      init_stddev=FLAGS.init_std,
437
      cnn_stride=FLAGS.cnn_stride,
438
      reduce_learningrate=FLAGS.reduce_learningrate,
439
      verbosity=FLAGS.verbose)
440

441

442
if __name__ == '__main__':
443
  tf.enable_v2_behavior()
444
  app.run(main)
445

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.