google-research

Форк
0
456 строк · 18.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train and eval yoto models.
17

18
The script works in one of two modes (passed as --schedule). It can either
19
train, which instantiates the models and calls TPUEstimator.train, or run in
20
eval mode, which waits for checkpoints to be produced, saves them to tfhub
21
modules, computes accuracies and logs them.
22
"""
23
import enum
24
import functools
25
import os.path
26

27
from absl import app
28
from absl import flags
29
from absl import logging
30

31
import gin
32
import gin.tf.external_configurables
33
import tensorflow.compat.v1 as tf
34
from tensorflow.compat.v1 import estimator as tf_estimator
35

36
# We need to import them so that gin can discover them.
37
from yoto import architectures  # pylint: disable=unused-import
38
from yoto import optimizers
39
from yoto import problems  # pylint: disable=unused-import
40
from yoto.optimizers import distributions
41
from yoto.utils import data
42
from yoto.utils import preprocessing
43
from tensorflow.python.training import checkpoint_utils  # pylint: disable=g-direct-tensorflow-import
44

45

46
FLAGS = flags.FLAGS
47

48
flags.DEFINE_string("model_dir", None, "Where to store files.")
49
flags.DEFINE_string(
50
    "schedule", "train", "Schedule to run. Options: 'train' and 'eval'.")
51
flags.DEFINE_multi_string(
52
    "gin_config", [],
53
    "List of paths to the config files.")
54
flags.DEFINE_multi_string(
55
    "gin_bindings", [],
56
    "Newline separated list of Gin parameter bindings.")
57
flags.DEFINE_bool("use_tpu", False, "Whether running on TPU or not.")
58
flags.DEFINE_integer("seed", 0, "The random seed.")
59
flags.DEFINE_integer("validation_percent", 20,
60
                     "Percent of training data to be used for validation.")
61
flags.DEFINE_string(
62
    "master", None, "Name of the TensorFlow master to use. Defaults to GPU.")
63

64

65
@gin.constants_from_enum
66
class Task(enum.Enum):
67
  VARIATIONAL_AUTOENCODER = 0
68

69

70
@gin.configurable("TrainingParams")
71
class TrainingParams(object):
72
  """Parameters for network training.
73

74
  Includes learning rate with schedule and weight decay.
75
  """
76

77
  def __init__(self, initial_lr=gin.REQUIRED, lr_decay_factor=gin.REQUIRED,
78
               lr_decay_steps_str=gin.REQUIRED, weight_decay=gin.REQUIRED):
79
    self.initial_lr = initial_lr
80
    self.lr_decay_factor = lr_decay_factor
81
    self.lr_decay_steps_str = lr_decay_steps_str
82
    self.weight_decay = weight_decay
83

84

85
def iterate_checkpoints_until_file_exists(checkpoints_dir,
86
                                          path_to_file,
87
                                          timeout_in_mins=60):
88
  """Yields checkpoints as long as the file does not exist."""
89
  remaining_mins = timeout_in_mins
90
  last_checkpoint = None
91
  while remaining_mins > 0:
92
    checkpoint = checkpoint_utils.wait_for_new_checkpoint(
93
        checkpoints_dir, last_checkpoint=last_checkpoint, timeout=60)  # 1 min.
94
    if checkpoint:
95
      last_checkpoint = checkpoint
96
      remaining_mins = timeout_in_mins  # Reset the remaining time.
97
      yield checkpoint
98
    elif tf.gfile.Exists(path_to_file):
99
      logging.info("Found %s, exiting", path_to_file)
100
      return
101
    else:
102
      remaining_mins -= 1
103

104

105
def get_decay_op(weight_decay, learning_rate, opt_step, vars_to_decay=None):
106
  """Generates the weight decay op for the given variables."""
107
  with tf.control_dependencies([opt_step]):
108
    if vars_to_decay is None:
109
      vars_to_decay = tf.trainable_variables()
110
    decay_ops = []
111
    for v in vars_to_decay:
112
      decayed_val = v * (1. - learning_rate * weight_decay)
113
      decay_ops.append(v.assign(decayed_val))
114
    decay_op = tf.group(decay_ops)
115
  return decay_op
116

117

118
def get_learning_rate(training_params):
119
  """Produces piece-wise learning rate tensor that decays exponentially.
