google-research
63 строки · 1.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Configuration and hyperparameter sweeps."""
17# pylint: disable=line-too-long
18
19import ml_collections20
21
22def get_config():23"""Get the default hyperparameter configuration."""24config = ml_collections.ConfigDict()25config.seed = 4226config.trial = 0 # Dummy for repeated runs.27
28config.dataset = "trafficsigns"29config.train_preprocess_str = "to_float_0_1|pad(ensure_small=(1160, 1480))|random_crop(resolution=(960, 1280))|random_linear_transform((.8,1.2),(-.1,.1),.8)"30config.eval_preprocess_str = "to_float_0_1"31
32# Top-k extraction.33config.model = "ats-traffic"34
35# Same set up as usual.36config.optimizer = "adam"37config.learning_rate = 1e-438config.gradient_value_clip = 1.39config.momentum = .940
41config.weight_decay = 1e-442config.cosine_decay = True43config.warmup_ratio = 0.44config.batch_size = 3245config.num_train_steps = 70_00046
47config.log_loss_every_steps = 5048config.eval_every_steps = 100049config.checkpoint_every_steps = 500050
51config.trial = 0 # Dummy for repeated runs.52
53return config54
55
56def get_sweep(h):57"""Get the hyperparamater sweep."""58sweeps = []59
60sweeps.append(h.sweep("config.seed", range(5)))61sweeps.append(h.sweep("config.learning_rate", [1e-3, 5e-4, 1e-4, 5e-5]))62
63return h.product(sweeps)64
65