google-research

Форк
0
63 строки · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main entry point for running the training pipeline."""
17

18
from typing import Sequence
19
from absl import app
20
from absl import flags
21
from absl import logging
22
from clu import platform
23
import jax
24
from ml_collections import config_flags
25
import tensorflow as tf
26

27
from action_angle_networks import train
28

29

30
_WORKDIR = flags.DEFINE_string('workdir', None,
31
                               'Directory to store model data.')
32
_CONFIG = config_flags.DEFINE_config_file(
33
    'config',
34
    None,
35
    'File path to the training hyperparameter configuration.',
36
    lock_config=True)
37

38

39
def main(argv):
40
  if len(argv) > 1:
41
    raise app.UsageError('Too many command-line arguments.')
42

43
  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
44
  # it unavailable to JAX.
45
  tf.config.experimental.set_visible_devices([], 'GPU')
46

47
  # This example only supports single-host training on a single device.
48
  logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count())
49
  logging.info('JAX local devices: %r', jax.local_devices())
50

51
  # Add a note so that we can tell which task is which JAX host.
52
  # (Depending on the platform task 0 is not guaranteed to be host 0)
53
  platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
54
                                       f'process_count: {jax.process_count()}')
55
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
56
                                       _WORKDIR.value, 'workdir')
57

58
  train.train_and_evaluate(_CONFIG.value, _WORKDIR.value)
59

60

61
if __name__ == '__main__':
62
  flags.mark_flags_as_required(['config', 'workdir'])
63
  app.run(main)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.