120

121
  Args:
122
    training_params: TrainingParams instance.
123
      training_params.initial_lr: initial learning rate.
124
      training_params.lr_decay_steps_str: a list of step numbers, for which
125
        learning rate decay should be performed.
126
      training_params.lr_decay_factor: learning rate decay factor.
127

128
  Returns:
129
    lr: Learning rate tensor that decays exponentially according to given
130
      parameters.
131
  """
132

133
  initial_lr = training_params.initial_lr
134
  lr_decay_factor = training_params.lr_decay_factor
135
  lr_decay_steps_str = training_params.lr_decay_steps_str
136
  if lr_decay_steps_str:
137
    global_step = tf.train.get_or_create_global_step()
138
    lr_decay_steps = [int(s) for s in lr_decay_steps_str.split(",")]
139

140
    lr = tf.train.piecewise_constant(
141
        global_step,
142
        lr_decay_steps,
143
        [initial_lr * (lr_decay_factor ** i)
144
         for i in range(len(lr_decay_steps) + 1)]
145
    )
146
  else:
147
    lr = initial_lr
148
  return lr
149

150

151
def get_optimizer(optimizer_class, learning_rate, use_tpu):
152
  optimizer = optimizer_class(learning_rate=learning_rate)
153
  if use_tpu:
154
    optimizer = tf.tpu.CrossShardOptimizer(optimizer)
155
  return optimizer
156

157

158
def construct_model_fn(problem, optimizer_class, base_optimizer_class,
159
                       eval_weights=None, eval_num_samples=10,
160
                       training_params_class=None,
161
                       training_params_conditioning_class=None,
162
                       base_optimizer_conditioning_class=None):
163
  """Constructs a model_fn for the given problem and optimizer.
164

165
  Args:
166
    problem: An instance of the Problem class, defining the learning problem.
167
    optimizer_class: MultiLossOptimizer class (gin-injected), used to generate
168
      an instance used to optimize the problem. This optimizer handles
169
      problems with parametrized loss functions.
170
    base_optimizer_class: A tf.Optimizer class (gin-injected), used to create
171
      an optimizer instance which is actually used to minimize the objective.
172
    eval_weights: a specification of eval_weights, either as a random
173
      distribution or as a list of weight dictionaries (see
174
      distributions.get_samples_as_dicts for details)
175
    eval_num_samples: Int. If eval_weights are given as a distribution, this
176
      defines how many vectors to sample from it for evaluation.
177
    training_params_class: TrainingParams class (gin_injected). Stores training
178
      parameters (learning rate parameters as in get_learning_rate(...) and
179
      weight_decay).
180
    training_params_conditioning_class: TrainingParams class (gin_injected).
181
      Same as training_params_class, but, if provided, to be used for the
182
      conditioning part of the network.
183
    base_optimizer_conditioning_class: A tf.Optimizer class (gin-injected).
184
      If proivided, used to create an optimizer instance that minimizes the
185
      objective for the conditioning variables.
186

187
  Returns:
188
    model_fn: A function that creates a model, to be used by TPU Estimator.
