google-research
149 строк · 4.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for inference evaluater."""
17
18from absl import app
19from absl import flags
20from absl import logging
21from clu import platform
22import jax
23import ml_collections
24import tensorflow as tf
25
26from sudoku_gpt.inference import inference_evaluater
27
28logging.set_verbosity(logging.INFO)
29
30FLAGS = flags.FLAGS
31
32_WORKDIR = flags.DEFINE_string(
33'workdir', None, 'Directory to store model data.')
34_CKPT_LOC = flags.DEFINE_string('ckpt_loc', None, 'Checkpoint location.')
35flags.mark_flags_as_required(['workdir', 'ckpt_loc'])
36
37
38def get_config():
39"""Get the default hyperparameter configuration.
40
41Dataset choices:
42othello: For othello game
43
44sudoku: Sudoku game but fixed order (order: row-wise left to right)
45ordered-sudoku: Sudoku game data with the order of solver
46ordered-sudoku-wo-random-guessing-w-candidates-train-test: Uses sudoku games
47that can be solved with 7 human logics. It has train-test split.
48Does not contain examples with random guessing. Has penciling
49candidates for 10 locations and strategies used for each of the
50move.
51
52Returns:
53A ConfigDict with all the experiment related settings.
54"""
55config = ml_collections.ConfigDict()
56config.dataset = 'ordered-sudoku'
57config.sampling_method = 'greedy-row-col'
58
59config.restore_checkpoint = False # Not implemented yet
60
61### Training related parameters
62config.max_steps = 2**22
63config.dtype = jax.numpy.bfloat16
64config.minibatch_size = 64
65
66if 'sudoku' in config.dataset:
67config.block_size = 81
68config.seq_len = 3 * config.block_size
69config.vocab_size = 11
70config.start_index = 32
71config.set_accuracy = 'top-k' # Choice = "top-k", "all"
72config.set_accuracy_top_k = 20
73elif config.dataset == 'othello':
74config.block_size = 60
75config.seq_len = config.block_size
76config.vocab_size = 65
77config.start_index = 0 # Does not get used
78
79### Model related parameters
80config.num_heads = 8
81config.num_layers = 8
82config.emb_dim = 576
83config.qkv_dim = 576
84config.mlp_dim = 6 * config.qkv_dim
85config.dropout_rate = 0.2
86config.attention_dropout_rate = 0.1
87config.optimizer = 'adamw'
88
89### Training related parameters
90config.learning_rate = 1e-4 # Base learning rate.
91config.end_lr_factor = 0.2
92config.warmup_tokens = 2**10
93config.weight_decay = 5e-3
94config.seed = 9
95config.save_checkpoint = True
96config.save_every_steps = 2**13
97config.num_examples = 10
98
99### Evaluation related parameters
100config.eval_every_steps = 500
101config.eval_epochs = 10
102config.beam_search_n = 3 ## Always use < 9
103
104if config.dataset == 'othello':
105config.dataset_path = None
106elif config.dataset == 'sudoku':
107config.dataset_path = None
108elif config.dataset == 'ordered-sudoku':
109config.dataset_path = None
110elif (
111config.dataset
112== 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
113):
114config.test_puzzle_path = None
115config.test_cands_path = None
116
117return config
118
119
120def main(argv):
121if len(argv) > 1:
122raise app.UsageError('Too many command-line arguments.')
123
124# # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
125# # it unavailable to JAX.
126tf.config.experimental.set_visible_devices([], 'GPU')
127
128logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
129logging.info('JAX local devices: %r', jax.local_devices())
130
131# Add a note so that we can tell which task is which JAX host.
132# (Depending on the platform task 0 is not guaranteed to be host 0)
133platform.work_unit().set_task_status(
134f'process_index: {jax.process_index()}, '
135f'process_count: {jax.process_count()}'
136)
137platform.work_unit().create_artifact(
138platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir'
139)
140
141cfgs = get_config()
142logging.info(cfgs)
143
144inference_evaluater.inference_evaluate(cfgs, _WORKDIR.value, _CKPT_LOC.value)
145
146
147if __name__ == '__main__':
148jax.config.config_with_absl()
149app.run(main)
150