google-research
66 строк · 2.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for running the DP-GNN training pipeline."""
17
18from typing import Sequence
19from absl import app
20from absl import flags
21from absl import logging
22from clu import platform
23import jax
24from ml_collections import config_flags
25import tensorflow as tf
26
27from differentially_private_gnns import train
28
29
30FLAGS = flags.FLAGS
31
32_WORKDIR = flags.DEFINE_string(
33'workdir',
34None,
35'Directory to store model data.')
36_CONFIG = config_flags.DEFINE_config_file(
37'config',
38None,
39'File path to the training hyperparameter configuration.',
40lock_config=True)
41
42
43def main(argv):
44if len(argv) > 1:
45raise app.UsageError('Too many command-line arguments.')
46
47# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
48# it unavailable to JAX.
49tf.config.experimental.set_visible_devices([], 'GPU')
50
51logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
52logging.info('JAX local devices: %r', jax.local_devices())
53
54# Add a note so that we can tell which task is which JAX host.
55# (Depending on the platform task 0 is not guaranteed to be host 0)
56platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
57f'process_count: {jax.process_count()}')
58platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
59_WORKDIR.value, 'workdir')
60
61train.train_and_evaluate(_CONFIG.value, _WORKDIR.value)
62
63
64if __name__ == '__main__':
65flags.mark_flags_as_required(['config', 'workdir'])
66app.run(main)
67