google-research
61 строка · 2.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataclasses for network parameter containers."""
17
18import dataclasses
19
20
21@dataclasses.dataclass
22class ConvScopeParams:
23"""Parameters for tf.slim arg_scope for conv layers."""
24# Whether to use dropout for conv layers.
25dropout: bool = False
26
27# Dropout regularization strength for conv layers.
28dropout_keep_prob: float = 0.8
29
30# Whether to use batch_norm for conv layers.
31batch_norm: bool = True
32
33# Decay factor for batch_norm in conv layers.
34batch_norm_decay: float = 0.99
35
36# L2 regularization strength on conv and clf weights.
37l2_weight_decay: float = 0.00004
38
39
40@dataclasses.dataclass
41class InceptionV3FCNParams:
42"""Parameters for configuring an InceptionV3FCN network."""
43# The receptive field size used by the network. Currently only two sizes are
44# supported: 911 and 129.
45receptive_field_size: int = 911
46
47# Prelogit dropout regularization strength.
48prelogit_dropout_keep_prob: float = 0.8
49
50# Scale number of filters in each Inception(V3) layer by this factor.
51# Minimum number of filters defaults to 16.
52depth_multiplier: float = 0.1
53
54# Minimum depth for the conv layers. Relevant only when depth_multiplier < 1.
55min_depth: int = 16
56
57# Stride used in inference. This stride should be a multiple of 16 as
58# InceptionV3 downsamples by 16 by the time it reaches its logits layer. If
59# set to 0, non-FCN mode is assumed and the output is squeezed from
60# (?, 1, 1, classes) to (?, classes).
61inception_fcn_stride: int = 0
62