google-research
63 строки · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main entry point for running the training pipeline."""
17
18from typing import Sequence
19from absl import app
20from absl import flags
21from absl import logging
22from clu import platform
23import jax
24from ml_collections import config_flags
25import tensorflow as tf
26
27from action_angle_networks import train
28
29
30_WORKDIR = flags.DEFINE_string('workdir', None,
31'Directory to store model data.')
32_CONFIG = config_flags.DEFINE_config_file(
33'config',
34None,
35'File path to the training hyperparameter configuration.',
36lock_config=True)
37
38
39def main(argv):
40if len(argv) > 1:
41raise app.UsageError('Too many command-line arguments.')
42
43# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
44# it unavailable to JAX.
45tf.config.experimental.set_visible_devices([], 'GPU')
46
47# This example only supports single-host training on a single device.
48logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count())
49logging.info('JAX local devices: %r', jax.local_devices())
50
51# Add a note so that we can tell which task is which JAX host.
52# (Depending on the platform task 0 is not guaranteed to be host 0)
53platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
54f'process_count: {jax.process_count()}')
55platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
56_WORKDIR.value, 'workdir')
57
58train.train_and_evaluate(_CONFIG.value, _WORKDIR.value)
59
60
61if __name__ == '__main__':
62flags.mark_flags_as_required(['config', 'workdir'])
63app.run(main)
64