google-research

Форк
0
402 строки · 14.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Runs MCMC methods for ResNet and LSTM models.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import os
24
import pathlib
25

26
from absl import app
27
from absl import flags
28
from absl import logging
29
import pandas as pd
30
import tensorflow.compat.v2 as tf
31
import tensorflow_datasets as tfds
32
from cold_posterior_bnn import datasets
33
from cold_posterior_bnn import models
34
from cold_posterior_bnn.core import ensemble
35
from cold_posterior_bnn.core import keras_utils
36
from cold_posterior_bnn.core import priorfactory
37
from cold_posterior_bnn.core import sgmcmc
38
from cold_posterior_bnn.core import statistics as stats
39

40

41
# FLAGS experiment
42
flags.DEFINE_string('output_dir', '/tmp/bnn/experiment/',
43
                    'Output directory.')
44
flags.DEFINE_integer('experiment_id', 0, 'ID of this run.')
45
flags.DEFINE_bool(
46
    'write_experiment_metadata_to_csv', False,
47
    'Write hyperparamters to csv file, useful for hyperparameter sweeps.')
48

49
# FLAGS train
50
flags.DEFINE_integer('seed', 42, 'Random seed.')
51
flags.DEFINE_integer('train_epochs', 1000, 'Number of training epochs.')
52
flags.DEFINE_integer('batch_size', 128, 'Batch size.')
53
flags.DEFINE_integer('pretend_batch_size', -1,
54
                     'Batch size used for cycle/epoch computations.')
55

56
flags.DEFINE_float('init_learning_rate', 0.1, 'Learning rate.')
57

58
# FLAGS dataset
59
flags.DEFINE_string('dataset', 'cifar10', 'Dataset from: cifar10, imdb.')
60
flags.DEFINE_bool('cifar_data_augmentation', True,
61
                  'Whether to use basic data augmentation for CIFAR-10 data.')
62
flags.DEFINE_integer('subsample_train_size', 0,
63
                     'Sub-sample training set to given number of samples.')
64

65
# FLAGS model
66
flags.DEFINE_string('model', 'resnet',
67
                    'Model to train, one of: cnnlstm, resnet.')
68
flags.DEFINE_bool('resnet_use_frn', False,
69
                  'Use filter response normalization instead of batchnorm.')
70
flags.DEFINE_bool('resnet_bias', True,
71
                  'Use biases in ResNet Conv2D layers.')
72

73
# Priors
74
flags.DEFINE_string('pfac', 'default',
75
                    'Use "default", "gaussian" prior factory.')
76

77
# FLAGS optimizer
78
flags.DEFINE_string('method', 'sgmcmc', 'MCMC method, one of: sgmcmc, baoab.')
79
flags.DEFINE_float('momentum_decay', 0.9,
80
                   'Momentum decay (used for sgmcmc).')
81

82
# FLAGS preconditioning
83
flags.DEFINE_bool('use_preconditioner', True,
84
                  'Use preconditioning of gradients (updated every epoch).')
85

86
# FLAGS cyclical learning rate ensemble
87
flags.DEFINE_integer('cycle_start_sampling', 10,
88
                     'Start sampling phase after x epoch.')
89
flags.DEFINE_integer('cycle_length', 5, 'Length of one cycle (in epochs).')
90
flags.DEFINE_string('cycle_schedule', 'cosine',
91
                    'Time stepping schedule ("cosine", "glide", or "flat").')
92

93
# FLAGS MCMC
94
flags.DEFINE_float('temperature', 1.,
95
                   'Temperature used in MCMC scheme (used for sgmcmc and hmc).')
96

97
FLAGS = flags.FLAGS
98
DATASET_SEED = 124
99

100

101
# Custom gradient function for SG-MCMC methods
102
def gradest_train_fn():
103
  """Function providing a step function for gradient estimation."""
104

105
  @tf.function
106
  def gest_step(grad_est, model, images, labels):
107
    """Custom gradient of log prior + log likelihood."""
108
    with tf.GradientTape(persistent=True) as tape:
109
      labels = tf.squeeze(labels)
110
      logits = model(images)
111
      ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
112
                                                          labels=labels)
113
      ce = tf.reduce_mean(ce)
114
      prior = sum(model.losses)
115
      obj = ce + prior
116

117
    gradients = tape.gradient(obj, model.trainable_variables)
