google-research
284 строки · 10.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main train and eval loop for SimCLR+linear layer experiments.
17
18Given an existing trained SimCLR model, trains a linear layer on top to predict
19the original latents from dsprites dataset.
20"""
21
22from absl import app23from absl import flags24from absl import logging25
26from simclr.tf2 import lars_optimizer as lars27import tensorflow.compat.v2 as tf28
29import graph_compression.contrastive_learning.data_utils.learning_latents as data_lib # pylint: disable=unused-import30import graph_compression.contrastive_learning.datasets.learning_latents as datasets_lib31import graph_compression.contrastive_learning.metrics_utils.learning_latents as metrics_lib32import graph_compression.contrastive_learning.models.learning_latents as models_lib33
34
35FLAGS = flags.FLAGS36
37USE_TPU = flags.DEFINE_boolean('use_tpu', False, 'For TPU training.')38
39TPU_ADDRESS = flags.DEFINE_string('tpu_address', None,40'Manually specify a TPU address.')41
42MASTER = flags.DEFINE_string('master', '',43'Required for compatibility, leave blank.')44
45LR = flags.DEFINE_float('learning_rate', 1e-1,46'Learning rate for linear layer.')47
48L2 = flags.DEFINE_float('l2_penalty', 1e-4, 'Penalty for L2 regularization.')49
50T_BATCHSIZE = flags.DEFINE_integer('train_batch_size', 512,51'Batch size for training.')52
53T_STEPS_PER_LOOP = flags.DEFINE_integer(54'train_steps_per_loop', 5,55'How many train steps to run between metrics summaries updates.')56
57TOTAL_STEPS = flags.DEFINE_integer('total_steps', 1,58'Number of steps to train for.')59
60DATA_DIR = flags.DEFINE_string('data_dir', None, 'Directory to log data to.')61
62IMG_SIZE = flags.DEFINE_list(63'img_size', None,64'Optional image rescaling (comma separated list representing [new_height, new_width]).'65)
66
67NUM_CHANNELS = flags.DEFINE_integer(68'num_channels', None, 'Optional image tiling to multiple channels.')69
70PRETRAINED_MODEL_PATH = flags.DEFINE_string('pretrained_model_path', None,71'Path to saved pretrained model.')72
73EVAL_SPLIT = flags.DEFINE_float('eval_split', 0.1,74'Fraction of dataset to use for eval.')75
76E_BATCHSIZE = flags.DEFINE_integer('eval_batch_size', 2048,77'Batch size for eval.')78
79E_FREQ = flags.DEFINE_integer('eval_frequency', 5,80'How often to run eval loop.')81
82SEED = flags.DEFINE_integer('seed', None, 'Specify a random seed.')83
84
85def main(argv):86if len(argv) > 1:87raise app.UsageError('Too many command-line arguments.')88
89
90# set up tpu strategy91if USE_TPU.value:92tpu_address = TPU_ADDRESS.value or MASTER.value93
94cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(95tpu=tpu_address)96tf.config.experimental_connect_to_cluster(cluster_resolver)97tf.tpu.experimental.initialize_tpu_system(cluster_resolver)98strategy = tf.distribute.TPUStrategy(cluster_resolver)99
100else:101# no-op strategy: for debugging / running on one machine102strategy = tf.distribute.get_strategy()103
104
105# actual run106
107with strategy.scope():108
109if IMG_SIZE.value is not None:110img_size = tf.convert_to_tensor([int(v) for v in IMG_SIZE.value])111
112# set up dataset113dataset = datasets_lib.get_standard_dataset(114name='dsprites',115img_size=img_size,116num_channels=NUM_CHANNELS.value,117eval_split=EVAL_SPLIT.value,118seed=SEED.value)119train_df, train_ds, num_train_examples = dataset['train']120eval_df, eval_ds, num_eval_examples = dataset['eval']121
122del train_df, eval_df ## not used here123
124logging.info('Train, eval sets contain %s, %s elements respectively',125num_train_examples, num_eval_examples)126for x in train_ds.take(1):127num_classes = x['values'].shape[0]128logging.info('Num classes is %s', num_classes)129
130logging.info('Setting up datasets...')131t_batchsize, e_batchsize = T_BATCHSIZE.value, E_BATCHSIZE.value132train_ds = train_ds.shuffle(133buffer_size=t_batchsize * 10, reshuffle_each_iteration=True)134# drop the final partial batch for tpu reasons135train_ds_batched = train_ds.batch(t_batchsize, drop_remainder=True)136eval_ds_batched = eval_ds.batch(e_batchsize, drop_remainder=True)137# so now we need to update the number of examples to match138num_train_examples = (num_train_examples // t_batchsize) * t_batchsize139num_eval_examples = (num_eval_examples // e_batchsize) * e_batchsize140
141train_ds_dist = strategy.experimental_distribute_dataset(train_ds_batched)142eval_ds_dist = strategy.experimental_distribute_dataset(eval_ds_batched)143logging.info('Datasets set up, setting up model...')144
145# instantiate optimizer and regularizer146optimizer = lars.LARSOptimizer(LR.value)147
148# loss has to be handled carefully when distributed on multiple cores;149# specify no reduction for now and handle reduction manually in train step.150loss_fn = tf.keras.losses.MeanSquaredError(151reduction=tf.keras.losses.Reduction.NONE)152regularizer = tf.keras.regularizers.L2(L2.value)153
154# set up model155
156model = models_lib.LinearLayerOverPretrainedSimclrModel(157PRETRAINED_MODEL_PATH.value, optimizer, num_classes)158
159logging.info('Optimizer, loss fn, model set up')160
161# set up metrics and summary writers for train and eval, if required162if DATA_DIR.value is not None:163# metrics need to be within a strategy scope164with strategy.scope():165logging.info('Starting on computing y_bar and tss...')166# TODO(zeef): implement a load from cache option to replace this167# For testing, just hardcode the values because the computation is slow.168# These are the values for the dsprites dataset.169
170# y_bar, tss = metrics_lib.get_tss_for_r2(strategy, eval_ds_dist,171# num_classes, num_eval_examples,172# e_batchsize)173
174y_bar = tf.constant([1750.33251953, 0.33551705, 0.33196342, 0.74878615, 0.50025487,1760.49955714, 0.5002258177],178dtype=tf.float32)179tss = tf.constant([18016363.938, 16437.312, 16350.195, 2150.051, 6469.37, 6572.2305,1816550.672182],183dtype=tf.float32)184logging.info('CAUTION! Hardcoded values for y_bar and tss!')185
186logging.info('y_bar, tss are %s, %s', y_bar, tss)187
188train_metrics = metrics_lib.DspritesTrainMetrics(DATA_DIR.value)189eval_metrics = metrics_lib.DspritesEvalMetrics(DATA_DIR.value, tss)190
191logging.info('Metrics set up')192
193# define functions for train step, eval step, metrics update step194@tf.function195def train_step_loop(iterator, steps_per_loop):196
197def step_fn(x):198with tf.GradientTape() as tape:199preds = model(x['image'])200# loss is a tensor of per_example losses of size batch_size/num_replicas201loss = loss_fn(x['values'], preds)202loss += tf.reduce_sum(203[regularizer(w) for w in model.dense_layer.trainable_weights])204# pass this to metrics first so it gets an accurate count of examples205if DATA_DIR.value is not None:206train_metrics.update_metrics(loss, x['values'], preds)207# now average the loss and then also divide by number of replicas208# since the gradients from each replica are added together209loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync210dense_layer_weights = model.dense_layer.trainable_weights211grads = tape.gradient(loss, dense_layer_weights)212model.optimizer.apply_gradients(zip(grads, dense_layer_weights))213
214for _ in tf.range(steps_per_loop):215strategy.run(step_fn, args=(next(iterator),))216
217@tf.function218def eval_step_loop(iterator, steps_per_loop):219
220def step_fn(x):221preds = model(x['image'])222loss = loss_fn(x['values'], preds)223# update eval metics224if DATA_DIR.value is not None:225eval_metrics.update_metrics(loss, x['values'], preds)226# no need to worry about scaling the loss here, since no gradients227
228for _ in tf.range(steps_per_loop):229strategy.run(step_fn, args=(next(iterator),))230
231def metrics_update_loop(metrics_obj, global_step):232for k in metrics_obj.writer_names:233logging.info('Writing metric: %s', k)234with metrics_obj.summary_writers[k].as_default():235metrics_obj.write_metrics_to_summary(236metrics_obj.metrics_dict[k], global_step=global_step)237metrics_obj.summary_writers[k].flush()238for metric in metrics_obj.metrics_dict[k]:239metric.reset_state()240
241# training loop242
243num_eval_steps = num_eval_examples // e_batchsize244num_train_steps = num_train_examples // t_batchsize245num_train_steps_per_loop = T_STEPS_PER_LOOP.value246num_train_loops_per_eval = E_FREQ.value // num_train_steps_per_loop247
248train_iterator_step = 0249current_step = 0250
251logging.info('starting main training loop')252train_iterator = iter(train_ds_dist)253while current_step < TOTAL_STEPS.value:254logging.info('current step %s, global step %s', current_step,255optimizer.iterations.numpy())256for _ in range(num_train_loops_per_eval):257# check there's enough examples left in the iterator and remake if needed258# TODO(zeef): rewrite dataset creation to repeat forever?259if train_iterator_step + num_train_steps_per_loop >= num_train_steps:260train_iterator = iter(train_ds_dist)261train_iterator_step = 0262
263train_step_loop(train_iterator, num_train_steps_per_loop)264
265
266metrics_update_loop(train_metrics, optimizer.iterations.numpy())267
268# keep track of how far through train_iterator we are269train_iterator_step += num_train_steps_per_loop270
271# now run through the entire eval dataset and update eval metrics272logging.info('Updating eval metrics for step %s',273optimizer.iterations.numpy())274eval_iterator = iter(eval_ds_dist)275eval_step_loop(eval_iterator, num_eval_steps)276metrics_update_loop(eval_metrics, optimizer.iterations.numpy())277
278# finally update current_step for the while loop to check279current_step = optimizer.iterations.numpy()280
281
282if __name__ == '__main__':283tf.compat.v1.enable_v2_behavior()284app.run(main)285