google-research

Форк
0
/
bert_layout_coco_config.py 
77 строк · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Default Hyperparameter configuration."""
17

18
import ml_collections
19

20

21
def get_config():
22
  """Gets the default hyperparameter configuration."""
23

24
  config = ml_collections.ConfigDict()
25
  # Exp info
26
  config.dataset_path = "/path/to/coco"
27
  config.dataset = "COCO"
28
  config.vocab_size = 316
29
  config.experiment = "bert_layout"
30
  config.model_class = "bert_layout"
31
  config.image_size = 256
32

33
  # Training info
34
  config.layout_dim = 2
35
  config.seed = 0
36
  config.log_every_steps = 100
37
  config.eval_num_steps = 1000
38
  config.max_length = 128
39
  config.batch_size = 64
40
  config.train_shuffle = True
41
  config.eval_pad_last_batch = False
42
  config.eval_batch_size = 64
43
  config.num_train_steps = 100_000
44
  config.checkpoint_every_steps = 5000
45
  config.eval_every_steps = 1000
46
  config.num_eval_steps = 100
47

48
  # Model info
49
  config.num_dim = 2
50
  config.dtype = "float32"
51
  config.autoregressive = False
52
  config.shuffle_buffer_size = 10
53
  config.use_vae = True
54
  config.share_embeddings = True
55
  config.num_layers = 4
56
  config.qkv_dim = 512
57
  config.emb_dim = 512
58
  config.mlp_dim = 2048
59
  config.num_heads = 8
60
  config.dropout_rate = 0.1
61
  config.attention_dropout_rate = 0.1
62
  config.restore_checkpoints = True
63
  config.label_smoothing = 0.
64
  config.sampling_method = "top-p"
65
  config.use_vertical_info = False
66

67
  # Optimizer info
68
  config.optimizer = ml_collections.ConfigDict()
69
  config.optimizer.type = "adam"
70
  config.optimizer.warmup_steps = 4000
71
  config.optimizer.lr = 5e-4
72
  config.optimizer.beta1 = 0.9
73
  config.optimizer.beta2 = 0.98
74
  config.optimizer.weight_decay = 0.01
75
  config.beta_rate = 1 / 20_000
76

77
  return config
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.