118
    grad_est.apply_gradients(zip(gradients, model.trainable_variables))
119

120
  def train_step(grad_est, model, data):
121
    images, labels = data
122
    gest_step(grad_est, model, images, labels)
123

124
  return train_step
125

126

127
def main(argv):
128
  del argv  # unused arg
129

130
  tf.io.gfile.makedirs(FLAGS.output_dir)
131

132
  # Load data
133
  tf.random.set_seed(DATASET_SEED)
134

135
  if FLAGS.dataset == 'cifar10':
136
    dataset_train, ds_info = datasets.load_cifar10(
137
        tfds.Split.TRAIN, with_info=True,
138
        data_augmentation=FLAGS.cifar_data_augmentation,
139
        subsample_n=FLAGS.subsample_train_size)
140
    dataset_test = datasets.load_cifar10(tfds.Split.TEST)
141
    logging.info('CIFAR10 dataset loaded.')
142

143
  elif FLAGS.dataset == 'imdb':
144
    dataset, ds_info = datasets.load_imdb(
145
        with_info=True, subsample_n=FLAGS.subsample_train_size)
146
    dataset_train, dataset_test = datasets.get_generators_from_ds(dataset)
147
    logging.info('IMDB dataset loaded.')
148

149
  else:
150
    raise ValueError('Unknown dataset {}.'.format(FLAGS.dataset))
151

152
  # Prepare data for SG-MCMC methods
153
  dataset_size = ds_info['train_num_examples']
154
  dataset_size_orig = ds_info.get('train_num_examples_orig', dataset_size)
155
  dataset_train = dataset_train.repeat().shuffle(10 * FLAGS.batch_size).batch(
156
      FLAGS.batch_size)
157
  test_batch_size = 100
158
  validation_steps = ds_info['test_num_examples'] // test_batch_size
159
  dataset_test_single = dataset_test.batch(FLAGS.batch_size)
160
  dataset_test = dataset_test.repeat().batch(test_batch_size)
161

162
  # If --pretend_batch_size flag is provided any cycle/epoch-length computation
163
  # is done using this pretend_batch_size.  Real batches are all still
164
  # FLAGS.batch_size of length.  This feature is used in the batch size ablation
165
  # study.
166
  #
167
  # Also, always determine number of iterations from original data set size
168
  if FLAGS.pretend_batch_size >= 1:
169
    steps_per_epoch = dataset_size_orig // FLAGS.pretend_batch_size
170
  else:
171
    steps_per_epoch = dataset_size_orig // FLAGS.batch_size
172

173
  # Set seed for the experiment
174
  tf.random.set_seed(FLAGS.seed)
175

176
  # Build model using pfac for proper priors
177
  reg_weight = 1.0 / float(dataset_size)
178
  if FLAGS.pfac == 'default':
179
    pfac = priorfactory.DefaultPriorFactory(weight=reg_weight)
180
  elif FLAGS.pfac == 'gaussian':
181
    pfac = priorfactory.GaussianPriorFactory(prior_stddev=1.0,
182
                                             weight=reg_weight)
183
  else:
184
    raise ValueError('Choose pfac from: default, gaussian.')
185

186
  input_shape = ds_info['input_shape']
187

188
  if FLAGS.model == 'cnnlstm':
189
    assert FLAGS.dataset == 'imdb'
190
    model = models.build_cnnlstm(ds_info['num_words'],
191
                                 ds_info['sequence_length'],
192
                                 pfac)
193

194
  elif FLAGS.model == 'resnet':
195
    assert FLAGS.dataset == 'cifar10'
196
    model = models.build_resnet_v1(
197
        input_shape=input_shape,
198
        depth=20,
199
        num_classes=ds_info['num_classes'],
200
        pfac=pfac,
201
        use_frn=FLAGS.resnet_use_frn,
202
        use_internal_bias=FLAGS.resnet_bias)
203
  else:
204
    raise ValueError('Choose model from: cnnlstm, resnet.')
205

206
  model.summary()
207

208
  # Setup callbacks executed in keras.compile loop
209
  callbacks = []
210

211
  # Setup preconditioner
212
  precond_dict = dict()
213

214
  if FLAGS.use_preconditioner:
215
    precond_dict['preconditioner'] = 'fixed'
216
    logging.info('Use fixed preconditioner.')
217
  else:
218
    logging.info('No preconditioner is used.')