189
  """
190
  def model_fn(features, mode, params):
191
    """Returns a TPU estimator spec for the task at hand."""
192
    problem.initialize_model()
193
    optimizer = optimizer_class(problem, batch_size=params["batch_size"])
194
    training_params = training_params_class()
195
    learning_rate_normal = get_learning_rate(training_params)
196
    separate_conditioning_optimizer = (
197
        training_params_conditioning_class and base_optimizer_conditioning_class
198
        and isinstance(optimizer,
199
                       optimizers.MultiLossOptimizerWithConditioning))
200
    if not separate_conditioning_optimizer and (
201
        training_params_conditioning_class
202
        or base_optimizer_conditioning_class):
203
      raise ValueError("training_params_conditioning_class and "
204
                       "base_optimizer_conditioning_class should be provided "
205
                       "together and only when the optimizer is "
206
                       "MultiLossOptimizerWithConditioning.")
207

208
    tf.logging.info("separate_conditioning_optimizer: %s",
209
                    separate_conditioning_optimizer)
210

211
    if separate_conditioning_optimizer:
212
      training_params_conditioning = training_params_conditioning_class()
213
      learning_rate_conditioning = get_learning_rate(
214
          training_params_conditioning)
215

216
    if mode == tf_estimator.ModeKeys.TRAIN:
217

218
      base_optimizer = get_optimizer(base_optimizer_class, learning_rate_normal,
219
                                     params["use_tpu"])
220
      if separate_conditioning_optimizer:
221
        base_optimizer_conditioning = get_optimizer(
222
            base_optimizer_conditioning_class, learning_rate_conditioning,
223
            params["use_tpu"])
224
        loss, opt_step = optimizer.compute_train_loss_and_update_op(
225
            features, base_optimizer, base_optimizer_conditioning)
226
        all_vars_str = "\n".join([str(v) for v in optimizer.all_vars])
227
        normal_vars_str = "\n".join([str(v) for v in optimizer.normal_vars])
228
        conditioning_vars_str = "\n".join([str(v) for
229
                                           v in optimizer.conditioning_vars])
230
        tf.logging.info("\n\nall_vars\n %s", all_vars_str)
231
        tf.logging.info("\n\nnormal_vars\n %s", normal_vars_str)
232
        tf.logging.info("\n\nconditioning_vars\n %s", conditioning_vars_str)
233
      else:
234
        loss, opt_step = optimizer.compute_train_loss_and_update_op(
235
            features, base_optimizer)
236

237
      # weight decay op
238
      decay_op = get_decay_op(training_params.weight_decay,
239
                              learning_rate_normal, opt_step,
240
                              vars_to_decay=optimizer.normal_vars)
241
      if separate_conditioning_optimizer:
242
        decay_op_conditioning = get_decay_op(
243
            training_params_conditioning.weight_decay,
244
            learning_rate_conditioning,
245
            opt_step, vars_to_decay=optimizer.conditioning_vars)
246
        decay_op = tf.group([decay_op, decay_op_conditioning])
247
      # batch norm update ops
248
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
249
      train_op = tf.group([opt_step, decay_op] + update_ops)
250
      return tf_estimator.tpu.TPUEstimatorSpec(
251
          mode=mode, loss=loss, train_op=train_op)
252
    elif mode == tf_estimator.ModeKeys.EVAL:
253
      def unstack_metrics(**metrics):
254
        """Unstack separate metrics from one big aggregate tensor.
255

256
        This is needed because otherwise evaluation on TPU with many metrics
257
        gets horribly slow. Concatenating all metrics into one tensor makes
258
        things much better.
259

260
        Args:
261
          **metrics: Dict[ Str: tf.Tensor ]. Dictionary with one element, for
262
            which the key the concatenation of all metric names separated by "!"
263
            and the value are all metric values stacked along axis 1.
264

265
        Returns:
266
          metrics_dict: Dict[ Str: tf.Tensor ]. Dictionary mapping metrics names
267
            to tensors with their per-sample values.
