google-research

Форк
0
/
mobilenet_builder.py 
177 строк · 5.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""MobileNetV1 model and builder functions. Only for inference."""
17
import functools
18
import tensorflow.compat.v1 as tf
19

20
from sgk.mbv1 import layers
21

22
MOVING_AVERAGE_DECAY = 0.9
23
EPSILON = 1e-5
24

25

26
def _make_divisible(v, divisor=8, min_value=None):
27
  if min_value is None:
28
    min_value = divisor
29
  new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30
  # Make sure that round down does not go down by more than 10%.
31
  if new_v < 0.9 * v:
32
    new_v += divisor
33
  return new_v
34

35

36
def batch_norm_relu(x, fuse_batch_norm=False):
37
  """Batch normalization + ReLU."""
38
  if not fuse_batch_norm:
39
    inputs = tf.layers.batch_normalization(
40
        inputs=x,
41
        axis=1,
42
        momentum=MOVING_AVERAGE_DECAY,
43
        epsilon=EPSILON,
44
        center=True,
45
        scale=True,
46
        training=False,
47
        fused=True)
48
    return tf.nn.relu(inputs)
49
  return x
50

51

52
def mbv1_block_(inputs, filters, stride, block_id, cfg):
53
  """Standard building block for mobilenetv1 networks.
54

55
  Args:
56
    inputs:  Input tensor, float32 of size [batch, channels, height, width].
57
    filters: Int specifying number of filters for the first two convolutions.
58
    stride: Int specifying the stride. If stride >1, the input is downsampled.
59
    block_id: which block this is.
60
    cfg: Configuration for the model.
61

62
  Returns:
63
    The output activation tensor.
64
  """
65
  # Setup the depthwise convolution layer.
66
  depthwise_conv = layers.DepthwiseConv2D(
67
      kernel_size=3,
68
      strides=[1, 1, stride, stride],
69
      padding=[0, 0, 1, 1],
70
      activation=tf.nn.relu if cfg.fuse_bnbr else None,
71
      use_bias=cfg.fuse_bnbr,
72
      name='depthwise_nxn_%s' % block_id)
73

74
  # Depthwise convolution, batch norm, relu.
75
  depthwise_out = batch_norm_relu(depthwise_conv(inputs), cfg.fuse_bnbr)
76

77
  # Setup the 1x1 convolution layer.
78
  out_filters = _make_divisible(
79
      int(cfg.width * filters), divisor=1 if block_id == 0 else 8)
80
  end_point = 'contraction_1x1_%s' % block_id
81

82
  conv_fn = layers.Conv2D
83
  if cfg.block_config[block_id] == 'sparse':
84
    conv_fn = functools.partial(
85
        layers.SparseConv2D, nonzeros=cfg.block_nonzeros[block_id])
86

87
  contraction = conv_fn(
88
      out_filters,
89
      activation=tf.nn.relu if cfg.fuse_bnbr else None,
90
      use_bias=cfg.fuse_bnbr,
91
      name=end_point)
92

93
  # Run the 1x1 convolution followed by batch norm and relu.
94
  return batch_norm_relu(contraction(depthwise_out), cfg.fuse_bnbr)
95

96

97
def mobilenet_generator(cfg):
98
  """Generator for mobilenet v2 models.
99

100
  Args:
101
    cfg: Configuration for the model.
102

103
  Returns:
104
    Model `function` that takes in `inputs` and returns the output `Tensor`
105
    of the model.
106
  """
107

108
  def model(inputs):
109
    """Creation of the model graph."""
110
    with tf.variable_scope('mobilenet_model', reuse=tf.AUTO_REUSE):
111
      # Initial convolutional layer.
112
      initial_conv_filters = _make_divisible(32 * cfg.width)
113
      initial_conv = layers.Conv2D(
114
          filters=initial_conv_filters,
115
          kernel_size=3,
116
          stride=2,
117
          padding=[[0, 0], [0, 0], [1, 1], [1, 1]],
118
          activation=tf.nn.relu,
119
          use_bias=True,
120
          name='initial_conv')
121
      inputs = batch_norm_relu(initial_conv(inputs), cfg.fuse_bnbr)
122

123
      mb_block = functools.partial(mbv1_block_, cfg=cfg)
124

125
      # Core MobileNetV1 blocks.
126
      inputs = mb_block(inputs, filters=64, stride=1, block_id=0)
127
      inputs = mb_block(inputs, filters=128, stride=2, block_id=1)
128
      inputs = mb_block(inputs, filters=128, stride=1, block_id=2)
129
      inputs = mb_block(inputs, filters=256, stride=2, block_id=3)
130
      inputs = mb_block(inputs, filters=256, stride=1, block_id=4)
131
      inputs = mb_block(inputs, filters=512, stride=2, block_id=5)
132
      inputs = mb_block(inputs, filters=512, stride=1, block_id=6)
133
      inputs = mb_block(inputs, filters=512, stride=1, block_id=7)
134
      inputs = mb_block(inputs, filters=512, stride=1, block_id=8)
135
      inputs = mb_block(inputs, filters=512, stride=1, block_id=9)
136
      inputs = mb_block(inputs, filters=512, stride=1, block_id=10)
137
      inputs = mb_block(inputs, filters=1024, stride=2, block_id=11)
138
      inputs = mb_block(inputs, filters=1024, stride=1, block_id=12)
139

140
      # Pooling layer.
141
      inputs = tf.layers.average_pooling2d(
142
          inputs=inputs,
143
          pool_size=(inputs.shape[2], inputs.shape[3]),
144
          strides=1,
145
          padding='VALID',
146
          data_format='channels_first',
147
          name='final_avg_pool')
148

149
      # Reshape the output of the pooling layer to 2D for the
150
      # final fully-connected layer.
151
      last_block_filters = _make_divisible(int(1024 * cfg.width), 8)
152
      inputs = tf.reshape(inputs, [-1, last_block_filters])
153

154
      # Final fully-connected layer.
155
      inputs = tf.layers.dense(
156
          inputs=inputs,
157
          units=cfg.num_classes,
158
          activation=None,
159
          use_bias=True,
160
          name='final_dense')
161
    return inputs
162

163
  return model
164

165

166
def build_model(features, cfg):
167
  """Builds the MobileNetV1 model and returns the output logits.
168

169
  Args:
170
    features: Input features tensor for the model.
171
    cfg: Configuration for the model.
172

173
  Returns:
174
    Computed logits from the model.
175
  """
176
  model = mobilenet_generator(cfg)
177
  return model(features)
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.