google-research
177 строк · 5.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""MobileNetV1 model and builder functions. Only for inference."""
17import functools
18import tensorflow.compat.v1 as tf
19
20from sgk.mbv1 import layers
21
22MOVING_AVERAGE_DECAY = 0.9
23EPSILON = 1e-5
24
25
26def _make_divisible(v, divisor=8, min_value=None):
27if min_value is None:
28min_value = divisor
29new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30# Make sure that round down does not go down by more than 10%.
31if new_v < 0.9 * v:
32new_v += divisor
33return new_v
34
35
36def batch_norm_relu(x, fuse_batch_norm=False):
37"""Batch normalization + ReLU."""
38if not fuse_batch_norm:
39inputs = tf.layers.batch_normalization(
40inputs=x,
41axis=1,
42momentum=MOVING_AVERAGE_DECAY,
43epsilon=EPSILON,
44center=True,
45scale=True,
46training=False,
47fused=True)
48return tf.nn.relu(inputs)
49return x
50
51
52def mbv1_block_(inputs, filters, stride, block_id, cfg):
53"""Standard building block for mobilenetv1 networks.
54
55Args:
56inputs: Input tensor, float32 of size [batch, channels, height, width].
57filters: Int specifying number of filters for the first two convolutions.
58stride: Int specifying the stride. If stride >1, the input is downsampled.
59block_id: which block this is.
60cfg: Configuration for the model.
61
62Returns:
63The output activation tensor.
64"""
65# Setup the depthwise convolution layer.
66depthwise_conv = layers.DepthwiseConv2D(
67kernel_size=3,
68strides=[1, 1, stride, stride],
69padding=[0, 0, 1, 1],
70activation=tf.nn.relu if cfg.fuse_bnbr else None,
71use_bias=cfg.fuse_bnbr,
72name='depthwise_nxn_%s' % block_id)
73
74# Depthwise convolution, batch norm, relu.
75depthwise_out = batch_norm_relu(depthwise_conv(inputs), cfg.fuse_bnbr)
76
77# Setup the 1x1 convolution layer.
78out_filters = _make_divisible(
79int(cfg.width * filters), divisor=1 if block_id == 0 else 8)
80end_point = 'contraction_1x1_%s' % block_id
81
82conv_fn = layers.Conv2D
83if cfg.block_config[block_id] == 'sparse':
84conv_fn = functools.partial(
85layers.SparseConv2D, nonzeros=cfg.block_nonzeros[block_id])
86
87contraction = conv_fn(
88out_filters,
89activation=tf.nn.relu if cfg.fuse_bnbr else None,
90use_bias=cfg.fuse_bnbr,
91name=end_point)
92
93# Run the 1x1 convolution followed by batch norm and relu.
94return batch_norm_relu(contraction(depthwise_out), cfg.fuse_bnbr)
95
96
97def mobilenet_generator(cfg):
98"""Generator for mobilenet v2 models.
99
100Args:
101cfg: Configuration for the model.
102
103Returns:
104Model `function` that takes in `inputs` and returns the output `Tensor`
105of the model.
106"""
107
108def model(inputs):
109"""Creation of the model graph."""
110with tf.variable_scope('mobilenet_model', reuse=tf.AUTO_REUSE):
111# Initial convolutional layer.
112initial_conv_filters = _make_divisible(32 * cfg.width)
113initial_conv = layers.Conv2D(
114filters=initial_conv_filters,
115kernel_size=3,
116stride=2,
117padding=[[0, 0], [0, 0], [1, 1], [1, 1]],
118activation=tf.nn.relu,
119use_bias=True,
120name='initial_conv')
121inputs = batch_norm_relu(initial_conv(inputs), cfg.fuse_bnbr)
122
123mb_block = functools.partial(mbv1_block_, cfg=cfg)
124
125# Core MobileNetV1 blocks.
126inputs = mb_block(inputs, filters=64, stride=1, block_id=0)
127inputs = mb_block(inputs, filters=128, stride=2, block_id=1)
128inputs = mb_block(inputs, filters=128, stride=1, block_id=2)
129inputs = mb_block(inputs, filters=256, stride=2, block_id=3)
130inputs = mb_block(inputs, filters=256, stride=1, block_id=4)
131inputs = mb_block(inputs, filters=512, stride=2, block_id=5)
132inputs = mb_block(inputs, filters=512, stride=1, block_id=6)
133inputs = mb_block(inputs, filters=512, stride=1, block_id=7)
134inputs = mb_block(inputs, filters=512, stride=1, block_id=8)
135inputs = mb_block(inputs, filters=512, stride=1, block_id=9)
136inputs = mb_block(inputs, filters=512, stride=1, block_id=10)
137inputs = mb_block(inputs, filters=1024, stride=2, block_id=11)
138inputs = mb_block(inputs, filters=1024, stride=1, block_id=12)
139
140# Pooling layer.
141inputs = tf.layers.average_pooling2d(
142inputs=inputs,
143pool_size=(inputs.shape[2], inputs.shape[3]),
144strides=1,
145padding='VALID',
146data_format='channels_first',
147name='final_avg_pool')
148
149# Reshape the output of the pooling layer to 2D for the
150# final fully-connected layer.
151last_block_filters = _make_divisible(int(1024 * cfg.width), 8)
152inputs = tf.reshape(inputs, [-1, last_block_filters])
153
154# Final fully-connected layer.
155inputs = tf.layers.dense(
156inputs=inputs,
157units=cfg.num_classes,
158activation=None,
159use_bias=True,
160name='final_dense')
161return inputs
162
163return model
164
165
166def build_model(features, cfg):
167"""Builds the MobileNetV1 model and returns the output logits.
168
169Args:
170features: Input features tensor for the model.
171cfg: Configuration for the model.
172
173Returns:
174Computed logits from the model.
175"""
176model = mobilenet_generator(cfg)
177return model(features)
178