google-research

Форк
0
150 строк · 4.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main file for Sudoku GPT experiments."""
17

18
from absl import app
19
from absl import flags
20
from absl import logging
21
from clu import platform
22
import jax
23
import ml_collections
24
from ml_collections import config_flags
25
import tensorflow as tf
26

27
from sudoku_gpt import train_and_evaluate
28

29

30
logging.set_verbosity(logging.INFO)
31

32
FLAGS = flags.FLAGS
33

34
_WORKDIR = flags.DEFINE_string(
35
    'workdir',
36
    None,
37
    'Directory to store model data.')
38
config_flags.DEFINE_config_file(
39
    'config',
40
    None,
41
    'File path to the training hyperparameter configuration.',
42
    lock_config=True)
43
flags.mark_flags_as_required(['config', 'workdir'])
44

45

46
def get_config():
47
  """Get the default hyperparameter configuration.
48

49
  Dataset choices:
50
    othello: For othello game
51
    sudoku: Sudoku game but fixed order (order: row-wise left to right)
52
    ordered-sudoku: Sudoku game data with the order of solver
53
    ordered-sudoku-wo-random-guessing-w-candidates-train-test: Uses sudoku games
54
              that can be solved with 7 human logics. It has train-test split.
55
              Does not contain examples with random guessing. Has penciling
56
              candidates for 10 locations and strategies used for each of the
57
              move.
58
  Returns:
59
    A ConfigDict object.
60
  """
61

62
  config = ml_collections.ConfigDict()
63
  config.dataset = 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
64

65
  ### Training related parameters
66
  config.max_steps = 2**22
67
  config.dtype = jax.numpy.bfloat16
68
  config.minibatch_size = 64
69
  if 'sudoku' in config.dataset:
70
    config.block_size = 81
71
    config.seq_len = 3*config.block_size
72
    config.vocab_size = 11
73
    config.start_index = 31
74
  elif config.dataset == 'othello':
75
    config.block_size = 60
76
    config.seq_len = config.block_size
77
    config.vocab_size = 65
78
    config.start_index = 0  # Does not get used
79

80
  ### Model related parameters
81
  config.num_heads = 8
82
  config.num_layers = 8
83
  config.emb_dim = 576
84
  config.qkv_dim = 576
85
  config.mlp_dim = 6 * config.emb_dim
86
  config.dropout_rate = 0.2
87
  config.attention_dropout_rate = 0.2
88

89
  ### Training related parameters
90
  config.learning_rate = 2e-4  # Base learning rate.
91
  config.end_lr_factor = 0.2
92
  config.warmup_tokens = 10000
93
  config.weight_decay = 5e-3
94
  config.optimizer = 'adamw'
95

96
  config.seed = 7
97
  config.save_checkpoint = True
98
  config.save_every_steps = 8000
99

100
  ### Evaluation related parameters
101
  config.eval_every_steps = 1000
102
  config.eval_epochs = 5
103

104
  # Need to set config.dataset paths
105
  if config.dataset == 'othello':
106
    config.dataset_path = None
107
  elif config.dataset == 'sudoku':
108
    config.dataset_path = None
109
  elif config.dataset == 'ordered-sudoku':
110
    config.dataset_path = None
111
  elif (
112
      config.dataset
113
      == 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
114
  ):
115
    config.train_puzzle_path = None
116
    config.test_puzzle_path = None
117
    config.train_candidate_path = None
118
    config.test_candidate_path = None
119

120
  return config
121

122

123
def main(argv):
124
  if len(argv) > 1:
125
    raise app.UsageError('Too many command-line arguments.')
126

127
  # # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
128
  # # it unavailable to JAX.
129
  tf.config.experimental.set_visible_devices([], 'GPU')
130

131
  logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
132
  logging.info('JAX local devices: %r', jax.local_devices())
133

134
  # Add a note so that we can tell which task is which JAX host.
135
  # (Depending on the platform task 0 is not guaranteed to be host 0)
136
  platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
137
                                       f'process_count: {jax.process_count()}')
138
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
139
                                       _WORKDIR.value, 'workdir')
140

141
  cfgs = get_config()
142
  logging.info(cfgs)
143

144
  cfgs.workdir = _WORKDIR.value
145
  train_and_evaluate.train_and_evaluate(cfgs, _WORKDIR.value)
146

147

148
if __name__ == '__main__':
149
  jax.config.config_with_absl()
150
  app.run(main)
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.