219

220
  # Always append preconditioner callback to compute ctemp statistics
221
  precond_estimator_cb = keras_utils.EstimatePreconditionerCallback(
222
      gradest_train_fn,
223
      iter(dataset_train),
224
      every_nth_epoch=1,
225
      batch_count=32,
226
      raw_second_moment=True,
227
      update_precond=FLAGS.use_preconditioner)
228
  callbacks.append(precond_estimator_cb)
229

230
  # Setup MCMC method
231
  if FLAGS.method == 'sgmcmc':
232
    # SG-MCMC optimizer, first-order symplectic Euler integrator
233
    optimizer = sgmcmc.NaiveSymplecticEulerMCMC(
234
        total_sample_size=dataset_size,
235
        learning_rate=FLAGS.init_learning_rate,
236
        momentum_decay=FLAGS.momentum_decay,
237
        temp=FLAGS.temperature,
238
        **precond_dict)
239
    logging.info('Use symplectic Euler integrator.')
240

241
  elif FLAGS.method == 'baoab':
242
    # SG-MCMC optimizer, second-order accurate BAOAB integrator
243
    optimizer = sgmcmc.BAOABMCMC(
244
        total_sample_size=dataset_size,
245
        learning_rate=FLAGS.init_learning_rate,
246
        momentum_decay=FLAGS.momentum_decay,
247
        temp=FLAGS.temperature,
248
        **precond_dict)
249
    logging.info('Use BAOAB integrator.')
250
  else:
251
    raise ValueError('Choose method from: sgmcmc, baoab.')
252

253
  # Statistics for evaluation of ensemble performance
254
  perf_stats = {
255
      'ens_gce': stats.MeanStatistic(stats.ClassificationGibbsCrossEntropy()),
256
      'ens_ce': stats.MeanStatistic(stats.ClassificationCrossEntropy()),
257
      'ens_ce_sem': stats.StandardError(stats.ClassificationCrossEntropy()),
258
      'ens_brier': stats.MeanStatistic(stats.BrierScore()),
259
      'ens_brier_unc': stats.BrierUncertainty(),
260
      'ens_brier_res': stats.BrierResolution(),
261
      'ens_brier_reliab': stats.BrierReliability(),
262
      'ens_ece': stats.ECE(10),
263
      'ens_gacc': stats.MeanStatistic(stats.GibbsAccuracy()),
264
      'ens_acc': stats.MeanStatistic(stats.Accuracy()),
265
      'ens_acc_sem': stats.StandardError(stats.Accuracy()),
266
  }
267

268
  perf_stats_l, perf_stats_s = zip(*(perf_stats.items()))
269

270
  # Setup ensemble
271
  ens = ensemble.EmpiricalEnsemble(model, input_shape)
272
  last_ens_eval = {'size': 0}  # ensemble size from last evaluation
273

274
  def cycle_ens_eval_maybe():
275
    """Ensemble evaluation callback, only evaluate at end of cycle."""
276

277
    if len(ens) > last_ens_eval['size']:
278
      last_ens_eval['size'] = len(ens)
279
      logging.info('... evaluate ensemble on %d members', len(ens))
280
      return ens.evaluate_ensemble(
281
          dataset=dataset_test_single, statistics=perf_stats_s)
282
    else:
283
      return None
284

285
  ensemble_eval_cb = keras_utils.EvaluateEnsemblePartial(
286
      cycle_ens_eval_maybe, perf_stats_l)
287
  callbacks.append(ensemble_eval_cb)
288

289
  # Setup cyclical learning rate and temperature schedule for sgmcmc
290
  if FLAGS.method == 'sgmcmc' or FLAGS.method == 'baoab':
291
    # setup cyclical learning rate schedule
292
    cyclic_sampler_cb = keras_utils.CyclicSamplerCallback(
293
        ens,
294
        FLAGS.cycle_length * steps_per_epoch,  # number of iterations per cycle
295
        FLAGS.cycle_start_sampling,  # sampling phase start epoch
296
        schedule=FLAGS.cycle_schedule,
297
        min_value=0.0)  # timestep_factor min value
298
    callbacks.append(cyclic_sampler_cb)
299

300
    # Setup temperature ramp-up schedule
301
    begin_ramp_epoch = FLAGS.cycle_start_sampling - FLAGS.cycle_length
302
    if begin_ramp_epoch < 0:
303
      raise ValueError(
304
          'cycle_start_sampling must be greater equal than cycle_length.')
305
    ramp_iterations = FLAGS.cycle_length
306
    tempramp_cb = keras_utils.TemperatureRampScheduler(
307
        0.0, FLAGS.temperature, begin_ramp_epoch * steps_per_epoch,
308
        ramp_iterations * steps_per_epoch)
309
    # T0, Tf, begin_iter, ramp_epochs
310
    callbacks.append(tempramp_cb)
311

312
  # Additional callbacks
313
  # Plot additional logs
314
  def plot_logs(epoch, logs):
315
    del epoch  # unused
316
    logs['lr'] = optimizer.get_config()['learning_rate']
317
    if FLAGS.method == 'sgmcmc':
318
      logs['timestep_factor'] = optimizer.get_config()['timestep_factor']
319
    logs['ens_size'] = len(ens)
320
  plot_logs_cb = tf.keras.callbacks.LambdaCallback(on_epoch_end=plot_logs)
321

322
  # Write logs to tensorboard
323
  tensorboard_cb = tf.keras.callbacks.TensorBoard(
324
      log_dir=FLAGS.output_dir, write_graph=False)
325

326
  # Output ktemp
327
  diag_cb = keras_utils.PrintDiagnosticsCallback(10)
328

329
  callbacks.extend([
330
      diag_cb,
331
      plot_logs_cb,
332
      keras_utils.TemperatureMetric(),
333
      keras_utils.SamplerTemperatureMetric(),
334
      tensorboard_cb,  # Should be after all callbacks that write logs
335
      tf.keras.callbacks.CSVLogger(os.path.join(FLAGS.output_dir, 'logs.csv'))
336
  ])
337

338
  # Keras train model
339
  metrics = [
340
      tf.keras.metrics.SparseCategoricalCrossentropy(
341
          name='negative_log_likelihood',
342
          from_logits=True),
343
      tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
344
  model.compile(
345
      optimizer,
346
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
347
      metrics=metrics)
348
  logging.info('Model input shape: %s', model.input_shape)
349
  logging.info('Model output shape: %s', model.output_shape)
350
  logging.info('Model number of weights: %s', model.count_params())
351

352
  model.fit(
353
      dataset_train,
354
      steps_per_epoch=steps_per_epoch,
355
      epochs=FLAGS.train_epochs,
356
      validation_data=dataset_test,
357
      validation_steps=validation_steps,
358
      callbacks=callbacks)
359

360
  # Evaluate final ensemble performance
361
  logging.info('Ensemble has %d members, computing final performance metrics.',
362
               len(ens))
363

364
  if ens.weights_list:
365
    ens_perf_stats = ens.evaluate_ensemble(dataset_test_single, perf_stats_s)
366
    print('Test set metrics:')
367
    for label, stat_value in zip(perf_stats_l, ens_perf_stats):
368
      stat_value = float(stat_value)
369
      logging.info('%s: %.5f', label, stat_value)
370
      print('%s: %.5f' % (label, stat_value))
371

372
  # Add experiment info to experiment metadata csv file in *parent folder*
373
  if FLAGS.write_experiment_metadata_to_csv:
374
    csv_path = pathlib.Path.joinpath(
375
        pathlib.Path(FLAGS.output_dir).parent, 'run_sweeps.csv')
376
    data = {
377
        'id': [FLAGS.experiment_id],
378
        'seed': [FLAGS.seed],
379
        'temperature': [FLAGS.temperature],
380
        'dir': ['run_{}'.format(FLAGS.experiment_id)]
381
    }
382
    if tf.io.gfile.exists(csv_path):
383
      sweeps_df = pd.read_csv(csv_path)
384
      sweeps_df = pd.concat(
385
          [sweeps_df, pd.DataFrame.from_dict(data)], ignore_index=True
386
      ).set_index('id')
387
    else:
388
      sweeps_df = pd.DataFrame.from_dict(data).set_index('id')
389

390
    # save experiment metadata csv file
391
    sweeps_df.to_csv(csv_path)
392

393

394
if __name__ == '__main__':
395

396
  # Print logging.info directly in shell
397
  def log_print(msg, *args):
398
    print(msg % args)
399
  logging.info = log_print
400

401
  tf.enable_v2_behavior()
402
  app.run(main)
403

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.