google-research

Форк
0
137 строк · 5.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Training loop for object discovery with Slot Attention."""
17
import datetime
18
import time
19

20
from absl import app
21
from absl import flags
22
from absl import logging
23
import tensorflow as tf
24

25
import slot_attention.data as data_utils
26
import slot_attention.model as model_utils
27
import slot_attention.utils as utils
28

29

30
FLAGS = flags.FLAGS
31
flags.DEFINE_string("model_dir", "/tmp/object_discovery/",
32
                    "Where to save the checkpoints.")
33
flags.DEFINE_integer("seed", 0, "Random seed.")
34
flags.DEFINE_integer("batch_size", 64, "Batch size for the model.")
35
flags.DEFINE_integer("num_slots", 7, "Number of slots in Slot Attention.")
36
flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")
37
flags.DEFINE_float("learning_rate", 0.0004, "Learning rate.")
38
flags.DEFINE_integer("num_train_steps", 500000, "Number of training steps.")
39
flags.DEFINE_integer("warmup_steps", 10000,
40
                     "Number of warmup steps for the learning rate.")
41
flags.DEFINE_float("decay_rate", 0.5, "Rate for the learning rate decay.")
42
flags.DEFINE_integer("decay_steps", 100000,
43
                     "Number of steps for the learning rate decay.")
44

45

46
# We use `tf.function` compilation to speed up execution. For debugging,
47
# consider commenting out the `@tf.function` decorator.
48
@tf.function
49
def train_step(batch, model, optimizer):
50
  """Perform a single training step."""
51

52
  # Get the prediction of the models and compute the loss.
53
  with tf.GradientTape() as tape:
54
    preds = model(batch["image"], training=True)
55
    recon_combined, recons, masks, slots = preds
56
    loss_value = utils.l2_loss(batch["image"], recon_combined)
57
    del recons, masks, slots  # Unused.
58

59
  # Get and apply gradients.
60
  gradients = tape.gradient(loss_value, model.trainable_weights)
61
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))
62

63
  return loss_value
64

65

66
def main(argv):
67
  del argv
68
  # Hyperparameters of the model.
69
  batch_size = FLAGS.batch_size
70
  num_slots = FLAGS.num_slots
71
  num_iterations = FLAGS.num_iterations
72
  base_learning_rate = FLAGS.learning_rate
73
  num_train_steps = FLAGS.num_train_steps
74
  warmup_steps = FLAGS.warmup_steps
75
  decay_rate = FLAGS.decay_rate
76
  decay_steps = FLAGS.decay_steps
77
  tf.random.set_seed(FLAGS.seed)
78
  resolution = (128, 128)
79

80
  # Build dataset iterators, optimizers and model.
81
  data_iterator = data_utils.build_clevr_iterator(
82
      batch_size, split="train", resolution=resolution, shuffle=True,
83
      max_n_objects=6, get_properties=False, apply_crop=True)
84

85
  optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)
86

87
  model = model_utils.build_model(resolution, batch_size, num_slots,
88
                                  num_iterations, model_type="object_discovery")
89

90
  # Prepare checkpoint manager.
91
  global_step = tf.Variable(
92
      0, trainable=False, name="global_step", dtype=tf.int64)
93
  ckpt = tf.train.Checkpoint(
94
      network=model, optimizer=optimizer, global_step=global_step)
95
  ckpt_manager = tf.train.CheckpointManager(
96
      checkpoint=ckpt, directory=FLAGS.model_dir, max_to_keep=5)
97
  ckpt.restore(ckpt_manager.latest_checkpoint)
98
  if ckpt_manager.latest_checkpoint:
99
    logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
100
  else:
101
    logging.info("Initializing from scratch.")
102

103
  start = time.time()
104
  for _ in range(num_train_steps):
105
    batch = next(data_iterator)
106

107
    # Learning rate warm-up.
108
    if global_step < warmup_steps:
109
      learning_rate = base_learning_rate * tf.cast(
110
          global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
111
    else:
112
      learning_rate = base_learning_rate
113
    learning_rate = learning_rate * (decay_rate ** (
114
        tf.cast(global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
115
    optimizer.lr = learning_rate.numpy()
116

117
    loss_value = train_step(batch, model, optimizer)
118

119
    # Update the global step. We update it before logging the loss and saving
120
    # the model so that the last checkpoint is saved at the last iteration.
121
    global_step.assign_add(1)
122

123
    # Log the training loss.
124
    if not global_step % 100:
125
      logging.info("Step: %s, Loss: %.6f, Time: %s",
126
                   global_step.numpy(), loss_value,
127
                   datetime.timedelta(seconds=time.time() - start))
128

129
    # We save the checkpoints every 1000 iterations.
130
    if not global_step  % 1000:
131
      # Save the checkpoint of the model.
132
      saved_ckpt = ckpt_manager.save()
133
      logging.info("Saved checkpoint: %s", saved_ckpt)
134

135

136
if __name__ == "__main__":
137
  app.run(main)
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.