google-research

Форк
0
/
transformer_coco_config.py 
76 строк · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Default Hyperparameter configuration."""
17

18
import ml_collections
19

20

21
def get_config():
22
  """Gets the default hyperparameter configuration."""
23

24
  config = ml_collections.ConfigDict()
25
  # Exp info
26
  config.dataset_path = "/path/to/coco"
27
  config.dataset = "COCO"
28
  config.vocab_size = 316
29
  config.experiment = "transformer"
30
  config.model_class = "transformer"
31
  config.image_size = 256
32

33
  # Training info
34
  config.seed = 0
35
  config.log_every_steps = 100
36
  config.eval_num_steps = 1000
37
  config.max_length = 128
38
  config.batch_size = 64
39
  config.train_shuffle = True
40
  config.eval_pad_last_batch = False
41
  config.eval_batch_size = 64
42
  config.num_train_steps = 100_000
43
  config.checkpoint_every_steps = 5000
44
  config.eval_every_steps = 1000
45
  config.num_eval_steps = 100
46

47
  # Model info
48
  config.layout_dim = 2
49
  config.autoregressive = True
50
  config.dtype = "float32"
51
  config.shuffle_buffer_size = 10
52
  config.use_vae = False
53
  config.share_embeddings = True
54
  config.num_layers = 4
55
  config.qkv_dim = 512
56
  config.emb_dim = 512
57
  config.mlp_dim = 2048
58
  config.num_heads = 8
59
  config.dropout_rate = 0.3
60
  config.attention_dropout_rate = 0.1
61
  config.restore_checkpoints = True
62
  config.label_smoothing = 0.
63
  config.sampling_method = "top-p"
64
  config.use_vertical_info = False
65

66
  # Optimizer info
67
  config.optimizer = ml_collections.ConfigDict()
68
  config.optimizer.type = "adam"
69
  config.optimizer.warmup_steps = 4000
70
  config.optimizer.lr = 1e-3
71
  config.optimizer.beta1 = 0.9
72
  config.optimizer.beta2 = 0.98
73
  config.optimizer.weight_decay = 0.01
74
  config.beta_rate = 1 / 20_000
75

76
  return config
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.