google-research

Форк
0
197 строк · 6.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utils for working with slim argscopes."""
17

18
import tensorflow.compat.v1 as tf
19
import tf_slim  as slim
20

21
import network_params
22

23

24
def get_conv_scope(params,
25
                   is_training = True):
26
  """Constructs an argscope for configuring CNNs.
27

28
  Note that the scope returned captures any existing scope from within which
29
  this function is called. The returned scope however is absolute and overrides
30
  any outside scope -- this implies that using it within a new scope renders the
31
  new scope redundant. Example:
32

33
    with slim.arg_scope(...) as existing_sc:
34
      sc = get_conv_scope(...)  # `sc` captures `existing_sc`.
35

36
    with slim.arg_scope(...) as new_sc:
37
      with slim.arg_scope(sc):
38
        ...  # This context does NOT capture `new_sc`; `sc` is absolute.
39

40
      # Correct way to capture `new_sc` by calling from within the scope.
41
      new_conv_sc = get_conv_scope(...)
42
      with slim.arg_scope(new_conv_sc):
43
        ...  # `new_conv_sc` captures `new_sc`.
44

45
  Args:
46
    params: `ParameterContainer` containing the model params.
47
    is_training: whether model is meant to be trained or not.
48

49
  Returns:
50
    sc: a `slim.arg_scope` that sets the context for convolutional layers based
51
      on `params` and the context from which `get_conv_scope` is called. Note
52
      that using `sc` overrides any outside `arg_scope`; see docstring for more
53
      info.
54
  """
55
  sc_gen = _get_base_scope(float(params.l2_weight_decay))
56
  with sc_gen:
57
    sc = slim.current_arg_scope()
58
  if params.batch_norm:
59
    batch_norm_sc_gen = _get_batch_norm_scope(
60
        is_training, decay=params.batch_norm_decay)
61
    sc = _update_arg_scope(sc, batch_norm_sc_gen)
62
  if params.dropout:
63
    dropout_sc_gen = _get_dropout_scope(
64
        is_training, keep_prob=params.dropout_keep_prob)
65
    sc = _update_arg_scope(sc, dropout_sc_gen)
66
  return sc
67

68

69
def _update_arg_scope(base_sc, override_sc_gen):
70
  """Override kwargs for ops in `base_sc` with those from `override_sc_gen`.
71

72
  Args:
73
    base_sc: base `arg_scope` containing ops mapped to their kwargs.
74
    override_sc_gen: a `slim.arg_scope` generator whose `arg_scope` will
75
      override the base scope.
76

77
  Returns:
78
    A new `arg_scope` that overrides 'base_sc` using overrides generated from
79
      `override_sc_gen`.
80
  """
81
  with slim.arg_scope(base_sc):
82
    with override_sc_gen:
83
      return slim.current_arg_scope()
84

85

86
def _get_base_scope(weight_decay=0.00004):
87
  """Defines the default arg scope.
88

89
  Args:
90
    weight_decay: The weight decay to use for regularizing the conv weights.
91

92
  Returns:
93
    A `slim.arg_scope` generator.
94
  """
95
  base_scope_args = _get_base_scope_args(weight_decay)
96
  sc_gen = slim.arg_scope(
97
      [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],
98
      **base_scope_args)
99
  return sc_gen
100

101

102
def _get_batch_norm_scope(is_training,
103
                          decay=0.9997,
104
                          init_stddev=0.1,
105
                          batch_norm_var_collection='moving_vars'):
106
  """Defines an arg scope for configuring batch_norm in conv2d layers.
107

108
  Args:
109
    is_training: Whether or not we're training the model.
110
    decay: Decay factor for moving averages used for eval.
111
    init_stddev: The standard deviation of the trunctated normal weight init.
112
    batch_norm_var_collection: The name of the collection for the batch norm
113
      variables.
114

115
  Returns:
116
    An `arg_scope` generator to induce batch_norm normalization in conv2d
117
      layers.
118
  """
119
  batch_norm_scope_args = _get_batch_norm_scope_args(is_training, decay,
120
                                                     init_stddev,
121
                                                     batch_norm_var_collection)
122
  sc_gen = slim.arg_scope(
123
      [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],
124
      **batch_norm_scope_args)
125
  return sc_gen
126

127

128
def _get_dropout_scope(is_training, keep_prob=0.8):
129
  """Defines an arg scope for configuring dropout after slim.conv2d layers.
130

131
  Args:
132
    is_training: Whether or not we're training the model.
133
    keep_prob: The probability that each element is kept.
134

135
  Returns:
136
    An `arg_scope` generator to induce dropout normalization in slim.conv2d
137
      layers.
138
  """
139
  dropout_scope_args = _get_dropout_scope_args(is_training, keep_prob)
140
  sc_gen = slim.arg_scope([slim.conv2d, slim.separable_conv2d],
141
                          **dropout_scope_args)
142
  return sc_gen
143

144

145
def _get_base_scope_args(weight_decay):
146
  """Returns arguments needed to initialize the base `arg_scope`."""
147
  regularizer = slim.l2_regularizer(weight_decay)
148
  conv_weights_init = slim.xavier_initializer_conv2d()
149
  base_scope_args = {
150
      'weights_initializer': conv_weights_init,
151
      'activation_fn': tf.nn.relu,
152
      'weights_regularizer': regularizer,
153
  }
154
  return base_scope_args
155

156

157
def _get_batch_norm_scope_args(is_training, decay, init_stddev,
158
                               batch_norm_var_collection):
159
  """Returns arguments needed to initialize the batch norm `arg_scope`."""
160
  batch_norm_params = {
161
      'is_training': is_training,
162
      # Decay for the moving averages.
163
      'decay': decay,
164
      # epsilon to prevent 0s in variance.
165
      'epsilon': 0.001,
166
      # collection containing the moving mean and moving variance.
167
      'variables_collections': {
168
          'beta': None,
169
          'gamma': None,
170
          'moving_mean': [batch_norm_var_collection],
171
          'moving_variance': [batch_norm_var_collection],
172
      },
173
      'zero_debias_moving_mean': False,
174
  }
175
  batch_norm_scope_args = {
176
      'normalizer_fn': slim.batch_norm,
177
      'normalizer_params': batch_norm_params,
178
      'weights_initializer': tf.truncated_normal_initializer(stddev=init_stddev)
179
  }
180
  return batch_norm_scope_args
181

182

183
def _get_dropout_scope_args(is_training, keep_prob):
184
  """Returns arguments needed to initialize the batch norm `arg_scope`."""
185
  dropout_scope_args = {
186
      'activation_fn': _get_relu_then_dropout(is_training, keep_prob),
187
  }
188
  return dropout_scope_args
189

190

191
def _get_relu_then_dropout(is_training, keep_prob):
192

193
  def relu_then_dropout(x):
194
    x = tf.nn.relu(x)
195
    return slim.dropout(x, is_training=is_training, keep_prob=keep_prob)
196

197
  return relu_then_dropout
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.