google-research
384 строки · 16.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""No padding inception FCN base network for a 911x911 receptive field.
17
18This is a variant of inception v3 FCN that takes a larger receptive field and
19predicts a larger patch size.
20"""
21import tensorflow.compat.v1 as tf
22import tf_slim as slim
23
24# The downsampling factor of the network.
25MODEL_DOWNSAMPLE_FACTOR = 2**4
26
27
28def _trim_border_px(inputs, n):
29"""Crop n pixels around the border of inputs.
30
31Args:
32inputs: a tensor of size [batch_size, height, width, channels].
33n: an integer for number of pixels to crop.
34
35Returns:
36cropped tensor.
37Raises:
38ValueError: if cropping leads to empty output tensor.
39"""
40if n > min(inputs.shape[1], inputs.shape[2]) // 2:
41raise ValueError(
42'n (%d) can not be greater than or equal to half of the input shape.' %
43n)
44return inputs[:, n:-n, n:-n, :]
45
46
47def nopad_inception_v3_base_911(inputs,
48min_depth=16,
49depth_multiplier=1.0,
50num_final_1x1_conv=0,
51scope=None):
52"""Constructs a no padding Inception v3 network from inputs.
53
54Args:
55inputs: a tensor of size [batch_size, height, width, channels]. Must be
56floating point. If a pretrained checkpoint is used, pixel values should be
57the same as during training.
58min_depth: Minimum depth value (number of channels) for all convolution ops.
59Enforced when depth_multiplier < 1, and not an active constraint when
60depth_multiplier >= 1.
61depth_multiplier: Float multiplier for the depth (number of channels) for
62all convolution ops. The value must be greater than zero. Typical usage
63will be to set this value in (0, 1) to reduce the number of parameters or
64computation cost of the model.
65num_final_1x1_conv: Int, number of final 1x1 conv layers.
66scope: Optional variable_scope.
67
68Returns:
69tensor_out: output tensor.
70end_points: a set of activations for external use, for example summaries or
71losses.
72
73Raises:
74ValueError: if depth_multiplier <= 0
75"""
76# end_points will collect relevant activations for external use, for example
77# summaries or losses.
78end_points = {}
79
80if depth_multiplier <= 0:
81raise ValueError('depth_multiplier is not greater than zero.')
82depth = lambda d: max(int(d * depth_multiplier), min_depth)
83
84with tf.variable_scope(scope, 'NopadInceptionV3', [inputs]):
85with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
86stride=1,
87padding='VALID'):
88# 911 x 911 x 3
89end_point = 'Conv2d_1a_3x3'
90net = slim.conv2d(inputs, depth(32), [3, 3], stride=2, scope=end_point)
91end_points[end_point] = net
92# 455 x 455 x 32
93end_point = 'Conv2d_2a_3x3'
94net = slim.conv2d(net, depth(32), [3, 3], scope=end_point)
95end_points[end_point] = net
96# 453 x 453 x 32
97end_point = 'Conv2d_2b_3x3'
98net = slim.conv2d(net, depth(64), [3, 3], scope=end_point)
99end_points[end_point] = net
100# 451 x 451 x 64
101end_point = 'MaxPool_3a_3x3'
102net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
103end_points[end_point] = net
104# 225 x 225 x 64
105end_point = 'Conv2d_3b_1x1'
106net = slim.conv2d(net, depth(80), [1, 1], scope=end_point)
107end_points[end_point] = net
108# 225 x 225 x 80.
109end_point = 'Conv2d_4a_3x3'
110net = slim.conv2d(net, depth(192), [3, 3], scope=end_point)
111end_points[end_point] = net
112# 223 x 223 x 192.
113end_point = 'MaxPool_5a_3x3'
114net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
115end_points[end_point] = net
116# 111 x 111 x 192.
117
118# Inception blocks
119with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
120stride=1,
121padding='VALID'):
122# Mixed_5b: 107 x 107 x 256.
123end_point = 'Mixed_5b'
124with tf.variable_scope(end_point):
125with tf.variable_scope('Branch_0'):
126branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
127with tf.variable_scope('Branch_1'):
128branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
129branch_1 = slim.conv2d(
130branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5')
131with tf.variable_scope('Branch_2'):
132branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
133branch_2 = slim.conv2d(
134branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
135branch_2 = slim.conv2d(
136branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
137with tf.variable_scope('Branch_3'):
138branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
139branch_3 = slim.conv2d(
140branch_3, depth(32), [1, 1], scope='Conv2d_0b_1x1')
141net = tf.concat(
142[
143_trim_border_px(branch_0, 2), # branch_0: 111 x 111 x 64
144branch_1, # branch_1: 107 x 107 x 64
145branch_2, # branch_2: 107 x 107 x 96
146_trim_border_px(branch_3, 1) # branch_3: 109 x 109 x 32
147],
1483)
149end_points[end_point] = net
150
151# Mixed_5c: 103 x 103 x 288.
152end_point = 'Mixed_5c'
153with tf.variable_scope(end_point):
154with tf.variable_scope('Branch_0'):
155branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
156with tf.variable_scope('Branch_1'):
157branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0b_1x1')
158branch_1 = slim.conv2d(
159branch_1, depth(64), [5, 5], scope='Conv_1_0c_5x5')
160with tf.variable_scope('Branch_2'):
161branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
162branch_2 = slim.conv2d(
163branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
164branch_2 = slim.conv2d(
165branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
166with tf.variable_scope('Branch_3'):
167branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
168branch_3 = slim.conv2d(
169branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1')
170net = tf.concat(
171[
172_trim_border_px(branch_0, 2), # branch_0: 107 x 107 x 64
173branch_1, # branch_1: 103 x 103 x 64
174branch_2, # branch_2: 103 x 103 x 96
175_trim_border_px(branch_3, 1) # branch_3: 105 x 105 x 64
176],
1773)
178end_points[end_point] = net
179
180# Mixed_5d: 99 x 99 x 288.
181end_point = 'Mixed_5d'
182with tf.variable_scope(end_point):
183with tf.variable_scope('Branch_0'):
184branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
185with tf.variable_scope('Branch_1'):
186branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
187branch_1 = slim.conv2d(
188branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5')
189with tf.variable_scope('Branch_2'):
190branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
191branch_2 = slim.conv2d(
192branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
193branch_2 = slim.conv2d(
194branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
195with tf.variable_scope('Branch_3'):
196branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
197branch_3 = slim.conv2d(
198branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1')
199net = tf.concat(
200[
201_trim_border_px(branch_0, 2), # branch_0: 103 x 103 x 64
202branch_1, # branch_1: 99 x 99 x 64
203branch_2, # branch_2: 99 x 99 x 96
204_trim_border_px(branch_3, 1) # branch_2: 101 x 101 x 64
205],
2063)
207
208end_points[end_point] = net
209
210# Mixed_6a: 49 x 49 x 768.
211end_point = 'Mixed_6a'
212with tf.variable_scope(end_point):
213with tf.variable_scope('Branch_0'):
214branch_0 = slim.conv2d(
215net,
216depth(384), [3, 3],
217stride=2,
218padding='VALID',
219scope='Conv2d_1a_1x1')
220with tf.variable_scope('Branch_1'):
221branch_1 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
222branch_1 = slim.conv2d(
223branch_1,
224depth(96), [3, 3],
225stride=2,
226padding='VALID',
227scope='Conv2d_1a_1x1')
228with tf.variable_scope('Branch_2'):
229branch_2 = slim.max_pool2d(
230net, [3, 3], stride=2, padding='VALID', scope='MaxPool_1a_3x3')
231net = tf.concat(
232[
233branch_0, # branch_0: 49 x 49 x 384
234branch_1, # branch_1: 49 x 49 x 96
235branch_2, # branch_2: 49 x 49 x 288
236],
2373)
238end_points[end_point] = net
239
240# Mixed_6b: 37 x 37 x 768.
241end_point = 'Mixed_6b'
242with tf.variable_scope(end_point):
243with tf.variable_scope('Branch_0'):
244branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
245with tf.variable_scope('Branch_1'):
246branch_1 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
247branch_1 = slim.conv2d(
248branch_1, depth(128), [1, 7], scope='Conv2d_0b_1x7')
249branch_1 = slim.conv2d(
250branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
251with tf.variable_scope('Branch_2'):
252branch_2 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
253branch_2 = slim.conv2d(
254branch_2, depth(128), [7, 1], scope='Conv2d_0b_7x1')
255branch_2 = slim.conv2d(
256branch_2, depth(128), [1, 7], scope='Conv2d_0c_1x7')
257branch_2 = slim.conv2d(
258branch_2, depth(128), [7, 1], scope='Conv2d_0d_7x1')
259branch_2 = slim.conv2d(
260branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
261with tf.variable_scope('Branch_3'):
262branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
263branch_3 = slim.conv2d(
264branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
265net = tf.concat(
266[
267_trim_border_px(branch_0, 6), # branch_0: 49 x 49 x 192
268_trim_border_px(branch_1, 3), # branch_1: 43 x 43 x 192
269branch_2, # branch_2: 37 x 37 x 192
270_trim_border_px(branch_3, 5) # branch_3: 47 x 47 x 192
271],
2723)
273end_points[end_point] = net
274
275# Mixed_6c: 25 x 25 x 768.
276end_point = 'Mixed_6c'
277with tf.variable_scope(end_point):
278with tf.variable_scope('Branch_0'):
279branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
280with tf.variable_scope('Branch_1'):
281branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
282branch_1 = slim.conv2d(
283branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7')
284branch_1 = slim.conv2d(
285branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
286with tf.variable_scope('Branch_2'):
287branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
288branch_2 = slim.conv2d(
289branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1')
290branch_2 = slim.conv2d(
291branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7')
292branch_2 = slim.conv2d(
293branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1')
294branch_2 = slim.conv2d(
295branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
296with tf.variable_scope('Branch_3'):
297branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
298branch_3 = slim.conv2d(
299branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
300net = tf.concat(
301[
302_trim_border_px(branch_0, 6), # branch_0: 37 x 37 x 192
303_trim_border_px(branch_1, 3), # branch_1: 31 x 31 x 192
304branch_2, # branch_2: 25 x 25 x 192
305_trim_border_px(branch_3, 5) # branch_3: 35 x 35 x 192
306],
3073)
308end_points[end_point] = net
309
310# mixed_6: 13 x 13 x 768.
311end_point = 'Mixed_6d'
312with tf.variable_scope(end_point):
313with tf.variable_scope('Branch_0'):
314branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
315with tf.variable_scope('Branch_1'):
316branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
317branch_1 = slim.conv2d(
318branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7')
319branch_1 = slim.conv2d(
320branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
321with tf.variable_scope('Branch_2'):
322branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
323branch_2 = slim.conv2d(
324branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1')
325branch_2 = slim.conv2d(
326branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7')
327branch_2 = slim.conv2d(
328branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1')
329branch_2 = slim.conv2d(
330branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
331with tf.variable_scope('Branch_3'):
332branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
333branch_3 = slim.conv2d(
334branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
335net = tf.concat(
336[
337_trim_border_px(branch_0, 6), # branch_0: 25 x 25 x 192
338_trim_border_px(branch_1, 3), # branch_1: 19 x 19 x 192
339branch_2, # branch_2: 13 x 13 x 192
340_trim_border_px(branch_3, 5) # branch_3: 23 x 23 x 192
341],
3423)
343end_points[end_point] = net
344
345# Mixed_6e: 1 x 1 x 768.
346end_point = 'Mixed_6e'
347with tf.variable_scope(end_point):
348with tf.variable_scope('Branch_0'):
349branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
350with tf.variable_scope('Branch_1'):
351branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
352branch_1 = slim.conv2d(
353branch_1, depth(192), [1, 7], scope='Conv2d_0b_1x7')
354branch_1 = slim.conv2d(
355branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
356with tf.variable_scope('Branch_2'):
357branch_2 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
358branch_2 = slim.conv2d(
359branch_2, depth(192), [7, 1], scope='Conv2d_0b_7x1')
360branch_2 = slim.conv2d(
361branch_2, depth(192), [1, 7], scope='Conv2d_0c_1x7')
362branch_2 = slim.conv2d(
363branch_2, depth(192), [7, 1], scope='Conv2d_0d_7x1')
364branch_2 = slim.conv2d(
365branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
366with tf.variable_scope('Branch_3'):
367branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
368branch_3 = slim.conv2d(
369branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
370net = tf.concat(
371[
372_trim_border_px(branch_0, 6), # branch_0: 13 x 13 x 192
373_trim_border_px(branch_1, 3), # branch_1: 7 x 7 x 192
374branch_2, # branch_2: 1 x 1 x 192
375_trim_border_px(branch_3, 5) # branch_3: 11 x 11 x 192
376],
3773)
378end_points[end_point] = net
379
380for i in range(num_final_1x1_conv):
381slim.conv2d(
382net, depth(256), [1, 1], scope='Final_Conv2d_{}_1x1'.format(i))
383end_points['Final_Conv2d_{}_1x1'.format(i)] = net
384return net, end_points
385