google-research
55 строк · 1.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17"""Trainer binary."""
18from absl import app19from absl import flags20import gin21import jax22import tensorflow.compat.v2 as tf23
24from moe_mtl.main import trainer_utils25_CONFIG_PATH = flags.DEFINE_string('config_path', None, 'paths to gin config.')26_CONFIG_PARAM = flags.DEFINE_multi_string(27'gin_bindings', None, 'Newline separated list of Gin parameter bindings.')28_OUTPUT_DIR = flags.DEFINE_string('output_dir', None,29'Path to model checkpoints/summaries.')30_AES = flags.DEFINE_bool(31'aes', False, 'launch with AES'32)
33
34
35def main(_):36
37# jax.config.update('jax_default_matmul_precision', _DEFAULT_PRECISION.value)38tf.enable_v2_behavior()39# make sure tf does not allocate gpu memory40tf.config.experimental.set_visible_devices([], 'GPU')41config_params = _CONFIG_PARAM.value or []42# enable relative paths within p5x configs.43gin.add_config_file_search_path('third_party/py/t5x/configs')44gin.parse_config_files_and_bindings([_CONFIG_PATH.value], config_params)45if not _AES.value:46trainer_utils.train_vmoe_mtl(_OUTPUT_DIR.value)47else:48trainer_utils.train_and_validate_vmoe_mtl(_OUTPUT_DIR.value)49# Wait until computations are done before exiting50jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()51
52
53if __name__ == '__main__':54flags.mark_flags_as_required(['output_dir', 'config_path'])55app.run(main)56