google-research

Форк
0
90 строк · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Configuration and hyperparameter sweeps."""
17
# pylint: disable=line-too-long
18

19
import ml_collections
20

21

22
def get_config():
23
  """Get the default hyperparameter configuration."""
24
  config = ml_collections.ConfigDict()
25
  config.seed = 42
26
  config.trial = 0  # Dummy for repeated runs.
27

28
  config.dataset = "billiard-max-left-right-test"
29
  config.train_preprocess_str = "to_float_0_1|pad(ensure_small=(1100, 1100))|random_crop(resolution=(1000, 1000))"
30
  config.eval_preprocess_str = "to_float_0_1"
31

32
  config.append_position_to_input = False
33
  config.downsample_input_factor = 1
34

35
  # Top-k extraction.
36
  config.model = "patchnet"
37
  config.k = 10
38
  config.patch_size = 100
39
  config.downscale = 4
40
  config.feature_network = "ResNet18"
41

42
  config.aggregation_method = "transformer"
43
  config.aggregation_method_kwargs = ml_collections.ConfigDict()
44
  config.aggregation_method_kwargs.num_layers = 3
45
  config.aggregation_method_kwargs.num_heads = 8
46
  config.aggregation_method_kwargs.dim_hidden = 256
47
  config.aggregation_method_kwargs.pooling = "sum"
48

49
  config.selection_method = "perturbed-topk"
50
  config.selection_method_inference = "hard-topk"
51
  config.entropy_regularizer = -0.01
52
  config.entropy_before_normalization = True
53

54
  config.scorer_has_se = False
55

56
  # parameters for sinkhorn topk
57
  config.sinkhorn_topk_kwargs = ml_collections.ConfigDict()
58
  config.sinkhorn_topk_kwargs.epsilon = 1e-4
59
  config.sinkhorn_topk_kwargs.num_iterations = 2000
60

61
  # parameters for perturbed_topk
62
  config.perturbed_topk_kwargs = ml_collections.ConfigDict()
63
  config.perturbed_topk_kwargs.num_samples = 500
64
  config.perturbed_topk_kwargs.sigma = 0.05
65
  config.linear_decrease_perturbed_sigma = True
66

67
  config.normalization_str = "zerooneeps(1e-5)"
68
  config.use_iterative_extraction = True
69

70
  # Same set up as usual.
71
  config.optimizer = "adam"
72
  config.learning_rate = 1e-4
73
  config.gradient_value_clip = 1.
74
  config.momentum = .9
75
  config.weight_decay = 1e-4
76
  config.cosine_decay = True
77
  config.warmup_ratio = 0.05
78
  config.batch_size = 64
79
  config.num_train_steps = 30_000
80

81
  config.log_loss_every_steps = 100
82
  config.eval_every_steps = 1000
83
  config.checkpoint_every_steps = 5000
84
  config.log_images = True
85
  config.log_histograms = False
86
  config.do_eval_only = False
87

88
  config.trial = 0  # Dummy for repeated runs.
89

90
  return config
91

92

93

94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.