google-research
55 строк · 1.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17"""Trainer binary."""
18from absl import app
19from absl import flags
20import gin
21import jax
22import tensorflow.compat.v2 as tf
23
24from moe_mtl.main import trainer_utils
25_CONFIG_PATH = flags.DEFINE_string('config_path', None, 'paths to gin config.')
26_CONFIG_PARAM = flags.DEFINE_multi_string(
27'gin_bindings', None, 'Newline separated list of Gin parameter bindings.')
28_OUTPUT_DIR = flags.DEFINE_string('output_dir', None,
29'Path to model checkpoints/summaries.')
30_AES = flags.DEFINE_bool(
31'aes', False, 'launch with AES'
32)
33
34
35def main(_):
36
37# jax.config.update('jax_default_matmul_precision', _DEFAULT_PRECISION.value)
38tf.enable_v2_behavior()
39# make sure tf does not allocate gpu memory
40tf.config.experimental.set_visible_devices([], 'GPU')
41config_params = _CONFIG_PARAM.value or []
42# enable relative paths within p5x configs.
43gin.add_config_file_search_path('third_party/py/t5x/configs')
44gin.parse_config_files_and_bindings([_CONFIG_PATH.value], config_params)
45if not _AES.value:
46trainer_utils.evaluate_vmoe_mtl(_OUTPUT_DIR.value)
47else:
48trainer_utils.evaluate_vmoe_mtl_with_aes(_OUTPUT_DIR.value)
49# Wait until computations are done before exiting
50jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
51
52
53if __name__ == '__main__':
54flags.mark_flags_as_required(['output_dir', 'config_path'])
55app.run(main)
56