google-research
150 строк · 4.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for Sudoku GPT experiments."""
17
18from absl import app
19from absl import flags
20from absl import logging
21from clu import platform
22import jax
23import ml_collections
24from ml_collections import config_flags
25import tensorflow as tf
26
27from sudoku_gpt import train_and_evaluate
28
29
30logging.set_verbosity(logging.INFO)
31
32FLAGS = flags.FLAGS
33
34_WORKDIR = flags.DEFINE_string(
35'workdir',
36None,
37'Directory to store model data.')
38config_flags.DEFINE_config_file(
39'config',
40None,
41'File path to the training hyperparameter configuration.',
42lock_config=True)
43flags.mark_flags_as_required(['config', 'workdir'])
44
45
46def get_config():
47"""Get the default hyperparameter configuration.
48
49Dataset choices:
50othello: For othello game
51sudoku: Sudoku game but fixed order (order: row-wise left to right)
52ordered-sudoku: Sudoku game data with the order of solver
53ordered-sudoku-wo-random-guessing-w-candidates-train-test: Uses sudoku games
54that can be solved with 7 human logics. It has train-test split.
55Does not contain examples with random guessing. Has penciling
56candidates for 10 locations and strategies used for each of the
57move.
58Returns:
59A ConfigDict object.
60"""
61
62config = ml_collections.ConfigDict()
63config.dataset = 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
64
65### Training related parameters
66config.max_steps = 2**22
67config.dtype = jax.numpy.bfloat16
68config.minibatch_size = 64
69if 'sudoku' in config.dataset:
70config.block_size = 81
71config.seq_len = 3*config.block_size
72config.vocab_size = 11
73config.start_index = 31
74elif config.dataset == 'othello':
75config.block_size = 60
76config.seq_len = config.block_size
77config.vocab_size = 65
78config.start_index = 0 # Does not get used
79
80### Model related parameters
81config.num_heads = 8
82config.num_layers = 8
83config.emb_dim = 576
84config.qkv_dim = 576
85config.mlp_dim = 6 * config.emb_dim
86config.dropout_rate = 0.2
87config.attention_dropout_rate = 0.2
88
89### Training related parameters
90config.learning_rate = 2e-4 # Base learning rate.
91config.end_lr_factor = 0.2
92config.warmup_tokens = 10000
93config.weight_decay = 5e-3
94config.optimizer = 'adamw'
95
96config.seed = 7
97config.save_checkpoint = True
98config.save_every_steps = 8000
99
100### Evaluation related parameters
101config.eval_every_steps = 1000
102config.eval_epochs = 5
103
104# Need to set config.dataset paths
105if config.dataset == 'othello':
106config.dataset_path = None
107elif config.dataset == 'sudoku':
108config.dataset_path = None
109elif config.dataset == 'ordered-sudoku':
110config.dataset_path = None
111elif (
112config.dataset
113== 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
114):
115config.train_puzzle_path = None
116config.test_puzzle_path = None
117config.train_candidate_path = None
118config.test_candidate_path = None
119
120return config
121
122
123def main(argv):
124if len(argv) > 1:
125raise app.UsageError('Too many command-line arguments.')
126
127# # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
128# # it unavailable to JAX.
129tf.config.experimental.set_visible_devices([], 'GPU')
130
131logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
132logging.info('JAX local devices: %r', jax.local_devices())
133
134# Add a note so that we can tell which task is which JAX host.
135# (Depending on the platform task 0 is not guaranteed to be host 0)
136platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
137f'process_count: {jax.process_count()}')
138platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
139_WORKDIR.value, 'workdir')
140
141cfgs = get_config()
142logging.info(cfgs)
143
144cfgs.workdir = _WORKDIR.value
145train_and_evaluate.train_and_evaluate(cfgs, _WORKDIR.value)
146
147
148if __name__ == '__main__':
149jax.config.config_with_absl()
150app.run(main)
151