google-research
51 строка · 1.4 Кб
1# coding=utf-8
2# Copyright 2021 DeepMind Technologies Limited and the Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for running the example.
17
18This file is intentionally kept short.
19"""
20
21from absl import app22from absl import flags23
24import jax25from ml_collections import config_flags26import tensorflow as tf27
28from vdvae_flax import experiment29
30FLAGS = flags.FLAGS31
32config_flags.DEFINE_config_file(33"config", None, "Training configuration.", lock_config=True)34flags.DEFINE_string("workdir", None, "Work unit directory.")35flags.mark_flags_as_required(["config", "workdir"])36
37
38def main(argv):39del argv40
41# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make42# it unavailable to JAX.43tf.config.experimental.set_visible_devices([], "GPU")44
45exp = experiment.Experiment("train", FLAGS.config)46exp.train_and_evaluate(FLAGS.workdir)47
48
49if __name__ == "__main__":50jax.config.config_with_absl()51app.run(main)52