google-research
294 строки · 10.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2018 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""ResNet-32 model.
32
33Related papers:
34https://arxiv.org/pdf/1603.05027v2.pdf
35https://arxiv.org/pdf/1512.03385v1.pdf
36https://arxiv.org/pdf/1605.07146v1.pdf
37"""
38from collections import namedtuple # pylint: disable=g-importing-member
39
40import numpy as np
41
42import tensorflow as tf
43from tensorflow.python.training import moving_averages
44
45
46HParams = namedtuple('HParams',
47'batch_size, num_classes, min_lrn_rate, lrn_rate, '
48'num_residual_units, use_bottleneck, weight_decay_rate, '
49'relu_leakiness, optimizer')
50
51
52class ResNet(object):
53"""ResNet model."""
54
55def __init__(self, hps, images, labels, mode):
56"""ResNet constructor.
57
58Args:
59hps: Hyperparameters.
60images: Batches of images. [batch_size, image_size, image_size, 3]
61labels: Batches of labels. [batch_size, num_classes]
62mode: One of 'train' and 'eval'.
63"""
64self.hps = hps
65self._images = images
66self.labels = labels
67self.mode = mode
68
69self.extra_train_ops = []
70
71def build_graph_unused(self):
72"""Build a whole graph for the model."""
73self.global_step = tf.Variable(0, name='global_step', trainable=False)
74self.build_model()
75if self.mode == 'train':
76self._build_train_op()
77self.summaries = tf.summary.merge_all()
78
79def _stride_arr(self, stride):
80"""Map a stride scalar to the stride array for tf.nn.conv2d."""
81return [1, stride, stride, 1]
82
83def build_model(self):
84"""Build the core model within the graph."""
85with tf.variable_scope('init'):
86x = self._images
87x = self._conv('init_conv', x, 3, 3, 64, self._stride_arr(1))
88
89strides = [1, 2, 2]
90activate_before_residual = [True, False, False]
91if self.hps.use_bottleneck:
92res_func = self._bottleneck_residual
93filters = [16, 64, 128, 256]
94else:
95res_func = self._residual
96filters = [64, 128, 256, 512]
97
98with tf.variable_scope('unit_1_0'):
99x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
100activate_before_residual[0])
101for i in range(1, self.hps.num_residual_units):
102with tf.variable_scope('unit_1_%d' % i):
103x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
104
105with tf.variable_scope('unit_2_0'):
106x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
107activate_before_residual[1])
108for i in range(1, self.hps.num_residual_units):
109with tf.variable_scope('unit_2_%d' % i):
110x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
111
112with tf.variable_scope('unit_3_0'):
113x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
114activate_before_residual[2])
115for i in range(1, self.hps.num_residual_units):
116with tf.variable_scope('unit_3_%d' % i):
117x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
118
119with tf.variable_scope('unit_last'):
120x = self._batch_norm('final_bn', x)
121x = self._relu(x, self.hps.relu_leakiness)
122x = self._global_avg_pool(x)
123
124with tf.variable_scope('logit'):
125logits = self._fully_connected(x, self.hps.num_classes)
126
127return logits
128
129def _build_train_op(self):
130"""Build training specific ops for the graph."""
131self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
132tf.compat.v1.summary.scalar('learning rate', self.lrn_rate)
133
134trainable_variables = tf.trainable_variables()
135grads = tf.gradients(self.cost, trainable_variables)
136
137if self.hps.optimizer == 'sgd':
138optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
139elif self.hps.optimizer == 'mom':
140optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
141
142apply_op = optimizer.apply_gradients(
143zip(grads, trainable_variables),
144global_step=self.global_step, name='train_step')
145
146train_ops = [apply_op] + self.extra_train_ops
147self.train_op = tf.group(*train_ops)
148
149def _batch_norm(self, name, x):
150"""Batch normalization."""
151with tf.variable_scope(name):
152params_shape = [x.get_shape()[-1]]
153
154beta = tf.get_variable(
155'beta', params_shape, tf.float32,
156initializer=tf.constant_initializer(0.0, tf.float32))
157gamma = tf.get_variable(
158'gamma', params_shape, tf.float32,
159initializer=tf.constant_initializer(1.0, tf.float32))
160
161if self.mode == 'train':
162mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
163
164moving_mean = tf.get_variable(
165'moving_mean', params_shape, tf.float32,
166initializer=tf.constant_initializer(0.0, tf.float32),
167trainable=False)
168moving_variance = tf.get_variable(
169'moving_variance', params_shape, tf.float32,
170initializer=tf.constant_initializer(1.0, tf.float32),
171trainable=False)
172
173self.extra_train_ops.append(moving_averages.assign_moving_average(
174moving_mean, mean, 0.9))
175self.extra_train_ops.append(moving_averages.assign_moving_average(
176moving_variance, variance, 0.9, zero_debias=False))
177else:
178mean = tf.get_variable(
179'moving_mean', params_shape, tf.float32,
180initializer=tf.constant_initializer(0.0, tf.float32),
181trainable=False)
182variance = tf.get_variable(
183'moving_variance', params_shape, tf.float32,
184initializer=tf.constant_initializer(1.0, tf.float32),
185trainable=False)
186y = tf.nn.batch_normalization(
187x, mean, variance, beta, gamma, 0.001)
188y.set_shape(x.get_shape())
189return y
190
191def _residual(self, x, in_filter, out_filter, stride,
192activate_before_residual=False):
193"""Residual unit with 2 sub layers."""
194if activate_before_residual:
195with tf.variable_scope('shared_activation'):
196x = self._relu(x, self.hps.relu_leakiness)
197x = self._batch_norm('init_bn', x)
198orig_x = x
199else:
200with tf.variable_scope('residual_only_activation'):
201orig_x = x
202x = self._batch_norm('init_bn', x)
203x = self._relu(x, self.hps.relu_leakiness)
204
205with tf.variable_scope('sub1'):
206x = self._conv('conv1', x, 3, in_filter, out_filter, stride)
207
208with tf.variable_scope('sub2'):
209x = self._batch_norm('bn2', x)
210x = self._relu(x, self.hps.relu_leakiness)
211x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
212
213with tf.variable_scope('sub_add'):
214if in_filter != out_filter:
215orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID')
216orig_x = tf.pad(
217orig_x, [[0, 0], [0, 0], [0, 0],
218[(out_filter-in_filter)//2, (out_filter-in_filter)//2]])
219x += orig_x
220
221tf.logging.info('image after unit %s', x.get_shape())
222return x
223
224def _bottleneck_residual(self, x, in_filter, out_filter, stride,
225activate_before_residual=False):
226"""Bottleneck residual unit with 3 sub layers."""
227if activate_before_residual:
228with tf.variable_scope('common_bn_relu'):
229x = self._batch_norm('init_bn', x)
230x = self._relu(x, self.hps.relu_leakiness)
231orig_x = x
232else:
233with tf.variable_scope('residual_bn_relu'):
234orig_x = x
235x = self._batch_norm('init_bn', x)
236x = self._relu(x, self.hps.relu_leakiness)
237
238with tf.variable_scope('sub1'):
239x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride)
240
241with tf.variable_scope('sub2'):
242x = self._batch_norm('bn2', x)
243x = self._relu(x, self.hps.relu_leakiness)
244x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1])
245
246with tf.variable_scope('sub3'):
247x = self._batch_norm('bn3', x)
248x = self._relu(x, self.hps.relu_leakiness)
249x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1])
250
251with tf.variable_scope('sub_add'):
252if in_filter != out_filter:
253orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride)
254x += orig_x
255
256tf.logging.info('image after unit %s', x.get_shape())
257return x
258
259def decay(self):
260"""L2 weight decay loss."""
261costs = []
262for var in tf.trainable_variables():
263if var.op.name.find(r'DW') > 0:
264costs.append(tf.nn.l2_loss(var))
265
266return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
267
268def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
269"""Convolution."""
270with tf.variable_scope(name):
271n = filter_size * filter_size * out_filters
272kernel = tf.get_variable(
273'DW', [filter_size, filter_size, in_filters, out_filters],
274tf.float32, initializer=tf.random_normal_initializer(
275stddev=np.sqrt(2.0/n)))
276return tf.nn.conv2d(x, kernel, strides, padding='SAME')
277
278def _relu(self, x, leakiness=0.0):
279"""Relu, with optional leaky support."""
280return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
281
282def _fully_connected(self, x, out_dim):
283"""FullyConnected layer for final output."""
284x = tf.reshape(x, [self.hps.batch_size, -1])
285w = tf.get_variable(
286'DW', [x.get_shape()[1], out_dim],
287initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
288b = tf.get_variable('biases', [out_dim],
289initializer=tf.constant_initializer())
290return tf.nn.xw_plus_b(x, w, b)
291
292def _global_avg_pool(self, x):
293assert x.get_shape().ndims == 4
294return tf.reduce_mean(x, [1, 2])
295