google-research

Форк
0
/
resnet_model.py 
294 строки · 10.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2018 Google Inc. All Rights Reserved.
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
#      http://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
# ==============================================================================
30

31
"""ResNet-32 model.
32

33
Related papers:
34
https://arxiv.org/pdf/1603.05027v2.pdf
35
https://arxiv.org/pdf/1512.03385v1.pdf
36
https://arxiv.org/pdf/1605.07146v1.pdf
37
"""
38
from collections import namedtuple  # pylint: disable=g-importing-member
39

40
import numpy as np
41

42
import tensorflow as tf
43
from tensorflow.python.training import moving_averages
44

45

46
HParams = namedtuple('HParams',
47
                     'batch_size, num_classes, min_lrn_rate, lrn_rate, '
48
                     'num_residual_units, use_bottleneck, weight_decay_rate, '
49
                     'relu_leakiness, optimizer')
50

51

52
class ResNet(object):
53
  """ResNet model."""
54

55
  def __init__(self, hps, images, labels, mode):
56
    """ResNet constructor.
57

58
    Args:
59
      hps: Hyperparameters.
60
      images: Batches of images. [batch_size, image_size, image_size, 3]
61
      labels: Batches of labels. [batch_size, num_classes]
62
      mode: One of 'train' and 'eval'.
63
    """
64
    self.hps = hps
65
    self._images = images
66
    self.labels = labels
67
    self.mode = mode
68

69
    self.extra_train_ops = []
70

71
  def build_graph_unused(self):
72
    """Build a whole graph for the model."""
73
    self.global_step = tf.Variable(0, name='global_step', trainable=False)
74
    self.build_model()
75
    if self.mode == 'train':
76
      self._build_train_op()
77
    self.summaries = tf.summary.merge_all()
78

79
  def _stride_arr(self, stride):
80
    """Map a stride scalar to the stride array for tf.nn.conv2d."""
81
    return [1, stride, stride, 1]
82

83
  def build_model(self):
84
    """Build the core model within the graph."""
85
    with tf.variable_scope('init'):
86
      x = self._images
87
      x = self._conv('init_conv', x, 3, 3, 64, self._stride_arr(1))
88

89
    strides = [1, 2, 2]
90
    activate_before_residual = [True, False, False]
91
    if self.hps.use_bottleneck:
92
      res_func = self._bottleneck_residual
93
      filters = [16, 64, 128, 256]
94
    else:
95
      res_func = self._residual
96
      filters = [64, 128, 256, 512]
97

98
    with tf.variable_scope('unit_1_0'):
99
      x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
100
                   activate_before_residual[0])
101
    for i in range(1, self.hps.num_residual_units):
102
      with tf.variable_scope('unit_1_%d' % i):
103
        x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
104

105
    with tf.variable_scope('unit_2_0'):
106
      x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
107
                   activate_before_residual[1])
108
    for i in range(1, self.hps.num_residual_units):
109
      with tf.variable_scope('unit_2_%d' % i):
110
        x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
111

112
    with tf.variable_scope('unit_3_0'):
113
      x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
114
                   activate_before_residual[2])
115
    for i in range(1, self.hps.num_residual_units):
116
      with tf.variable_scope('unit_3_%d' % i):
117
        x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
118

119
    with tf.variable_scope('unit_last'):
120
      x = self._batch_norm('final_bn', x)
121
      x = self._relu(x, self.hps.relu_leakiness)
122
      x = self._global_avg_pool(x)
123

124
    with tf.variable_scope('logit'):
125
      logits = self._fully_connected(x, self.hps.num_classes)
126

127
    return logits
128

129
  def _build_train_op(self):
130
    """Build training specific ops for the graph."""
131
    self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
132
    tf.compat.v1.summary.scalar('learning rate', self.lrn_rate)
133

134
    trainable_variables = tf.trainable_variables()
135
    grads = tf.gradients(self.cost, trainable_variables)
136

137
    if self.hps.optimizer == 'sgd':
138
      optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
139
    elif self.hps.optimizer == 'mom':
140
      optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
141

142
    apply_op = optimizer.apply_gradients(
143
        zip(grads, trainable_variables),
144
        global_step=self.global_step, name='train_step')
145

146
    train_ops = [apply_op] + self.extra_train_ops
147
    self.train_op = tf.group(*train_ops)
148

149
  def _batch_norm(self, name, x):
150
    """Batch normalization."""
151
    with tf.variable_scope(name):
152
      params_shape = [x.get_shape()[-1]]
153

154
      beta = tf.get_variable(
155
          'beta', params_shape, tf.float32,
156
          initializer=tf.constant_initializer(0.0, tf.float32))
157
      gamma = tf.get_variable(
158
          'gamma', params_shape, tf.float32,
159
          initializer=tf.constant_initializer(1.0, tf.float32))
160

161
      if self.mode == 'train':
162
        mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
163

164
        moving_mean = tf.get_variable(
165
            'moving_mean', params_shape, tf.float32,
166
            initializer=tf.constant_initializer(0.0, tf.float32),
167
            trainable=False)
