google-research
456 строк · 18.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train and eval yoto models.
17
18The script works in one of two modes (passed as --schedule). It can either
19train, which instantiates the models and calls TPUEstimator.train, or run in
20eval mode, which waits for checkpoints to be produced, saves them to tfhub
21modules, computes accuracies and logs them.
22"""
23import enum
24import functools
25import os.path
26
27from absl import app
28from absl import flags
29from absl import logging
30
31import gin
32import gin.tf.external_configurables
33import tensorflow.compat.v1 as tf
34from tensorflow.compat.v1 import estimator as tf_estimator
35
36# We need to import them so that gin can discover them.
37from yoto import architectures # pylint: disable=unused-import
38from yoto import optimizers
39from yoto import problems # pylint: disable=unused-import
40from yoto.optimizers import distributions
41from yoto.utils import data
42from yoto.utils import preprocessing
43from tensorflow.python.training import checkpoint_utils # pylint: disable=g-direct-tensorflow-import
44
45
46FLAGS = flags.FLAGS
47
48flags.DEFINE_string("model_dir", None, "Where to store files.")
49flags.DEFINE_string(
50"schedule", "train", "Schedule to run. Options: 'train' and 'eval'.")
51flags.DEFINE_multi_string(
52"gin_config", [],
53"List of paths to the config files.")
54flags.DEFINE_multi_string(
55"gin_bindings", [],
56"Newline separated list of Gin parameter bindings.")
57flags.DEFINE_bool("use_tpu", False, "Whether running on TPU or not.")
58flags.DEFINE_integer("seed", 0, "The random seed.")
59flags.DEFINE_integer("validation_percent", 20,
60"Percent of training data to be used for validation.")
61flags.DEFINE_string(
62"master", None, "Name of the TensorFlow master to use. Defaults to GPU.")
63
64
65@gin.constants_from_enum
66class Task(enum.Enum):
67VARIATIONAL_AUTOENCODER = 0
68
69
70@gin.configurable("TrainingParams")
71class TrainingParams(object):
72"""Parameters for network training.
73
74Includes learning rate with schedule and weight decay.
75"""
76
77def __init__(self, initial_lr=gin.REQUIRED, lr_decay_factor=gin.REQUIRED,
78lr_decay_steps_str=gin.REQUIRED, weight_decay=gin.REQUIRED):
79self.initial_lr = initial_lr
80self.lr_decay_factor = lr_decay_factor
81self.lr_decay_steps_str = lr_decay_steps_str
82self.weight_decay = weight_decay
83
84
85def iterate_checkpoints_until_file_exists(checkpoints_dir,
86path_to_file,
87timeout_in_mins=60):
88"""Yields checkpoints as long as the file does not exist."""
89remaining_mins = timeout_in_mins
90last_checkpoint = None
91while remaining_mins > 0:
92checkpoint = checkpoint_utils.wait_for_new_checkpoint(
93checkpoints_dir, last_checkpoint=last_checkpoint, timeout=60) # 1 min.
94if checkpoint:
95last_checkpoint = checkpoint
96remaining_mins = timeout_in_mins # Reset the remaining time.
97yield checkpoint
98elif tf.gfile.Exists(path_to_file):
99logging.info("Found %s, exiting", path_to_file)
100return
101else:
102remaining_mins -= 1
103
104
105def get_decay_op(weight_decay, learning_rate, opt_step, vars_to_decay=None):
106"""Generates the weight decay op for the given variables."""
107with tf.control_dependencies([opt_step]):
108if vars_to_decay is None:
109vars_to_decay = tf.trainable_variables()
110decay_ops = []
111for v in vars_to_decay:
112decayed_val = v * (1. - learning_rate * weight_decay)
113decay_ops.append(v.assign(decayed_val))
114decay_op = tf.group(decay_ops)
115return decay_op
116
117
118def get_learning_rate(training_params):
119"""Produces piece-wise learning rate tensor that decays exponentially.
120
121Args:
122training_params: TrainingParams instance.
123training_params.initial_lr: initial learning rate.
124training_params.lr_decay_steps_str: a list of step numbers, for which
125learning rate decay should be performed.
126training_params.lr_decay_factor: learning rate decay factor.
127
128Returns:
129lr: Learning rate tensor that decays exponentially according to given
130parameters.
131"""
132
133initial_lr = training_params.initial_lr
134lr_decay_factor = training_params.lr_decay_factor
135lr_decay_steps_str = training_params.lr_decay_steps_str
136if lr_decay_steps_str:
137global_step = tf.train.get_or_create_global_step()
138lr_decay_steps = [int(s) for s in lr_decay_steps_str.split(",")]
139
140lr = tf.train.piecewise_constant(
141global_step,
142lr_decay_steps,
143[initial_lr * (lr_decay_factor ** i)
144for i in range(len(lr_decay_steps) + 1)]
145)
146else:
147lr = initial_lr
148return lr
149
150
151def get_optimizer(optimizer_class, learning_rate, use_tpu):
152optimizer = optimizer_class(learning_rate=learning_rate)
153if use_tpu:
154optimizer = tf.tpu.CrossShardOptimizer(optimizer)
155return optimizer
156
157
158def construct_model_fn(problem, optimizer_class, base_optimizer_class,
159eval_weights=None, eval_num_samples=10,
160training_params_class=None,
161training_params_conditioning_class=None,
162base_optimizer_conditioning_class=None):
163"""Constructs a model_fn for the given problem and optimizer.
164
165Args:
166problem: An instance of the Problem class, defining the learning problem.
167optimizer_class: MultiLossOptimizer class (gin-injected), used to generate
168an instance used to optimize the problem. This optimizer handles
169problems with parametrized loss functions.
170base_optimizer_class: A tf.Optimizer class (gin-injected), used to create
171an optimizer instance which is actually used to minimize the objective.
172eval_weights: a specification of eval_weights, either as a random
173distribution or as a list of weight dictionaries (see
174distributions.get_samples_as_dicts for details)
175eval_num_samples: Int. If eval_weights are given as a distribution, this
176defines how many vectors to sample from it for evaluation.
177training_params_class: TrainingParams class (gin_injected). Stores training
178parameters (learning rate parameters as in get_learning_rate(...) and
179weight_decay).
180training_params_conditioning_class: TrainingParams class (gin_injected).
181Same as training_params_class, but, if provided, to be used for the
182conditioning part of the network.
183base_optimizer_conditioning_class: A tf.Optimizer class (gin-injected).
184If proivided, used to create an optimizer instance that minimizes the
185objective for the conditioning variables.
186
187Returns:
188model_fn: A function that creates a model, to be used by TPU Estimator.
189"""
190def model_fn(features, mode, params):
191"""Returns a TPU estimator spec for the task at hand."""
192problem.initialize_model()
193optimizer = optimizer_class(problem, batch_size=params["batch_size"])
194training_params = training_params_class()
195learning_rate_normal = get_learning_rate(training_params)
196separate_conditioning_optimizer = (
197training_params_conditioning_class and base_optimizer_conditioning_class
198and isinstance(optimizer,
199optimizers.MultiLossOptimizerWithConditioning))
200if not separate_conditioning_optimizer and (
201training_params_conditioning_class
202or base_optimizer_conditioning_class):
203raise ValueError("training_params_conditioning_class and "
204"base_optimizer_conditioning_class should be provided "
205"together and only when the optimizer is "
206"MultiLossOptimizerWithConditioning.")
207
208tf.logging.info("separate_conditioning_optimizer: %s",
209separate_conditioning_optimizer)
210
211if separate_conditioning_optimizer:
212training_params_conditioning = training_params_conditioning_class()
213learning_rate_conditioning = get_learning_rate(
214training_params_conditioning)
215
216if mode == tf_estimator.ModeKeys.TRAIN:
217
218base_optimizer = get_optimizer(base_optimizer_class, learning_rate_normal,
219params["use_tpu"])
220if separate_conditioning_optimizer:
221base_optimizer_conditioning = get_optimizer(
222base_optimizer_conditioning_class, learning_rate_conditioning,
223params["use_tpu"])
224loss, opt_step = optimizer.compute_train_loss_and_update_op(
225features, base_optimizer, base_optimizer_conditioning)
226all_vars_str = "\n".join([str(v) for v in optimizer.all_vars])
227normal_vars_str = "\n".join([str(v) for v in optimizer.normal_vars])
228conditioning_vars_str = "\n".join([str(v) for
229v in optimizer.conditioning_vars])
230tf.logging.info("\n\nall_vars\n %s", all_vars_str)
231tf.logging.info("\n\nnormal_vars\n %s", normal_vars_str)
232tf.logging.info("\n\nconditioning_vars\n %s", conditioning_vars_str)
233else:
234loss, opt_step = optimizer.compute_train_loss_and_update_op(
235features, base_optimizer)
236
237# weight decay op
238decay_op = get_decay_op(training_params.weight_decay,
239learning_rate_normal, opt_step,
240vars_to_decay=optimizer.normal_vars)
241if separate_conditioning_optimizer:
242decay_op_conditioning = get_decay_op(
243training_params_conditioning.weight_decay,
244learning_rate_conditioning,
245opt_step, vars_to_decay=optimizer.conditioning_vars)
246decay_op = tf.group([decay_op, decay_op_conditioning])
247# batch norm update ops
248update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
249train_op = tf.group([opt_step, decay_op] + update_ops)
250return tf_estimator.tpu.TPUEstimatorSpec(
251mode=mode, loss=loss, train_op=train_op)
252elif mode == tf_estimator.ModeKeys.EVAL:
253def unstack_metrics(**metrics):
254"""Unstack separate metrics from one big aggregate tensor.
255
256This is needed because otherwise evaluation on TPU with many metrics
257gets horribly slow. Concatenating all metrics into one tensor makes
258things much better.
259
260Args:
261**metrics: Dict[ Str: tf.Tensor ]. Dictionary with one element, for
262which the key the concatenation of all metric names separated by "!"
263and the value are all metric values stacked along axis 1.
264
265Returns:
266metrics_dict: Dict[ Str: tf.Tensor ]. Dictionary mapping metrics names
267to tensors with their per-sample values.
268"""
269if len(metrics) != 1:
270raise ValueError("Stacked metrics dict should have one element, got "
271"{}".format(len(metrics)))
272names_stacked = list(metrics.keys())[0]
273values_stacked = metrics[names_stacked]
274names = names_stacked.split("!")
275values = tf.unstack(values_stacked, axis=1)
276return {name: tf.metrics.mean(value) for name, value in
277zip(names, values)}
278
279loss = optimizer.compute_eval_loss(features)
280
281if isinstance(optimizer, optimizers.MultiLossOptimizerWithConditioning):
282sampled_weights = distributions.get_samples_as_dicts(
283eval_weights, num_samples=eval_num_samples,
284names=problem.losses_keys, seed=17)
285all_metrics = {}
286for idx, weights in enumerate(sampled_weights):
287with tf.variable_scope("", reuse=tf.AUTO_REUSE):
288losses_id, metrics_id = \
289optimizer.compute_eval_losses_and_metrics_for_weights(features,
290weights)
291all_metrics.update({"{}/{}".format(key, idx): value
292for key, value in losses_id.items()})
293all_metrics.update({"{}/{}".format(key, idx): value
294for key, value in metrics_id.items()})
295full_loss = 0.
296for loss_name in losses_id.keys():
297full_loss += weights[loss_name] * losses_id[loss_name]
298all_metrics.update({"full_loss/{}".format(idx): full_loss})
299else:
300with tf.variable_scope("", reuse=tf.AUTO_REUSE):
301losses, metrics = problem.losses_and_metrics(features, training=False)
302all_metrics = losses
303all_metrics.update(metrics)
304metrics_shape_out = all_metrics[list(all_metrics.keys())[0]].get_shape()
305# Need this broadcasting because on TPU all output tensors should have
306# the same shape
307all_metrics.update(
308{"learning_rate_normal": tf.broadcast_to(
309learning_rate_normal, metrics_shape_out)})
310if separate_conditioning_optimizer:
311all_metrics.update(
312{"learning_rate_conditioning": tf.broadcast_to(
313learning_rate_conditioning, metrics_shape_out)})
314# Stacking all metrics for efficiency (otherwise eval is horribly slow)
315sorted_keys = sorted(all_metrics.keys())
316sorted_values = [all_metrics[key] for key in sorted_keys]
317metrics_stacked = {"!".join(sorted_keys): tf.stack(sorted_values, axis=1)}
318return tf_estimator.tpu.TPUEstimatorSpec(
319mode=mode,
320loss=loss,
321eval_metrics=(unstack_metrics, metrics_stacked))
322else:
323raise ValueError("Unknown mode: {}".format(mode))
324
325return model_fn
326
327
328@gin.configurable("experiment")
329def run(model_dir,
330schedule,
331problem_class=gin.REQUIRED,
332optimizer_class=gin.REQUIRED,
333dataset_name=gin.REQUIRED,
334batch_size=gin.REQUIRED,
335eval_batch_size=64,
336train_steps=gin.REQUIRED,
337eval_steps=gin.REQUIRED,
338base_optimizer_class=gin.REQUIRED,
339base_optimizer_conditioning_class=None,
340iterations_per_loop=gin.REQUIRED,
341eval_weights=None,
342training_params_class=gin.REQUIRED,
343training_params_conditioning_class=None,
344preprocess="",
345preprocess_eval="",
346save_checkpoints_steps=None,
347keep_checkpoint_max=0,
348eval_on_test=False):
349"""Main training function. Most of the parameters come from Gin."""
350assert schedule in ("train", "eval")
351
352if save_checkpoints_steps:
353kwargs = {"save_checkpoints_steps": save_checkpoints_steps}
354else:
355kwargs = {"save_checkpoints_secs": 60*10} # Every 10 minutes.
356
357run_config = tf_estimator.tpu.RunConfig(
358keep_checkpoint_max=keep_checkpoint_max,
359master=FLAGS.master,
360evaluation_master=FLAGS.master,
361tpu_config=tf_estimator.tpu.TPUConfig(
362iterations_per_loop=iterations_per_loop),
363**kwargs)
364# We use one estimator (potentially on TPU) for training and evaluation.
365problem = problem_class()
366model_fn = construct_model_fn(
367problem, optimizer_class, base_optimizer_class,
368eval_weights=eval_weights,
369base_optimizer_conditioning_class=base_optimizer_conditioning_class,
370training_params_class=training_params_class,
371training_params_conditioning_class=training_params_conditioning_class)
372tpu_estimator = tf_estimator.tpu.TPUEstimator(
373use_tpu=FLAGS.use_tpu,
374model_fn=model_fn,
375model_dir=model_dir,
376train_batch_size=batch_size,
377eval_batch_size=eval_batch_size,
378config=run_config)
379
380
381def input_fn_train(params):
382preprocess_fn = preprocessing.get_preprocess_fn(preprocess)
383return data.get_dataset(dataset_name, data.DatasetSplit.TRAIN,
384FLAGS.validation_percent, params["batch_size"],
385preprocess_fn)
386
387def input_fn_eval(params, split):
388preprocess_fn = preprocessing.get_preprocess_fn(preprocess_eval)
389return data.get_dataset(dataset_name, split, FLAGS.validation_percent,
390params["batch_size"], preprocess_fn).repeat()
391
392path_to_finished_file = os.path.join(model_dir, "FINISHED")
393if schedule == "train":
394gin_hook = gin.tf.GinConfigSaverHook(model_dir, summarize_config=True)
395tpu_estimator.train(input_fn=input_fn_train,
396hooks=[gin_hook],
397max_steps=train_steps)
398with tf.gfile.GFile(path_to_finished_file, "w") as finished_file:
399finished_file.write("1")
400else:
401for checkpoint in iterate_checkpoints_until_file_exists(
402model_dir, path_to_finished_file):
403if eval_on_test:
404train_split = data.DatasetSplit.TRAIN_FULL
405test_split = data.DatasetSplit.TEST
406test_summary_name = "test"
407else:
408train_split = data.DatasetSplit.TRAIN
409test_split = data.DatasetSplit.VALID
410test_summary_name = "valid"
411
412eval_train = tpu_estimator.evaluate(
413input_fn=functools.partial(input_fn_eval, split=train_split),
414checkpoint_path=checkpoint,
415steps=eval_steps,
416name="train")
417eval_test = tpu_estimator.evaluate(
418input_fn=functools.partial(input_fn_eval, split=test_split),
419checkpoint_path=checkpoint,
420steps=eval_steps,
421name="test")
422
423current_step = eval_train["global_step"]
424
425
426hub_modules_dir = os.path.join(model_dir, "hub_modules")
427if not tf.gfile.Exists(hub_modules_dir):
428tf.gfile.MkDir(hub_modules_dir)
429else:
430if not tf.gfile.IsDirectory(hub_modules_dir):
431raise ValueError("{0} exists and is not a directory".format(
432hub_modules_dir))
433
434hub_module_path = os.path.join(hub_modules_dir,
435"step-{:0>9}".format(current_step))
436if not tf.gfile.Exists(hub_module_path):
437problem.module_spec.export(hub_module_path,
438checkpoint_path=checkpoint)
439else:
440logging.info("Not saving the hub module, since the path"
441" %s already exists", hub_module_path)
442
443
444def main(argv):
445if len(argv) > 1:
446raise ValueError("Too many command-line arguments.")
447logging.info("Gin config: %s\nGin bindings: %s",
448FLAGS.gin_config, FLAGS.gin_bindings)
449gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
450run(model_dir=FLAGS.model_dir, schedule=FLAGS.schedule)
451
452
453if __name__ == "__main__":
454tf.disable_v2_behavior()
455flags.mark_flag_as_required("model_dir")
456app.run(main)
457