google-research

Форк
0
63 строки · 1.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Configuration and hyperparameter sweeps."""
17
# pylint: disable=line-too-long
18

19
import ml_collections
20

21

22
def get_config():
23
  """Get the default hyperparameter configuration."""
24
  config = ml_collections.ConfigDict()
25
  config.seed = 42
26
  config.trial = 0  # Dummy for repeated runs.
27

28
  config.dataset = "trafficsigns"
29
  config.train_preprocess_str = "to_float_0_1|pad(ensure_small=(1160, 1480))|random_crop(resolution=(960, 1280))|random_linear_transform((.8,1.2),(-.1,.1),.8)"
30
  config.eval_preprocess_str = "to_float_0_1"
31

32
  # Top-k extraction.
33
  config.model = "ats-traffic"
34

35
  # Same set up as usual.
36
  config.optimizer = "adam"
37
  config.learning_rate = 1e-4
38
  config.gradient_value_clip = 1.
39
  config.momentum = .9
40

41
  config.weight_decay = 1e-4
42
  config.cosine_decay = True
43
  config.warmup_ratio = 0.
44
  config.batch_size = 32
45
  config.num_train_steps = 70_000
46

47
  config.log_loss_every_steps = 50
48
  config.eval_every_steps = 1000
49
  config.checkpoint_every_steps = 5000
50

51
  config.trial = 0  # Dummy for repeated runs.
52

53
  return config
54

55

56
def get_sweep(h):
57
  """Get the hyperparamater sweep."""
58
  sweeps = []
59

60
  sweeps.append(h.sweep("config.seed", range(5)))
61
  sweeps.append(h.sweep("config.learning_rate", [1e-3, 5e-4, 1e-4, 5e-5]))
62

63
  return h.product(sweeps)
64

65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.