google-research
86 строк · 2.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Configuration for trafficsigns_topk."""
17# pylint: disable=line-too-long
18
19import ml_collections
20
21
22def get_config():
23"""Get the default hyperparameter configuration."""
24config = ml_collections.ConfigDict()
25config.seed = 42
26config.trial = 0 # Dummy for repeated runs.
27
28config.dataset = "trafficsigns"
29config.train_preprocess_str = "to_float_0_1|pad(ensure_small=(1160, 1480))|random_crop(resolution=(960, 1280))|random_linear_transform((.8,1.2),(-.1,.1),.8)"
30config.eval_preprocess_str = "to_float_0_1"
31
32# Top-k extraction.
33config.model = "patchnet"
34config.k = 5
35config.patch_size = 100
36config.downscale = 3
37
38config.append_position_to_input = False
39config.feature_network = "ats-traffic"
40
41config.aggregation_method = "meanpooling"
42
43config.scorer_has_se = False
44
45config.selection_method = "perturbed-topk"
46config.selection_method_inference = "hard-topk"
47config.entropy_regularizer = -0.01
48config.entropy_before_normalization = True
49
50# parameters for perturbed_topk
51config.perturbed_topk_kwargs = ml_collections.ConfigDict()
52config.perturbed_topk_kwargs.num_samples = 500
53config.perturbed_topk_kwargs.sigma = 0.05
54config.linear_decrease_perturbed_sigma = True
55
56# parameters for sinkhorn topk
57config.sinkhorn_topk_kwargs = ml_collections.ConfigDict()
58config.sinkhorn_topk_kwargs.epsilon = 1e-4
59config.sinkhorn_topk_kwargs.num_iterations = 2000
60
61config.normalization_str = "zerooneeps(1e-5)"
62config.use_iterative_extraction = True
63
64config.optimizer = "adam"
65config.learning_rate = 1e-4
66config.gradient_value_clip = 0.1
67config.momentum = .9
68
69config.weight_decay = 1e-4
70config.cosine_decay = True
71config.warmup_ratio = 0.1
72config.batch_size = 32
73config.num_train_steps = 70_000
74
75config.log_loss_every_steps = 50
76config.eval_every_steps = 1000
77config.checkpoint_every_steps = 5000
78
79config.log_images = True
80config.log_histograms = False
81config.skip_nan_updates = True
82config.do_eval_only = False
83
84config.trial = 0 # Dummy for repeated runs.
85
86return config
87
88
89