google-research

Форк
0
149 строк · 4.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main file for inference evaluater."""
17

18
from absl import app
19
from absl import flags
20
from absl import logging
21
from clu import platform
22
import jax
23
import ml_collections
24
import tensorflow as tf
25

26
from sudoku_gpt.inference import inference_evaluater
27

28
logging.set_verbosity(logging.INFO)
29

30
FLAGS = flags.FLAGS
31

32
_WORKDIR = flags.DEFINE_string(
33
    'workdir', None, 'Directory to store model data.')
34
_CKPT_LOC = flags.DEFINE_string('ckpt_loc', None, 'Checkpoint location.')
35
flags.mark_flags_as_required(['workdir', 'ckpt_loc'])
36

37

38
def get_config():
39
  """Get the default hyperparameter configuration.
40

41
  Dataset choices:
42
  othello: For othello game
43

44
  sudoku: Sudoku game but fixed order (order: row-wise left to right)
45
  ordered-sudoku: Sudoku game data with the order of solver
46
  ordered-sudoku-wo-random-guessing-w-candidates-train-test: Uses sudoku games
47
              that can be solved with 7 human logics. It has train-test split.
48
              Does not contain examples with random guessing. Has penciling
49
              candidates for 10 locations and strategies used for each of the
50
              move.
51

52
  Returns:
53
    A ConfigDict with all the experiment related settings.
54
  """
55
  config = ml_collections.ConfigDict()
56
  config.dataset = 'ordered-sudoku'
57
  config.sampling_method = 'greedy-row-col'
58

59
  config.restore_checkpoint = False  # Not implemented yet
60

61
  ### Training related parameters
62
  config.max_steps = 2**22
63
  config.dtype = jax.numpy.bfloat16
64
  config.minibatch_size = 64
65

66
  if 'sudoku' in config.dataset:
67
    config.block_size = 81
68
    config.seq_len = 3 * config.block_size
69
    config.vocab_size = 11
70
    config.start_index = 32
71
    config.set_accuracy = 'top-k'  # Choice = "top-k", "all"
72
    config.set_accuracy_top_k = 20
73
  elif config.dataset == 'othello':
74
    config.block_size = 60
75
    config.seq_len = config.block_size
76
    config.vocab_size = 65
77
    config.start_index = 0  # Does not get used
78

79
  ### Model related parameters
80
  config.num_heads = 8
81
  config.num_layers = 8
82
  config.emb_dim = 576
83
  config.qkv_dim = 576
84
  config.mlp_dim = 6 * config.qkv_dim
85
  config.dropout_rate = 0.2
86
  config.attention_dropout_rate = 0.1
87
  config.optimizer = 'adamw'
88

89
  ### Training related parameters
90
  config.learning_rate = 1e-4  # Base learning rate.
91
  config.end_lr_factor = 0.2
92
  config.warmup_tokens = 2**10
93
  config.weight_decay = 5e-3
94
  config.seed = 9
95
  config.save_checkpoint = True
96
  config.save_every_steps = 2**13
97
  config.num_examples = 10
98

99
  ### Evaluation related parameters
100
  config.eval_every_steps = 500
101
  config.eval_epochs = 10
102
  config.beam_search_n = 3  ## Always use < 9
103

104
  if config.dataset == 'othello':
105
    config.dataset_path = None
106
  elif config.dataset == 'sudoku':
107
    config.dataset_path = None
108
  elif config.dataset == 'ordered-sudoku':
109
    config.dataset_path = None
110
  elif (
111
      config.dataset
112
      == 'ordered-sudoku-wo-random-guessing-w-candidates-train-test'
113
  ):
114
    config.test_puzzle_path = None
115
    config.test_cands_path = None
116

117
  return config
118

119

120
def main(argv):
121
  if len(argv) > 1:
122
    raise app.UsageError('Too many command-line arguments.')
123

124
  # # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
125
  # # it unavailable to JAX.
126
  tf.config.experimental.set_visible_devices([], 'GPU')
127

128
  logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
129
  logging.info('JAX local devices: %r', jax.local_devices())
130

131
  # Add a note so that we can tell which task is which JAX host.
132
  # (Depending on the platform task 0 is not guaranteed to be host 0)
133
  platform.work_unit().set_task_status(
134
      f'process_index: {jax.process_index()}, '
135
      f'process_count: {jax.process_count()}'
136
  )
137
  platform.work_unit().create_artifact(
138
      platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir'
139
  )
140

141
  cfgs = get_config()
142
  logging.info(cfgs)
143

144
  inference_evaluater.inference_evaluate(cfgs, _WORKDIR.value, _CKPT_LOC.value)
145

146

147
if __name__ == '__main__':
148
  jax.config.config_with_absl()
149
  app.run(main)
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.