168
        moving_variance = tf.get_variable(
169
            'moving_variance', params_shape, tf.float32,
170
            initializer=tf.constant_initializer(1.0, tf.float32),
171
            trainable=False)
172

173
        self.extra_train_ops.append(moving_averages.assign_moving_average(
174
            moving_mean, mean, 0.9))
175
        self.extra_train_ops.append(moving_averages.assign_moving_average(
176
            moving_variance, variance, 0.9, zero_debias=False))
177
      else:
178
        mean = tf.get_variable(
179
            'moving_mean', params_shape, tf.float32,
180
            initializer=tf.constant_initializer(0.0, tf.float32),
181
            trainable=False)
182
        variance = tf.get_variable(
183
            'moving_variance', params_shape, tf.float32,
184
            initializer=tf.constant_initializer(1.0, tf.float32),
185
            trainable=False)
186
      y = tf.nn.batch_normalization(
187
          x, mean, variance, beta, gamma, 0.001)
188
      y.set_shape(x.get_shape())
189
      return y
190

191
  def _residual(self, x, in_filter, out_filter, stride,
192
                activate_before_residual=False):
193
    """Residual unit with 2 sub layers."""
194
    if activate_before_residual:
195
      with tf.variable_scope('shared_activation'):
196
        x = self._relu(x, self.hps.relu_leakiness)
197
        x = self._batch_norm('init_bn', x)
198
        orig_x = x
199
    else:
200
      with tf.variable_scope('residual_only_activation'):
201
        orig_x = x
202
        x = self._batch_norm('init_bn', x)
203
        x = self._relu(x, self.hps.relu_leakiness)
204

205
    with tf.variable_scope('sub1'):
206
      x = self._conv('conv1', x, 3, in_filter, out_filter, stride)
207

208
    with tf.variable_scope('sub2'):
209
      x = self._batch_norm('bn2', x)
210
      x = self._relu(x, self.hps.relu_leakiness)
211
      x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
212

213
    with tf.variable_scope('sub_add'):
214
      if in_filter != out_filter:
215
        orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID')
216
        orig_x = tf.pad(
217
            orig_x, [[0, 0], [0, 0], [0, 0],
218
                     [(out_filter-in_filter)//2, (out_filter-in_filter)//2]])
219
      x += orig_x
220

221
    tf.logging.info('image after unit %s', x.get_shape())
222
    return x
223

224
  def _bottleneck_residual(self, x, in_filter, out_filter, stride,
225
                           activate_before_residual=False):
226
    """Bottleneck residual unit with 3 sub layers."""
227
    if activate_before_residual:
228
      with tf.variable_scope('common_bn_relu'):
229
        x = self._batch_norm('init_bn', x)
230
        x = self._relu(x, self.hps.relu_leakiness)
231
        orig_x = x
232
    else:
233
      with tf.variable_scope('residual_bn_relu'):
234
        orig_x = x
235
        x = self._batch_norm('init_bn', x)
236
        x = self._relu(x, self.hps.relu_leakiness)
237

238
    with tf.variable_scope('sub1'):
239
      x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride)
240

241
    with tf.variable_scope('sub2'):
242
      x = self._batch_norm('bn2', x)
243
      x = self._relu(x, self.hps.relu_leakiness)
244
      x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1])
245

246
    with tf.variable_scope('sub3'):
247
      x = self._batch_norm('bn3', x)
248
      x = self._relu(x, self.hps.relu_leakiness)
249
      x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1])
250

251
    with tf.variable_scope('sub_add'):
252
      if in_filter != out_filter:
253
        orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride)
254
      x += orig_x
255

256
    tf.logging.info('image after unit %s', x.get_shape())
257
    return x
258

259
  def decay(self):
260
    """L2 weight decay loss."""
261
    costs = []
262
    for var in tf.trainable_variables():
263
      if var.op.name.find(r'DW') > 0:
264
        costs.append(tf.nn.l2_loss(var))
265

266
    return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
267

268
  def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
269
    """Convolution."""
270
    with tf.variable_scope(name):
271
      n = filter_size * filter_size * out_filters
272
      kernel = tf.get_variable(
273
          'DW', [filter_size, filter_size, in_filters, out_filters],
274
          tf.float32, initializer=tf.random_normal_initializer(
275
              stddev=np.sqrt(2.0/n)))
276
      return tf.nn.conv2d(x, kernel, strides, padding='SAME')
277

278
  def _relu(self, x, leakiness=0.0):
279
    """Relu, with optional leaky support."""
280
    return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
281

282
  def _fully_connected(self, x, out_dim):
283
    """FullyConnected layer for final output."""
284
    x = tf.reshape(x, [self.hps.batch_size, -1])
285
    w = tf.get_variable(
286
        'DW', [x.get_shape()[1], out_dim],
287
        initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
288
    b = tf.get_variable('biases', [out_dim],
289
                        initializer=tf.constant_initializer())
290
    return tf.nn.xw_plus_b(x, w, b)
291

292
  def _global_avg_pool(self, x):
293
    assert x.get_shape().ndims == 4
294
    return tf.reduce_mean(x, [1, 2])
295

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.