google-research
402 строки · 14.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Runs MCMC methods for ResNet and LSTM models.
17"""
18
19from __future__ import absolute_import20from __future__ import division21from __future__ import print_function22
23import os24import pathlib25
26from absl import app27from absl import flags28from absl import logging29import pandas as pd30import tensorflow.compat.v2 as tf31import tensorflow_datasets as tfds32from cold_posterior_bnn import datasets33from cold_posterior_bnn import models34from cold_posterior_bnn.core import ensemble35from cold_posterior_bnn.core import keras_utils36from cold_posterior_bnn.core import priorfactory37from cold_posterior_bnn.core import sgmcmc38from cold_posterior_bnn.core import statistics as stats39
40
41# FLAGS experiment
42flags.DEFINE_string('output_dir', '/tmp/bnn/experiment/',43'Output directory.')44flags.DEFINE_integer('experiment_id', 0, 'ID of this run.')45flags.DEFINE_bool(46'write_experiment_metadata_to_csv', False,47'Write hyperparamters to csv file, useful for hyperparameter sweeps.')48
49# FLAGS train
50flags.DEFINE_integer('seed', 42, 'Random seed.')51flags.DEFINE_integer('train_epochs', 1000, 'Number of training epochs.')52flags.DEFINE_integer('batch_size', 128, 'Batch size.')53flags.DEFINE_integer('pretend_batch_size', -1,54'Batch size used for cycle/epoch computations.')55
56flags.DEFINE_float('init_learning_rate', 0.1, 'Learning rate.')57
58# FLAGS dataset
59flags.DEFINE_string('dataset', 'cifar10', 'Dataset from: cifar10, imdb.')60flags.DEFINE_bool('cifar_data_augmentation', True,61'Whether to use basic data augmentation for CIFAR-10 data.')62flags.DEFINE_integer('subsample_train_size', 0,63'Sub-sample training set to given number of samples.')64
65# FLAGS model
66flags.DEFINE_string('model', 'resnet',67'Model to train, one of: cnnlstm, resnet.')68flags.DEFINE_bool('resnet_use_frn', False,69'Use filter response normalization instead of batchnorm.')70flags.DEFINE_bool('resnet_bias', True,71'Use biases in ResNet Conv2D layers.')72
73# Priors
74flags.DEFINE_string('pfac', 'default',75'Use "default", "gaussian" prior factory.')76
77# FLAGS optimizer
78flags.DEFINE_string('method', 'sgmcmc', 'MCMC method, one of: sgmcmc, baoab.')79flags.DEFINE_float('momentum_decay', 0.9,80'Momentum decay (used for sgmcmc).')81
82# FLAGS preconditioning
83flags.DEFINE_bool('use_preconditioner', True,84'Use preconditioning of gradients (updated every epoch).')85
86# FLAGS cyclical learning rate ensemble
87flags.DEFINE_integer('cycle_start_sampling', 10,88'Start sampling phase after x epoch.')89flags.DEFINE_integer('cycle_length', 5, 'Length of one cycle (in epochs).')90flags.DEFINE_string('cycle_schedule', 'cosine',91'Time stepping schedule ("cosine", "glide", or "flat").')92
93# FLAGS MCMC
94flags.DEFINE_float('temperature', 1.,95'Temperature used in MCMC scheme (used for sgmcmc and hmc).')96
97FLAGS = flags.FLAGS98DATASET_SEED = 12499
100
101# Custom gradient function for SG-MCMC methods
102def gradest_train_fn():103"""Function providing a step function for gradient estimation."""104
105@tf.function106def gest_step(grad_est, model, images, labels):107"""Custom gradient of log prior + log likelihood."""108with tf.GradientTape(persistent=True) as tape:109labels = tf.squeeze(labels)110logits = model(images)111ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,112labels=labels)113ce = tf.reduce_mean(ce)114prior = sum(model.losses)115obj = ce + prior116
117gradients = tape.gradient(obj, model.trainable_variables)118grad_est.apply_gradients(zip(gradients, model.trainable_variables))119
120def train_step(grad_est, model, data):121images, labels = data122gest_step(grad_est, model, images, labels)123
124return train_step125
126
127def main(argv):128del argv # unused arg129
130tf.io.gfile.makedirs(FLAGS.output_dir)131
132# Load data133tf.random.set_seed(DATASET_SEED)134
135if FLAGS.dataset == 'cifar10':136dataset_train, ds_info = datasets.load_cifar10(137tfds.Split.TRAIN, with_info=True,138data_augmentation=FLAGS.cifar_data_augmentation,139subsample_n=FLAGS.subsample_train_size)140dataset_test = datasets.load_cifar10(tfds.Split.TEST)141logging.info('CIFAR10 dataset loaded.')142
143elif FLAGS.dataset == 'imdb':144dataset, ds_info = datasets.load_imdb(145with_info=True, subsample_n=FLAGS.subsample_train_size)146dataset_train, dataset_test = datasets.get_generators_from_ds(dataset)147logging.info('IMDB dataset loaded.')148
149else:150raise ValueError('Unknown dataset {}.'.format(FLAGS.dataset))151
152# Prepare data for SG-MCMC methods153dataset_size = ds_info['train_num_examples']154dataset_size_orig = ds_info.get('train_num_examples_orig', dataset_size)155dataset_train = dataset_train.repeat().shuffle(10 * FLAGS.batch_size).batch(156FLAGS.batch_size)157test_batch_size = 100158validation_steps = ds_info['test_num_examples'] // test_batch_size159dataset_test_single = dataset_test.batch(FLAGS.batch_size)160dataset_test = dataset_test.repeat().batch(test_batch_size)161
162# If --pretend_batch_size flag is provided any cycle/epoch-length computation163# is done using this pretend_batch_size. Real batches are all still164# FLAGS.batch_size of length. This feature is used in the batch size ablation165# study.166#167# Also, always determine number of iterations from original data set size168if FLAGS.pretend_batch_size >= 1:169steps_per_epoch = dataset_size_orig // FLAGS.pretend_batch_size170else:171steps_per_epoch = dataset_size_orig // FLAGS.batch_size172
173# Set seed for the experiment174tf.random.set_seed(FLAGS.seed)175
176# Build model using pfac for proper priors177reg_weight = 1.0 / float(dataset_size)178if FLAGS.pfac == 'default':179pfac = priorfactory.DefaultPriorFactory(weight=reg_weight)180elif FLAGS.pfac == 'gaussian':181pfac = priorfactory.GaussianPriorFactory(prior_stddev=1.0,182weight=reg_weight)183else:184raise ValueError('Choose pfac from: default, gaussian.')185
186input_shape = ds_info['input_shape']187
188if FLAGS.model == 'cnnlstm':189assert FLAGS.dataset == 'imdb'190model = models.build_cnnlstm(ds_info['num_words'],191ds_info['sequence_length'],192pfac)193
194elif FLAGS.model == 'resnet':195assert FLAGS.dataset == 'cifar10'196model = models.build_resnet_v1(197input_shape=input_shape,198depth=20,199num_classes=ds_info['num_classes'],200pfac=pfac,201use_frn=FLAGS.resnet_use_frn,202use_internal_bias=FLAGS.resnet_bias)203else:204raise ValueError('Choose model from: cnnlstm, resnet.')205
206model.summary()207
208# Setup callbacks executed in keras.compile loop209callbacks = []210
211# Setup preconditioner212precond_dict = dict()213
214if FLAGS.use_preconditioner:215precond_dict['preconditioner'] = 'fixed'216logging.info('Use fixed preconditioner.')217else:218logging.info('No preconditioner is used.')219
220# Always append preconditioner callback to compute ctemp statistics221precond_estimator_cb = keras_utils.EstimatePreconditionerCallback(222gradest_train_fn,223iter(dataset_train),224every_nth_epoch=1,225batch_count=32,226raw_second_moment=True,227update_precond=FLAGS.use_preconditioner)228callbacks.append(precond_estimator_cb)229
230# Setup MCMC method231if FLAGS.method == 'sgmcmc':232# SG-MCMC optimizer, first-order symplectic Euler integrator233optimizer = sgmcmc.NaiveSymplecticEulerMCMC(234total_sample_size=dataset_size,235learning_rate=FLAGS.init_learning_rate,236momentum_decay=FLAGS.momentum_decay,237temp=FLAGS.temperature,238**precond_dict)239logging.info('Use symplectic Euler integrator.')240
241elif FLAGS.method == 'baoab':242# SG-MCMC optimizer, second-order accurate BAOAB integrator243optimizer = sgmcmc.BAOABMCMC(244total_sample_size=dataset_size,245learning_rate=FLAGS.init_learning_rate,246momentum_decay=FLAGS.momentum_decay,247temp=FLAGS.temperature,248**precond_dict)249logging.info('Use BAOAB integrator.')250else:251raise ValueError('Choose method from: sgmcmc, baoab.')252
253# Statistics for evaluation of ensemble performance254perf_stats = {255'ens_gce': stats.MeanStatistic(stats.ClassificationGibbsCrossEntropy()),256'ens_ce': stats.MeanStatistic(stats.ClassificationCrossEntropy()),257'ens_ce_sem': stats.StandardError(stats.ClassificationCrossEntropy()),258'ens_brier': stats.MeanStatistic(stats.BrierScore()),259'ens_brier_unc': stats.BrierUncertainty(),260'ens_brier_res': stats.BrierResolution(),261'ens_brier_reliab': stats.BrierReliability(),262'ens_ece': stats.ECE(10),263'ens_gacc': stats.MeanStatistic(stats.GibbsAccuracy()),264'ens_acc': stats.MeanStatistic(stats.Accuracy()),265'ens_acc_sem': stats.StandardError(stats.Accuracy()),266}267
268perf_stats_l, perf_stats_s = zip(*(perf_stats.items()))269
270# Setup ensemble271ens = ensemble.EmpiricalEnsemble(model, input_shape)272last_ens_eval = {'size': 0} # ensemble size from last evaluation273
274def cycle_ens_eval_maybe():275"""Ensemble evaluation callback, only evaluate at end of cycle."""276
277if len(ens) > last_ens_eval['size']:278last_ens_eval['size'] = len(ens)279logging.info('... evaluate ensemble on %d members', len(ens))280return ens.evaluate_ensemble(281dataset=dataset_test_single, statistics=perf_stats_s)282else:283return None284
285ensemble_eval_cb = keras_utils.EvaluateEnsemblePartial(286cycle_ens_eval_maybe, perf_stats_l)287callbacks.append(ensemble_eval_cb)288
289# Setup cyclical learning rate and temperature schedule for sgmcmc290if FLAGS.method == 'sgmcmc' or FLAGS.method == 'baoab':291# setup cyclical learning rate schedule292cyclic_sampler_cb = keras_utils.CyclicSamplerCallback(293ens,294FLAGS.cycle_length * steps_per_epoch, # number of iterations per cycle295FLAGS.cycle_start_sampling, # sampling phase start epoch296schedule=FLAGS.cycle_schedule,297min_value=0.0) # timestep_factor min value298callbacks.append(cyclic_sampler_cb)299
300# Setup temperature ramp-up schedule301begin_ramp_epoch = FLAGS.cycle_start_sampling - FLAGS.cycle_length302if begin_ramp_epoch < 0:303raise ValueError(304'cycle_start_sampling must be greater equal than cycle_length.')305ramp_iterations = FLAGS.cycle_length306tempramp_cb = keras_utils.TemperatureRampScheduler(3070.0, FLAGS.temperature, begin_ramp_epoch * steps_per_epoch,308ramp_iterations * steps_per_epoch)309# T0, Tf, begin_iter, ramp_epochs310callbacks.append(tempramp_cb)311
312# Additional callbacks313# Plot additional logs314def plot_logs(epoch, logs):315del epoch # unused316logs['lr'] = optimizer.get_config()['learning_rate']317if FLAGS.method == 'sgmcmc':318logs['timestep_factor'] = optimizer.get_config()['timestep_factor']319logs['ens_size'] = len(ens)320plot_logs_cb = tf.keras.callbacks.LambdaCallback(on_epoch_end=plot_logs)321
322# Write logs to tensorboard323tensorboard_cb = tf.keras.callbacks.TensorBoard(324log_dir=FLAGS.output_dir, write_graph=False)325
326# Output ktemp327diag_cb = keras_utils.PrintDiagnosticsCallback(10)328
329callbacks.extend([330diag_cb,331plot_logs_cb,332keras_utils.TemperatureMetric(),333keras_utils.SamplerTemperatureMetric(),334tensorboard_cb, # Should be after all callbacks that write logs335tf.keras.callbacks.CSVLogger(os.path.join(FLAGS.output_dir, 'logs.csv'))336])337
338# Keras train model339metrics = [340tf.keras.metrics.SparseCategoricalCrossentropy(341name='negative_log_likelihood',342from_logits=True),343tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]344model.compile(345optimizer,346loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),347metrics=metrics)348logging.info('Model input shape: %s', model.input_shape)349logging.info('Model output shape: %s', model.output_shape)350logging.info('Model number of weights: %s', model.count_params())351
352model.fit(353dataset_train,354steps_per_epoch=steps_per_epoch,355epochs=FLAGS.train_epochs,356validation_data=dataset_test,357validation_steps=validation_steps,358callbacks=callbacks)359
360# Evaluate final ensemble performance361logging.info('Ensemble has %d members, computing final performance metrics.',362len(ens))363
364if ens.weights_list:365ens_perf_stats = ens.evaluate_ensemble(dataset_test_single, perf_stats_s)366print('Test set metrics:')367for label, stat_value in zip(perf_stats_l, ens_perf_stats):368stat_value = float(stat_value)369logging.info('%s: %.5f', label, stat_value)370print('%s: %.5f' % (label, stat_value))371
372# Add experiment info to experiment metadata csv file in *parent folder*373if FLAGS.write_experiment_metadata_to_csv:374csv_path = pathlib.Path.joinpath(375pathlib.Path(FLAGS.output_dir).parent, 'run_sweeps.csv')376data = {377'id': [FLAGS.experiment_id],378'seed': [FLAGS.seed],379'temperature': [FLAGS.temperature],380'dir': ['run_{}'.format(FLAGS.experiment_id)]381}382if tf.io.gfile.exists(csv_path):383sweeps_df = pd.read_csv(csv_path)384sweeps_df = pd.concat(385[sweeps_df, pd.DataFrame.from_dict(data)], ignore_index=True386).set_index('id')387else:388sweeps_df = pd.DataFrame.from_dict(data).set_index('id')389
390# save experiment metadata csv file391sweeps_df.to_csv(csv_path)392
393
394if __name__ == '__main__':395
396# Print logging.info directly in shell397def log_print(msg, *args):398print(msg % args)399logging.info = log_print400
401tf.enable_v2_behavior()402app.run(main)403