google-research
90 строк · 2.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Configuration and hyperparameter sweeps."""
17# pylint: disable=line-too-long
18
19import ml_collections20
21
22def get_config():23"""Get the default hyperparameter configuration."""24config = ml_collections.ConfigDict()25config.seed = 4226config.trial = 0 # Dummy for repeated runs.27
28config.dataset = "billiard-max-left-right-test"29config.train_preprocess_str = "to_float_0_1|pad(ensure_small=(1100, 1100))|random_crop(resolution=(1000, 1000))"30config.eval_preprocess_str = "to_float_0_1"31
32config.append_position_to_input = False33config.downsample_input_factor = 134
35# Top-k extraction.36config.model = "patchnet"37config.k = 1038config.patch_size = 10039config.downscale = 440config.feature_network = "ResNet18"41
42config.aggregation_method = "transformer"43config.aggregation_method_kwargs = ml_collections.ConfigDict()44config.aggregation_method_kwargs.num_layers = 345config.aggregation_method_kwargs.num_heads = 846config.aggregation_method_kwargs.dim_hidden = 25647config.aggregation_method_kwargs.pooling = "sum"48
49config.selection_method = "perturbed-topk"50config.selection_method_inference = "hard-topk"51config.entropy_regularizer = -0.0152config.entropy_before_normalization = True53
54config.scorer_has_se = False55
56# parameters for sinkhorn topk57config.sinkhorn_topk_kwargs = ml_collections.ConfigDict()58config.sinkhorn_topk_kwargs.epsilon = 1e-459config.sinkhorn_topk_kwargs.num_iterations = 200060
61# parameters for perturbed_topk62config.perturbed_topk_kwargs = ml_collections.ConfigDict()63config.perturbed_topk_kwargs.num_samples = 50064config.perturbed_topk_kwargs.sigma = 0.0565config.linear_decrease_perturbed_sigma = True66
67config.normalization_str = "zerooneeps(1e-5)"68config.use_iterative_extraction = True69
70# Same set up as usual.71config.optimizer = "adam"72config.learning_rate = 1e-473config.gradient_value_clip = 1.74config.momentum = .975config.weight_decay = 1e-476config.cosine_decay = True77config.warmup_ratio = 0.0578config.batch_size = 6479config.num_train_steps = 30_00080
81config.log_loss_every_steps = 10082config.eval_every_steps = 100083config.checkpoint_every_steps = 500084config.log_images = True85config.log_histograms = False86config.do_eval_only = False87
88config.trial = 0 # Dummy for repeated runs.89
90return config91
92
93
94