google-research
90 строк · 2.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Library of tuned hparams and functions for converting to ModelOptions."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import collections23
24from uq_benchmark_2019.cifar import data_lib25from uq_benchmark_2019.cifar import models_lib26
27
28HParams = collections.namedtuple(29'CifarHparams', ['batch_size', 'init_learning_rate', 'dropout_rate',30'init_prior_scale_mean',31'init_prior_scale_std', 'std_prior_scale'])32
33_HPS_VANILLA = HParams(7, 0.000717, 0, None, None, None)34_HPS_DROPOUT = HParams(5, 0.000250, 0.054988, None, None, None)35
36_HPS_LL_SVI = HParams(3716, 0.00115285, 0,38init_prior_scale_mean=-2.73995,39init_prior_scale_std=-3.61795,40std_prior_scale=4.85503)41
42_HPS_SVI = HParams(43107, 0.001189, 0,44init_prior_scale_mean=-1.9994,45init_prior_scale_std=-0.30840,46std_prior_scale=3.4210)47
48_HPS_LL_DROPOUT = HParams(16, 0.000313, 0.319811, None, None, None)49
50HPS_DICT = dict(51vanilla=_HPS_VANILLA,52dropout=_HPS_DROPOUT,53dropout_nofirst=_HPS_DROPOUT,54svi=_HPS_SVI,55ll_dropout=_HPS_LL_DROPOUT,56ll_svi=_HPS_LL_SVI,57wide_dropout=_HPS_DROPOUT,58)
59
60
61def model_opts_from_hparams(hps, method, fake_training=False):62"""Returns a ModelOptions instance using given hyperparameters."""63dropout_rate = hps.dropout_rate if hasattr(hps, 'dropout_rate') else 064variational = method in ('svi', 'll_svi')65
66model_opts = models_lib.ModelOptions(67# Modeling params68method=method,69resnet_depth=20,70num_resnet_filters=32 if method == 'wide_dropout' else 16,71# Data params.72image_shape=data_lib.CIFAR_SHAPE,73num_classes=data_lib.CIFAR_NUM_CLASSES,74examples_per_epoch=data_lib.CIFAR_NUM_TRAIN_EXAMPLES,75# SGD params76train_epochs=200,77batch_size=hps.batch_size,78dropout_rate=dropout_rate,79init_learning_rate=hps.init_learning_rate,80# Variational params81std_prior_scale=hps.std_prior_scale if variational else None,82init_prior_scale_mean=hps.init_prior_scale_mean if variational else None,83init_prior_scale_std=hps.init_prior_scale_std if variational else None,84)85
86if fake_training:87model_opts.batch_size = 3288model_opts.examples_per_epoch = 25689model_opts.train_epochs = 190return model_opts91