google-research
444 строки · 16.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train DNN of a specified architecture on a specified data set."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import json23import os24import sys25import time26
27from absl import app28from absl import flags29from absl import logging30
31import numpy as np32import tensorflow.compat.v2 as tf33from tensorflow.io import gfile34import tensorflow_datasets as tfds35
36FLAGS = flags.FLAGS37CNN_KERNEL_SIZE = 338
39flags.DEFINE_integer('num_layers', 3, 'Number of layers in the network.')40flags.DEFINE_integer('num_units', 16, 'Number of units in a dense layer.')41flags.DEFINE_integer('batchsize', 512, 'Size of the mini-batch.')42flags.DEFINE_float(43'train_fraction', 1.0, 'How much of the dataset to use for'44'training [as fraction]: eg. 0.15, 0.5, 1.0')45flags.DEFINE_integer('epochs', 18, 'How many epochs to train for')46flags.DEFINE_integer('epochs_between_checkpoints', 6,47'How many epochs to train between creating checkpoints')48flags.DEFINE_integer('random_seed', 42, 'Random seed.')49flags.DEFINE_integer('cnn_stride', 2, 'Stride of the CNN')50flags.DEFINE_float('dropout', 0.0, 'Dropout Rate')51flags.DEFINE_float('l2reg', 0.0, 'L2 regularization strength')52flags.DEFINE_float('init_std', 0.05, 'Standard deviation of the initializer.')53flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.')54flags.DEFINE_string('optimizer', 'sgd',55'Optimizer algorithm: sgd / adam / momentum.')56flags.DEFINE_string('activation', 'relu',57'Nonlinear activation: relu / tanh / sigmoind / selu.')58flags.DEFINE_string(59'w_init', 'he_normal', 'Initialization for weights. '60'see tf.keras.initializers for options')61flags.DEFINE_string(62'b_init', 'zero', 'Initialization for biases.'63'see tf.keras.initializers for options')64flags.DEFINE_boolean('grayscale', True, 'Convert input images to grayscale.')65flags.DEFINE_boolean('augment_traindata', False, 'Augmenting Training data.')66flags.DEFINE_boolean('reduce_learningrate', False,67'Reduce LR towards end of training.')68flags.DEFINE_string('dataset', 'mnist', 'Name of the dataset compatible '69'with TFDS.')70flags.DEFINE_string('dnn_architecture', 'cnn',71'Architecture of the DNN [fc, cnn, cnnbn]')72flags.DEFINE_string(73'workdir', '/tmp/dnn_science_workdir', 'Base working directory for storing'74'checkpoints, summaries, etc.')75flags.DEFINE_integer('verbose', 0, 'Verbosity')76flags.DEFINE_bool('use_tpu', False, 'Whether running on TPU or not.')77flags.DEFINE_string('master', 'local',78'Name of the TensorFlow master to use. "local" for GPU.')79flags.DEFINE_string(80'tpu_job_name', 'tpu_worker',81'Name of the TPU worker job. This is required when having multiple TPU '82'worker jobs.')83
84
85def _get_workunit_params():86"""Get command line parameters of the current process as dict."""87main_flags = FLAGS.get_key_flags_for_module(sys.argv[0])88params = {'config.' + k.name: k.value for k in main_flags}89return params90
91
92def store_results(info_dict, filepath):93"""Save results in the json file."""94with gfile.GFile(filepath, 'w') as json_fp:95json.dump(info_dict, json_fp)96
97
98def restore_results(filepath):99"""Retrieve results in the json file."""100with gfile.GFile(filepath, 'r') as json_fp:101info = json.load(json_fp)102return info103
104
105def _preprocess_batch(batch,106normalize,107to_grayscale,108augment=False):109"""Preprocessing function for each batch of data."""110min_out = -1.0111max_out = 1.0112image = tf.cast(batch['image'], tf.float32)113image /= 255.0114
115if augment:116shape = image.shape117image = tf.image.resize_with_crop_or_pad(image, shape[1] + 2, shape[2] + 2)118image = tf.image.random_crop(image, size=shape)119
120image = tf.image.random_flip_left_right(image)121image = tf.image.random_hue(image, 0.08)122image = tf.image.random_saturation(image, 0.6, 1.6)123image = tf.image.random_brightness(image, 0.05)124image = tf.image.random_contrast(image, 0.7, 1.3)125
126if normalize:127image = min_out + image * (max_out - min_out)128if to_grayscale:129image = tf.math.reduce_mean(image, axis=-1, keepdims=True)130return image, batch['label']131
132
133def get_dataset(dataset,134batchsize,135to_grayscale=True,136train_fraction=1.0,137shuffle_buffer=1024,138random_seed=None,139normalize=True,140augment=False):141"""Load and preprocess the dataset.142
143Args:
144dataset: The dataset name. Either 'toy' or a TFDS dataset
145batchsize: the desired batch size
146to_grayscale: if True, all images will be converted into grayscale
147train_fraction: what fraction of the overall training set should we use
148shuffle_buffer: size of the shuffle.buffer for tf.data.Dataset.shuffle
149random_seed: random seed for shuffling operations
150normalize: whether to normalize the data into [-1, 1]
151augment: use data augmentation on the training set.
152
153Returns:
154tuple (training_dataset, test_dataset, info), where info is a dictionary
155with some relevant information about the dataset.
156"""
157data_tr, ds_info = tfds.load(dataset, split='train', with_info=True)158effective_train_size = ds_info.splits['train'].num_examples159
160if train_fraction < 1.0:161effective_train_size = int(effective_train_size * train_fraction)162data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed)163data_tr = data_tr.take(effective_train_size)164
165fn_tr = lambda b: _preprocess_batch(b, normalize, to_grayscale, augment)166data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed)167data_tr = data_tr.batch(batchsize, drop_remainder=True)168data_tr = data_tr.map(fn_tr, tf.data.experimental.AUTOTUNE)169data_tr = data_tr.prefetch(tf.data.experimental.AUTOTUNE)170
171fn_te = lambda b: _preprocess_batch(b, normalize, to_grayscale, False)172data_te = tfds.load(dataset, split='test')173data_te = data_te.batch(batchsize)174data_te = data_te.map(fn_te, tf.data.experimental.AUTOTUNE)175data_te = data_te.prefetch(tf.data.experimental.AUTOTUNE)176
177dataset_info = {178'num_classes': ds_info.features['label'].num_classes,179'data_shape': ds_info.features['image'].shape,180'train_num_examples': effective_train_size181}182return data_tr, data_te, dataset_info183
184
185def build_cnn(n_layers, n_hidden, n_outputs, dropout_rate, activation, stride,186w_regularizer, w_init, b_init, use_batchnorm):187"""Convolutional deep neural network."""188model = tf.keras.Sequential()189for _ in range(n_layers):190model.add(191tf.keras.layers.Conv2D(192n_hidden,193kernel_size=CNN_KERNEL_SIZE,194strides=stride,195activation=activation,196kernel_regularizer=w_regularizer,197kernel_initializer=w_init,198bias_initializer=b_init))199if dropout_rate > 0.0:200model.add(tf.keras.layers.Dropout(dropout_rate))201if use_batchnorm:202model.add(tf.keras.layers.BatchNormalization())203model.add(tf.keras.layers.GlobalAveragePooling2D())204model.add(205tf.keras.layers.Dense(206n_outputs,207kernel_regularizer=w_regularizer,208kernel_initializer=w_init,209bias_initializer=b_init))210return model211
212
213def build_fcn(n_layers, n_hidden, n_outputs, dropout_rate, activation,214w_regularizer, w_init, b_init, use_batchnorm):215"""Fully Connected deep neural network."""216model = tf.keras.Sequential()217model.add(tf.keras.layers.Flatten())218for _ in range(n_layers):219model.add(220tf.keras.layers.Dense(221n_hidden,222activation=activation,223kernel_regularizer=w_regularizer,224kernel_initializer=w_init,225bias_initializer=b_init))226if dropout_rate > 0.0:227model.add(tf.keras.layers.Dropout(dropout_rate))228if use_batchnorm:229model.add(tf.keras.layers.BatchNormalization())230model.add(231tf.keras.layers.Dense(232n_outputs,233kernel_regularizer=w_regularizer,234kernel_initializer=w_init,235bias_initializer=b_init))236return model237
238
239def eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir):240"""Runs Model Evaluation."""241# get training set metrics in eval-mode (no dropout etc.)242metrics_te = model.evaluate(data_te, verbose=0)243res_te = dict(zip(model.metrics_names, metrics_te))244metrics_tr = model.evaluate(data_tr, verbose=0)245res_tr = dict(zip(model.metrics_names, metrics_tr))246metrics = {247'train_accuracy': res_tr['accuracy'],248'train_loss': res_tr['loss'],249'test_accuracy': res_te['accuracy'],250'test_loss': res_te['loss'],251}252for k in metrics:253info[k][cur_epoch] = float(metrics[k])254metrics['epoch'] = cur_epoch # so it's included in the logging output255print(metrics)256savepath = os.path.join(workdir, 'permanent_ckpt-%d' % cur_epoch)257model.save(savepath)258
259
260def run(workdir,261data,262strategy,263architecture,264n_layers,265n_hiddens,266activation,267dropout_rate,268l2_penalty,269w_init_name,270b_init_name,271optimizer_name,272learning_rate,273n_epochs,274epochs_between_checkpoints,275init_stddev,276cnn_stride,277reduce_learningrate=False,278verbosity=0):279"""Runs the whole training procedure."""280data_tr, data_te, dataset_info = data281n_outputs = dataset_info['num_classes']282
283with strategy.scope():284optimizer = tf.keras.optimizers.get(optimizer_name)285optimizer.learning_rate = learning_rate286w_init = tf.keras.initializers.get(w_init_name)287if w_init_name.lower() in ['truncatednormal', 'randomnormal']:288w_init.stddev = init_stddev289b_init = tf.keras.initializers.get(b_init_name)290if b_init_name.lower() in ['truncatednormal', 'randomnormal']:291b_init.stddev = init_stddev292w_reg = tf.keras.regularizers.l2(l2_penalty) if l2_penalty > 0 else None293
294if architecture == 'cnn' or architecture == 'cnnbn':295model = build_cnn(n_layers, n_hiddens, n_outputs, dropout_rate,296activation, cnn_stride, w_reg, w_init, b_init,297architecture == 'cnnbn')298elif architecture == 'fcn':299model = build_fcn(n_layers, n_hiddens, n_outputs, dropout_rate,300activation, w_reg, w_init, b_init, False)301else:302assert False, 'Unknown architecture: ' % architecture303
304model.compile(305optimizer=optimizer,306loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),307metrics=['accuracy', 'mse', 'sparse_categorical_crossentropy'])308
309# force the model to set input shapes and init weights310for x, _ in data_tr:311model.predict(x)312if verbosity:313model.summary()314break315
316ckpt = tf.train.Checkpoint(317step=optimizer.iterations, optimizer=optimizer, model=model)318ckpt_dir = os.path.join(workdir, 'temporary-ckpt')319ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)320if ckpt_manager.latest_checkpoint:321logging.info('restoring checkpoint: %s', ckpt_manager.latest_checkpoint)322print('restoring from %s' % ckpt_manager.latest_checkpoint)323with strategy.scope():324ckpt.restore(ckpt_manager.latest_checkpoint)325info = restore_results(os.path.join(workdir, '.intermediate-results.json'))326print(info, flush=True)327else:328info = {329'steps': 0,330'start_time': time.time(),331'train_loss': dict(),332'train_accuracy': dict(),333'test_loss': dict(),334'test_accuracy': dict(),335}336info.update(_get_workunit_params()) # Add command line parameters.337
338logger = None339starting_epoch = len(info['train_loss'])340cur_epoch = starting_epoch341for cur_epoch in range(starting_epoch, n_epochs):342if reduce_learningrate and cur_epoch == n_epochs - (n_epochs // 10):343optimizer.learning_rate = learning_rate / 10344elif reduce_learningrate and cur_epoch == n_epochs - 2:345optimizer.learning_rate = learning_rate / 100346
347# Train until we reach the criterion or get NaNs348try:349# always keep checkpoints for the first few epochs350# we evaluate first and train afterwards so we have the at-init data351if cur_epoch < 4 or (cur_epoch % epochs_between_checkpoints) == 0:352eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir)353
354model.fit(data_tr, epochs=1, verbose=verbosity)355ckpt_manager.save()356store_results(info, os.path.join(workdir, '.intermediate-results.json'))357
358dt = time.time() - info['start_time']359logging.info('epoch %d (%3.2fs)', cur_epoch, dt)360
361except tf.errors.InvalidArgumentError as e:362# We got NaN in the loss, most likely gradients resulted in NaNs363logging.info(str(e))364info['status'] = 'NaN'365logging.info('Stop training because NaNs encountered')366break367
368eval_model(model, data_tr, data_te, info, logger, cur_epoch+1, workdir)369store_results(info, os.path.join(workdir, 'results.json'))370
371# we don't need the temporary checkpoints anymore372gfile.rmtree(os.path.join(workdir, 'temporary-ckpt'))373gfile.remove(os.path.join(workdir, '.intermediate-results.json'))374
375
376def main(unused_argv):377workdir = FLAGS.workdir378
379
380if not gfile.isdir(workdir):381gfile.makedirs(workdir)382
383tf.random.set_seed(FLAGS.random_seed)384np.random.seed(FLAGS.random_seed)385data = get_dataset(386FLAGS.dataset,387FLAGS.batchsize,388to_grayscale=FLAGS.grayscale,389train_fraction=FLAGS.train_fraction,390random_seed=FLAGS.random_seed,391augment=FLAGS.augment_traindata)392
393# Figure out TPU related stuff and create distribution strategy394use_remote_eager = FLAGS.master and FLAGS.master != 'local'395if FLAGS.use_tpu:396logging.info("Use TPU at %s with job name '%s'.", FLAGS.master,397FLAGS.tpu_job_name)398resolver = tf.distribute.cluster_resolver.TPUClusterResolver(399tpu=FLAGS.master, job_name=FLAGS.tpu_job_name)400if use_remote_eager:401tf.config.experimental_connect_to_cluster(resolver)402logging.warning('Remote eager configured. Remote eager can be slow.')403tf.tpu.experimental.initialize_tpu_system(resolver)404strategy = tf.distribute.experimental.TPUStrategy(resolver)405else:406if use_remote_eager:407tf.config.experimental_connect_to_host(408FLAGS.master, job_name='gpu_worker')409logging.warning('Remote eager configured. Remote eager can be slow.')410gpus = tf.config.experimental.list_logical_devices(device_type='GPU')411if gpus:412logging.info('Found GPUs: %s', gpus)413strategy = tf.distribute.MirroredStrategy()414else:415logging.info('Devices: %s', tf.config.list_logical_devices())416strategy = tf.distribute.OneDeviceStrategy('CPU')417logging.info('Devices: %s', tf.config.list_logical_devices())418logging.info('Distribution strategy: %s', strategy)419logging.info('Model directory: %s', workdir)420
421run(workdir,422data,423strategy,424architecture=FLAGS.dnn_architecture,425n_layers=FLAGS.num_layers,426n_hiddens=FLAGS.num_units,427activation=FLAGS.activation,428dropout_rate=FLAGS.dropout,429l2_penalty=FLAGS.l2reg,430w_init_name=FLAGS.w_init,431b_init_name=FLAGS.b_init,432optimizer_name=FLAGS.optimizer,433learning_rate=FLAGS.learning_rate,434n_epochs=FLAGS.epochs,435epochs_between_checkpoints=FLAGS.epochs_between_checkpoints,436init_stddev=FLAGS.init_std,437cnn_stride=FLAGS.cnn_stride,438reduce_learningrate=FLAGS.reduce_learningrate,439verbosity=FLAGS.verbose)440
441
442if __name__ == '__main__':443tf.enable_v2_behavior()444app.run(main)445