google-research

Форк
0
149 строк · 5.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""No padding Inception FCN neural network.
17

18
This is a variant of Inception v3 that removes all paddings. This change
19
allows the network to be trained and inference run with different patch size
20
(Fully Convolutional Network, FCN mode) while having the same inference results.
21
The network can be initialized for two different receptive fields: 911 and 129.
22
"""
23

24
import tensorflow.compat.v1 as tf
25
import tf_slim as slim
26

27
import inception_base_129
28
import inception_base_911
29
import network
30
import network_params
31
import scope_utils
32

33

34
def get_inception_base_and_downsample_factor(receptive_field_size):
35
  """Get the Inception base network and its downsample factor."""
36
  if receptive_field_size == 911:
37
    return inception_base_911.nopad_inception_v3_base_911, inception_base_911.MODEL_DOWNSAMPLE_FACTOR
38
  elif receptive_field_size == 129:
39
    return inception_base_129.nopad_inception_v3_base_129, inception_base_129.MODEL_DOWNSAMPLE_FACTOR
40
  else:
41
    raise ValueError(
42
        f'Receptive field size should be 911 or 129. {receptive_field_size} was provided.'
43
    )
44

45

46
class InceptionV3FCN(network.Network):
47
  """A no pad, fully convolutional InceptionV3 model."""
48

49
  def __init__(
50
      self,
51
      inception_params,
52
      conv_scope_params,
53
      num_classes = 2,
54
      is_training = True,
55
  ):
56
    """Creates a no pad, fully convolutional InceptionV3 model.
57

58
    Args:
59
      inception_params: parameters specific to the InceptionV3
60
      conv_scope_params: parameters used to configure the general convolution
61
        parameters used in the slim argument scope.
62
      num_classes: number of output classes from the model
63
      is_training: whether the network should be built for training or inference
64
    """
65
    super().__init__()
66
    self._num_classes = num_classes
67
    self._is_training = is_training
68
    self._network_base, self._downsample_factor = get_inception_base_and_downsample_factor(
69
        inception_params.receptive_field_size)
70
    self._prelogit_dropout_keep_prob = inception_params.prelogit_dropout_keep_prob
71
    self._depth_multiplier = inception_params.depth_multiplier
72
    self._min_depth = inception_params.min_depth
73
    self._inception_fcn_stride = inception_params.inception_fcn_stride
74
    self._conv_scope_params = conv_scope_params
75
    if self._depth_multiplier <= 0:
76
      raise ValueError('param depth_multiplier should be greater than zero.')
77
    self._logits_stride = int(
78
        self._inception_fcn_stride /
79
        self._downsample_factor) if self._inception_fcn_stride else 1
80

81
  def build(self, inputs):
82
    """Returns an InceptionV3FCN model with configurable conv2d normalization.
83

84
    Args:
85
      inputs: a map from input string names to tensors. Required:
86
        * IMAGES: a tensor of shape [batch, height, width, channels]
87

88
    Returns:
89
      A dictionary from network layer names to the corresponding layer
90
      activation Tensors. Includes:
91
        * PRE_LOGITS: activation layer preceding LOGITS
92
        * LOGITS: the pre-softmax activations, size [batch, num_classes]
93
        * PROBABILITIES: softmax probs, size [batch, num_classes]
94
    """
95
    images = self._get_tensor(inputs, self.IMAGES, expected_rank=4)
96
    with slim.arg_scope(
97
        scope_utils.get_conv_scope(self._conv_scope_params, self._is_training)):
98
      net, end_points = self._network_base(
99
          images,
100
          min_depth=self._min_depth,
101
          depth_multiplier=self._depth_multiplier)
102
      # Final pooling and prediction
103
      with tf.variable_scope('Logits'):
104
        # 1 x 1 x 768
105
        net = slim.dropout(
106
            net,
107
            keep_prob=self._prelogit_dropout_keep_prob,
108
            is_training=self._is_training,
109
            scope='Dropout_1b')
110
        end_points[self.PRE_LOGITS] = net
111
        # 1 x 1 x num_classes
112
        logits = slim.conv2d(
113
            net,
114
            self._num_classes, [1, 1],
115
            activation_fn=None,
116
            normalizer_fn=None,
117
            stride=self._logits_stride,
118
            scope='Conv2d_1c_1x1')
119
      probabilities_tensor = tf.nn.softmax(logits)
120
      end_points[self.PROBABILITIES_TENSOR] = probabilities_tensor
121
      if self._logits_stride == 1:
122
        # Reshape to remove height and width
123
        end_points[self.LOGITS] = tf.squeeze(
124
            logits, [1, 2], name='SpatialSqueeze')
125
        end_points[self.PROBABILITIES] = tf.squeeze(
126
            probabilities_tensor, [1, 2], name='SpatialSqueeze')
127
      else:
128
        end_points[self.LOGITS] = logits
129
        end_points[self.PROBABILITIES] = probabilities_tensor
130
    return end_points
131

132

133
def get_inception_v3_fcn_network_fn(
134
    inception_params,
135
    conv_scope_params,
136
    num_classes = 2,
137
    is_training = True,
138
):
139
  """Returns a function that return logits and endpoints for slim uptraining."""
140

141
  net = InceptionV3FCN(inception_params, conv_scope_params, num_classes,
142
                       is_training)
143

144
  def network_fn(images):
145
    images_dict = {'Images': images}
146
    endpoints = net.build(images_dict)
147
    return endpoints['Logits'], endpoints
148

149
  return network_fn
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.