google-research
197 строк · 6.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utils for working with slim argscopes."""
17
18import tensorflow.compat.v1 as tf19import tf_slim as slim20
21import network_params22
23
24def get_conv_scope(params,25is_training = True):26"""Constructs an argscope for configuring CNNs.27
28Note that the scope returned captures any existing scope from within which
29this function is called. The returned scope however is absolute and overrides
30any outside scope -- this implies that using it within a new scope renders the
31new scope redundant. Example:
32
33with slim.arg_scope(...) as existing_sc:
34sc = get_conv_scope(...) # `sc` captures `existing_sc`.
35
36with slim.arg_scope(...) as new_sc:
37with slim.arg_scope(sc):
38... # This context does NOT capture `new_sc`; `sc` is absolute.
39
40# Correct way to capture `new_sc` by calling from within the scope.
41new_conv_sc = get_conv_scope(...)
42with slim.arg_scope(new_conv_sc):
43... # `new_conv_sc` captures `new_sc`.
44
45Args:
46params: `ParameterContainer` containing the model params.
47is_training: whether model is meant to be trained or not.
48
49Returns:
50sc: a `slim.arg_scope` that sets the context for convolutional layers based
51on `params` and the context from which `get_conv_scope` is called. Note
52that using `sc` overrides any outside `arg_scope`; see docstring for more
53info.
54"""
55sc_gen = _get_base_scope(float(params.l2_weight_decay))56with sc_gen:57sc = slim.current_arg_scope()58if params.batch_norm:59batch_norm_sc_gen = _get_batch_norm_scope(60is_training, decay=params.batch_norm_decay)61sc = _update_arg_scope(sc, batch_norm_sc_gen)62if params.dropout:63dropout_sc_gen = _get_dropout_scope(64is_training, keep_prob=params.dropout_keep_prob)65sc = _update_arg_scope(sc, dropout_sc_gen)66return sc67
68
69def _update_arg_scope(base_sc, override_sc_gen):70"""Override kwargs for ops in `base_sc` with those from `override_sc_gen`.71
72Args:
73base_sc: base `arg_scope` containing ops mapped to their kwargs.
74override_sc_gen: a `slim.arg_scope` generator whose `arg_scope` will
75override the base scope.
76
77Returns:
78A new `arg_scope` that overrides 'base_sc` using overrides generated from
79`override_sc_gen`.
80"""
81with slim.arg_scope(base_sc):82with override_sc_gen:83return slim.current_arg_scope()84
85
86def _get_base_scope(weight_decay=0.00004):87"""Defines the default arg scope.88
89Args:
90weight_decay: The weight decay to use for regularizing the conv weights.
91
92Returns:
93A `slim.arg_scope` generator.
94"""
95base_scope_args = _get_base_scope_args(weight_decay)96sc_gen = slim.arg_scope(97[slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],98**base_scope_args)99return sc_gen100
101
102def _get_batch_norm_scope(is_training,103decay=0.9997,104init_stddev=0.1,105batch_norm_var_collection='moving_vars'):106"""Defines an arg scope for configuring batch_norm in conv2d layers.107
108Args:
109is_training: Whether or not we're training the model.
110decay: Decay factor for moving averages used for eval.
111init_stddev: The standard deviation of the trunctated normal weight init.
112batch_norm_var_collection: The name of the collection for the batch norm
113variables.
114
115Returns:
116An `arg_scope` generator to induce batch_norm normalization in conv2d
117layers.
118"""
119batch_norm_scope_args = _get_batch_norm_scope_args(is_training, decay,120init_stddev,121batch_norm_var_collection)122sc_gen = slim.arg_scope(123[slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],124**batch_norm_scope_args)125return sc_gen126
127
128def _get_dropout_scope(is_training, keep_prob=0.8):129"""Defines an arg scope for configuring dropout after slim.conv2d layers.130
131Args:
132is_training: Whether or not we're training the model.
133keep_prob: The probability that each element is kept.
134
135Returns:
136An `arg_scope` generator to induce dropout normalization in slim.conv2d
137layers.
138"""
139dropout_scope_args = _get_dropout_scope_args(is_training, keep_prob)140sc_gen = slim.arg_scope([slim.conv2d, slim.separable_conv2d],141**dropout_scope_args)142return sc_gen143
144
145def _get_base_scope_args(weight_decay):146"""Returns arguments needed to initialize the base `arg_scope`."""147regularizer = slim.l2_regularizer(weight_decay)148conv_weights_init = slim.xavier_initializer_conv2d()149base_scope_args = {150'weights_initializer': conv_weights_init,151'activation_fn': tf.nn.relu,152'weights_regularizer': regularizer,153}154return base_scope_args155
156
157def _get_batch_norm_scope_args(is_training, decay, init_stddev,158batch_norm_var_collection):159"""Returns arguments needed to initialize the batch norm `arg_scope`."""160batch_norm_params = {161'is_training': is_training,162# Decay for the moving averages.163'decay': decay,164# epsilon to prevent 0s in variance.165'epsilon': 0.001,166# collection containing the moving mean and moving variance.167'variables_collections': {168'beta': None,169'gamma': None,170'moving_mean': [batch_norm_var_collection],171'moving_variance': [batch_norm_var_collection],172},173'zero_debias_moving_mean': False,174}175batch_norm_scope_args = {176'normalizer_fn': slim.batch_norm,177'normalizer_params': batch_norm_params,178'weights_initializer': tf.truncated_normal_initializer(stddev=init_stddev)179}180return batch_norm_scope_args181
182
183def _get_dropout_scope_args(is_training, keep_prob):184"""Returns arguments needed to initialize the batch norm `arg_scope`."""185dropout_scope_args = {186'activation_fn': _get_relu_then_dropout(is_training, keep_prob),187}188return dropout_scope_args189
190
191def _get_relu_then_dropout(is_training, keep_prob):192
193def relu_then_dropout(x):194x = tf.nn.relu(x)195return slim.dropout(x, is_training=is_training, keep_prob=keep_prob)196
197return relu_then_dropout198