google-research
137 строк · 5.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Training loop for object discovery with Slot Attention."""
17import datetime18import time19
20from absl import app21from absl import flags22from absl import logging23import tensorflow as tf24
25import slot_attention.data as data_utils26import slot_attention.model as model_utils27import slot_attention.utils as utils28
29
30FLAGS = flags.FLAGS31flags.DEFINE_string("model_dir", "/tmp/object_discovery/",32"Where to save the checkpoints.")33flags.DEFINE_integer("seed", 0, "Random seed.")34flags.DEFINE_integer("batch_size", 64, "Batch size for the model.")35flags.DEFINE_integer("num_slots", 7, "Number of slots in Slot Attention.")36flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")37flags.DEFINE_float("learning_rate", 0.0004, "Learning rate.")38flags.DEFINE_integer("num_train_steps", 500000, "Number of training steps.")39flags.DEFINE_integer("warmup_steps", 10000,40"Number of warmup steps for the learning rate.")41flags.DEFINE_float("decay_rate", 0.5, "Rate for the learning rate decay.")42flags.DEFINE_integer("decay_steps", 100000,43"Number of steps for the learning rate decay.")44
45
46# We use `tf.function` compilation to speed up execution. For debugging,
47# consider commenting out the `@tf.function` decorator.
48@tf.function49def train_step(batch, model, optimizer):50"""Perform a single training step."""51
52# Get the prediction of the models and compute the loss.53with tf.GradientTape() as tape:54preds = model(batch["image"], training=True)55recon_combined, recons, masks, slots = preds56loss_value = utils.l2_loss(batch["image"], recon_combined)57del recons, masks, slots # Unused.58
59# Get and apply gradients.60gradients = tape.gradient(loss_value, model.trainable_weights)61optimizer.apply_gradients(zip(gradients, model.trainable_weights))62
63return loss_value64
65
66def main(argv):67del argv68# Hyperparameters of the model.69batch_size = FLAGS.batch_size70num_slots = FLAGS.num_slots71num_iterations = FLAGS.num_iterations72base_learning_rate = FLAGS.learning_rate73num_train_steps = FLAGS.num_train_steps74warmup_steps = FLAGS.warmup_steps75decay_rate = FLAGS.decay_rate76decay_steps = FLAGS.decay_steps77tf.random.set_seed(FLAGS.seed)78resolution = (128, 128)79
80# Build dataset iterators, optimizers and model.81data_iterator = data_utils.build_clevr_iterator(82batch_size, split="train", resolution=resolution, shuffle=True,83max_n_objects=6, get_properties=False, apply_crop=True)84
85optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)86
87model = model_utils.build_model(resolution, batch_size, num_slots,88num_iterations, model_type="object_discovery")89
90# Prepare checkpoint manager.91global_step = tf.Variable(920, trainable=False, name="global_step", dtype=tf.int64)93ckpt = tf.train.Checkpoint(94network=model, optimizer=optimizer, global_step=global_step)95ckpt_manager = tf.train.CheckpointManager(96checkpoint=ckpt, directory=FLAGS.model_dir, max_to_keep=5)97ckpt.restore(ckpt_manager.latest_checkpoint)98if ckpt_manager.latest_checkpoint:99logging.info("Restored from %s", ckpt_manager.latest_checkpoint)100else:101logging.info("Initializing from scratch.")102
103start = time.time()104for _ in range(num_train_steps):105batch = next(data_iterator)106
107# Learning rate warm-up.108if global_step < warmup_steps:109learning_rate = base_learning_rate * tf.cast(110global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)111else:112learning_rate = base_learning_rate113learning_rate = learning_rate * (decay_rate ** (114tf.cast(global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))115optimizer.lr = learning_rate.numpy()116
117loss_value = train_step(batch, model, optimizer)118
119# Update the global step. We update it before logging the loss and saving120# the model so that the last checkpoint is saved at the last iteration.121global_step.assign_add(1)122
123# Log the training loss.124if not global_step % 100:125logging.info("Step: %s, Loss: %.6f, Time: %s",126global_step.numpy(), loss_value,127datetime.timedelta(seconds=time.time() - start))128
129# We save the checkpoints every 1000 iterations.130if not global_step % 1000:131# Save the checkpoint of the model.132saved_ckpt = ckpt_manager.save()133logging.info("Saved checkpoint: %s", saved_ckpt)134
135
136if __name__ == "__main__":137app.run(main)138