268
        """
269
        if len(metrics) != 1:
270
          raise ValueError("Stacked metrics dict should have one element, got "
271
                           "{}".format(len(metrics)))
272
        names_stacked = list(metrics.keys())[0]
273
        values_stacked = metrics[names_stacked]
274
        names = names_stacked.split("!")
275
        values = tf.unstack(values_stacked, axis=1)
276
        return {name: tf.metrics.mean(value) for name, value in
277
                zip(names, values)}
278

279
      loss = optimizer.compute_eval_loss(features)
280

281
      if isinstance(optimizer, optimizers.MultiLossOptimizerWithConditioning):
282
        sampled_weights = distributions.get_samples_as_dicts(
283
            eval_weights, num_samples=eval_num_samples,
284
            names=problem.losses_keys, seed=17)
285
        all_metrics = {}
286
        for idx, weights in enumerate(sampled_weights):
287
          with tf.variable_scope("", reuse=tf.AUTO_REUSE):
288
            losses_id, metrics_id = \
289
                optimizer.compute_eval_losses_and_metrics_for_weights(features,
290
                                                                      weights)
291
          all_metrics.update({"{}/{}".format(key, idx): value
292
                              for key, value in losses_id.items()})
293
          all_metrics.update({"{}/{}".format(key, idx): value
294
                              for key, value in metrics_id.items()})
295
          full_loss = 0.
296
          for loss_name in losses_id.keys():
297
            full_loss += weights[loss_name] * losses_id[loss_name]
298
          all_metrics.update({"full_loss/{}".format(idx): full_loss})
299
      else:
300
        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
301
          losses, metrics = problem.losses_and_metrics(features, training=False)
302
        all_metrics = losses
303
        all_metrics.update(metrics)
304
      metrics_shape_out = all_metrics[list(all_metrics.keys())[0]].get_shape()
305
      # Need this broadcasting because on TPU all output tensors should have
306
      # the same shape
307
      all_metrics.update(
308
          {"learning_rate_normal": tf.broadcast_to(
309
              learning_rate_normal, metrics_shape_out)})
310
      if separate_conditioning_optimizer:
311
        all_metrics.update(
312
            {"learning_rate_conditioning": tf.broadcast_to(
313
                learning_rate_conditioning, metrics_shape_out)})
314
      # Stacking all metrics for efficiency (otherwise eval is horribly slow)
315
      sorted_keys = sorted(all_metrics.keys())
316
      sorted_values = [all_metrics[key] for key in sorted_keys]
317
      metrics_stacked = {"!".join(sorted_keys): tf.stack(sorted_values, axis=1)}
318
      return tf_estimator.tpu.TPUEstimatorSpec(
319
          mode=mode,
320
          loss=loss,
321
          eval_metrics=(unstack_metrics, metrics_stacked))
322
    else:
323
      raise ValueError("Unknown mode: {}".format(mode))
324

325
  return model_fn
326

327

328
@gin.configurable("experiment")
329
def run(model_dir,
330
        schedule,
331
        problem_class=gin.REQUIRED,
332
        optimizer_class=gin.REQUIRED,
333
        dataset_name=gin.REQUIRED,
334
        batch_size=gin.REQUIRED,
335
        eval_batch_size=64,
336
        train_steps=gin.REQUIRED,
337
        eval_steps=gin.REQUIRED,
338
        base_optimizer_class=gin.REQUIRED,
339
        base_optimizer_conditioning_class=None,
340
        iterations_per_loop=gin.REQUIRED,
341
        eval_weights=None,
342
        training_params_class=gin.REQUIRED,
343
        training_params_conditioning_class=None,
344
        preprocess="",
345
        preprocess_eval="",
346
        save_checkpoints_steps=None,
347
        keep_checkpoint_max=0,
348
        eval_on_test=False):
349
  """Main training function. Most of the parameters come from Gin."""
350
  assert schedule in ("train", "eval")
351

352
  if save_checkpoints_steps:
353
    kwargs = {"save_checkpoints_steps": save_checkpoints_steps}
354
  else:
355
    kwargs = {"save_checkpoints_secs": 60*10}  # Every 10 minutes.
356

357
  run_config = tf_estimator.tpu.RunConfig(
358
      keep_checkpoint_max=keep_checkpoint_max,
359
      master=FLAGS.master,
360
      evaluation_master=FLAGS.master,
361
      tpu_config=tf_estimator.tpu.TPUConfig(
362
          iterations_per_loop=iterations_per_loop),
363
      **kwargs)
364
  # We use one estimator (potentially on TPU) for training and evaluation.
365
  problem = problem_class()
366
  model_fn = construct_model_fn(
367
      problem, optimizer_class, base_optimizer_class,
368
      eval_weights=eval_weights,
369
      base_optimizer_conditioning_class=base_optimizer_conditioning_class,
370
      training_params_class=training_params_class,
371
      training_params_conditioning_class=training_params_conditioning_class)
372
  tpu_estimator = tf_estimator.tpu.TPUEstimator(
373
      use_tpu=FLAGS.use_tpu,
374
      model_fn=model_fn,
375
      model_dir=model_dir,
376
      train_batch_size=batch_size,
377
      eval_batch_size=eval_batch_size,
378
      config=run_config)
379

380

381
  def input_fn_train(params):
382
    preprocess_fn = preprocessing.get_preprocess_fn(preprocess)
383
    return data.get_dataset(dataset_name, data.DatasetSplit.TRAIN,
384
                            FLAGS.validation_percent, params["batch_size"],
385
                            preprocess_fn)
386

387
  def input_fn_eval(params, split):
388
    preprocess_fn = preprocessing.get_preprocess_fn(preprocess_eval)
389
    return data.get_dataset(dataset_name, split, FLAGS.validation_percent,
390
                            params["batch_size"], preprocess_fn).repeat()
391

392
  path_to_finished_file = os.path.join(model_dir, "FINISHED")
393
  if schedule == "train":
394
    gin_hook = gin.tf.GinConfigSaverHook(model_dir, summarize_config=True)
395
    tpu_estimator.train(input_fn=input_fn_train,
396
                        hooks=[gin_hook],
397
                        max_steps=train_steps)
398
    with tf.gfile.GFile(path_to_finished_file, "w") as finished_file:
399
      finished_file.write("1")
400
  else:
401
    for checkpoint in iterate_checkpoints_until_file_exists(
402
        model_dir, path_to_finished_file):
403
      if eval_on_test:
404
        train_split = data.DatasetSplit.TRAIN_FULL
405
        test_split = data.DatasetSplit.TEST
406
        test_summary_name = "test"
407
      else:
408
        train_split = data.DatasetSplit.TRAIN
409
        test_split = data.DatasetSplit.VALID
410
        test_summary_name = "valid"
411

412
      eval_train = tpu_estimator.evaluate(
413
          input_fn=functools.partial(input_fn_eval, split=train_split),
414
          checkpoint_path=checkpoint,
415
          steps=eval_steps,
416
          name="train")
417
      eval_test = tpu_estimator.evaluate(
418
          input_fn=functools.partial(input_fn_eval, split=test_split),
419
          checkpoint_path=checkpoint,
420
          steps=eval_steps,
421
          name="test")
422

423
      current_step = eval_train["global_step"]
424

425

426
      hub_modules_dir = os.path.join(model_dir, "hub_modules")
427
      if not tf.gfile.Exists(hub_modules_dir):
428
        tf.gfile.MkDir(hub_modules_dir)
429
      else:
430
        if not tf.gfile.IsDirectory(hub_modules_dir):
431
          raise ValueError("{0} exists and is not a directory".format(
432
              hub_modules_dir))
433

434
      hub_module_path = os.path.join(hub_modules_dir,
435
                                     "step-{:0>9}".format(current_step))
436
      if not tf.gfile.Exists(hub_module_path):
437
        problem.module_spec.export(hub_module_path,
438
                                   checkpoint_path=checkpoint)
439
      else:
440
        logging.info("Not saving the hub module, since the path"
441
                     " %s already exists", hub_module_path)
442

443

444
def main(argv):
445
  if len(argv) > 1:
446
    raise ValueError("Too many command-line arguments.")
447
  logging.info("Gin config: %s\nGin bindings: %s",
448
               FLAGS.gin_config, FLAGS.gin_bindings)
449
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
450
  run(model_dir=FLAGS.model_dir, schedule=FLAGS.schedule)
451

452

453
if __name__ == "__main__":
454
  tf.disable_v2_behavior()
455
  flags.mark_flag_as_required("model_dir")
456
  app.run(main)
457

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.