google-research

Форк
0
90 строк · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Library of tuned hparams and functions for converting to ModelOptions."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23

24
from uq_benchmark_2019.cifar import data_lib
25
from uq_benchmark_2019.cifar import models_lib
26

27

28
HParams = collections.namedtuple(
29
    'CifarHparams', ['batch_size', 'init_learning_rate', 'dropout_rate',
30
                     'init_prior_scale_mean',
31
                     'init_prior_scale_std', 'std_prior_scale'])
32

33
_HPS_VANILLA = HParams(7, 0.000717, 0, None, None, None)
34
_HPS_DROPOUT = HParams(5, 0.000250, 0.054988, None, None, None)
35

36
_HPS_LL_SVI = HParams(
37
    16, 0.00115285, 0,
38
    init_prior_scale_mean=-2.73995,
39
    init_prior_scale_std=-3.61795,
40
    std_prior_scale=4.85503)
41

42
_HPS_SVI = HParams(
43
    107, 0.001189, 0,
44
    init_prior_scale_mean=-1.9994,
45
    init_prior_scale_std=-0.30840,
46
    std_prior_scale=3.4210)
47

48
_HPS_LL_DROPOUT = HParams(16, 0.000313, 0.319811, None, None, None)
49

50
HPS_DICT = dict(
51
    vanilla=_HPS_VANILLA,
52
    dropout=_HPS_DROPOUT,
53
    dropout_nofirst=_HPS_DROPOUT,
54
    svi=_HPS_SVI,
55
    ll_dropout=_HPS_LL_DROPOUT,
56
    ll_svi=_HPS_LL_SVI,
57
    wide_dropout=_HPS_DROPOUT,
58
)
59

60

61
def model_opts_from_hparams(hps, method, fake_training=False):
62
  """Returns a ModelOptions instance using given hyperparameters."""
63
  dropout_rate = hps.dropout_rate if hasattr(hps, 'dropout_rate') else 0
64
  variational = method in ('svi', 'll_svi')
65

66
  model_opts = models_lib.ModelOptions(
67
      # Modeling params
68
      method=method,
69
      resnet_depth=20,
70
      num_resnet_filters=32 if method == 'wide_dropout' else 16,
71
      # Data params.
72
      image_shape=data_lib.CIFAR_SHAPE,
73
      num_classes=data_lib.CIFAR_NUM_CLASSES,
74
      examples_per_epoch=data_lib.CIFAR_NUM_TRAIN_EXAMPLES,
75
      # SGD params
76
      train_epochs=200,
77
      batch_size=hps.batch_size,
78
      dropout_rate=dropout_rate,
79
      init_learning_rate=hps.init_learning_rate,
80
      # Variational params
81
      std_prior_scale=hps.std_prior_scale if variational else None,
82
      init_prior_scale_mean=hps.init_prior_scale_mean if variational else None,
83
      init_prior_scale_std=hps.init_prior_scale_std if variational else None,
84
  )
85

86
  if fake_training:
87
    model_opts.batch_size = 32
88
    model_opts.examples_per_epoch = 256
89
    model_opts.train_epochs = 1
90
  return model_opts
91

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.