google-research

Форк
0
262 строки · 9.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Segmentation results visualization on a given set of images.
17

18
See model.py for more details and usage.
19
"""
20

21
import math
22
import os.path
23
import time
24
import common  # pylint: disable=unused-import
25
from deeplab import save_annotation
26
import model
27
import model_input
28
import numpy as np
29
import tensorflow.compat.v1 as tf
30
import tensorflow.contrib.slim as slim
31
from tensorflow.compat.v1.python.platform import app
32
from tensorflow.contrib import slim as contrib_slim
33

34
flags = tf.app.flags
35
FLAGS = flags.FLAGS
36

37
# Settings for log directories.
38

39
flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
40

41
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
42

43
# Settings for evaluating the model.
44

45
flags.DEFINE_integer('batch_size', 32,
46
                     'The number of images in each batch during evaluation.')
47

48
flags.DEFINE_integer('num_vis_examples', 32,
49
                     'Number of examples for visualization.')
50

51
flags.DEFINE_integer('eval_interval_secs', 60,
52
                     'How often (in seconds) to run evaluation.')
53

54
flags.DEFINE_multi_float('eval_scales', [1.0],
55
                         'The scales to resize images for evaluation.')
56

57
flags.DEFINE_bool('add_flipped_images', False,
58
                  'Add flipped images for evaluation or not.')
59

60
flags.DEFINE_string('split', 'val',
61
                    'Which split of the dataset used for evaluation')
62

63
flags.DEFINE_string('colormap_type', 'pascal', 'Visualization colormap type.')
64

65
# The html template directory.
66
_HTML_TEMPLATE_DIR = '.'
67

68
# The folder where semantic segmentation predictions are saved.
69
_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results'
70

71
# The folder where raw semantic segmentation predictions are saved.
72
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results'
73

74
# The format to save image.
75
_IMAGE_FORMAT = '%06d_image'
76

77
# The format to save prediction
78
_PREDICTION_FORMAT = '%06d_prediction'
79

80
_LABEL_FORMAT = '%06d_label'
81

82

83
def _process_batch(sess, samples, semantic_predictions, labels, image_id_offset,
84
                   save_dir):
85
  """Evaluates one single batch qualitatively.
86

87
  Args:
88
    sess: TensorFlow session.
89
    samples: The input features.
90
    semantic_predictions: Model predictions.
91
    labels: Ground truth labels.
92
    image_id_offset: Image id offset for indexing images.
93
    save_dir: The directory where the predictions will be saved.
94
  Returns:
95
    The referring expressions.
96
  """
97
  (original_images, new_refs, semantic_predictions, labels) = sess.run(
98
      [samples['image'], samples['ref_exp'], semantic_predictions, labels])
99

100
  num_image = semantic_predictions.shape[0]
101
  for i in range(num_image):
102
    original_image = np.squeeze(original_images[i])
103
    semantic_prediction = np.squeeze(semantic_predictions[i])
104
    label = np.squeeze(labels[i])
105

106
    # Save image.
107
    save_annotation.save_annotation(
108
        original_image,
109
        save_dir,
110
        _IMAGE_FORMAT % (image_id_offset + i),
111
        add_colormap=False)
112

113
    # Save prediction.
114
    save_annotation.save_annotation(
115
        semantic_prediction,
116
        save_dir,
117
        _PREDICTION_FORMAT % (image_id_offset + i),
118
        add_colormap=True,
119
        colormap_type=FLAGS.colormap_type)
120

121
    save_annotation.save_annotation(
122
        label,
123
        save_dir,
124
        _LABEL_FORMAT % (image_id_offset + i),
125
        add_colormap=True,
126
        colormap_type=FLAGS.colormap_type)
127

128
  return new_refs.tolist()
129

130

131
def main(unused_argv):
132
  # Get dataset-dependent information.
133
  # Prepare for visualization.
134
  tf.gfile.MakeDirs(FLAGS.vis_logdir)
135
  save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
136
  tf.gfile.MakeDirs(save_dir)
137
  raw_save_dir = os.path.join(FLAGS.vis_logdir,
138
                              _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
139
  tf.gfile.MakeDirs(raw_save_dir)
140
  num_vis_examples = FLAGS.num_vis_examples
141

142
  print('Visualizing on set', FLAGS.split)
143

144
  g = tf.Graph()
145
  with g.as_default():
146
    samples = model_input.get_input_fn(FLAGS)()
147
    outputs_to_num_classes = model.get_output_to_num_classes(FLAGS)
148

149
    # Get model segmentation predictions.
150
    if tuple(FLAGS.eval_scales) == (1.0,):
151
      tf.logging.info('Performing single-scale test.')
152
      predictions, probs = model.predict_labels(
153
          samples['image'],
154
          samples,
155
          FLAGS,
156
          outputs_to_num_classes=outputs_to_num_classes,
157
          image_pyramid=FLAGS.image_pyramid,
158
          merge_method=FLAGS.merge_method,
159
          atrous_rates=FLAGS.atrous_rates,
160
          add_image_level_feature=FLAGS.add_image_level_feature,
161
          aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
162
          aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
163
          multi_grid=FLAGS.multi_grid,
164
          depth_multiplier=FLAGS.depth_multiplier,
165
          output_stride=FLAGS.output_stride,
166
          decoder_output_stride=FLAGS.decoder_output_stride,
167
          decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
168
          crop_size=[FLAGS.image_size, FLAGS.image_size],
169
          logits_kernel_size=FLAGS.logits_kernel_size,
170
          model_variant=FLAGS.model_variant)
171
    else:
172
      tf.logging.info('Performing multi-scale test.')
173
      predictions, probs = model.predict_labels_multi_scale(
174
          samples['image'],
175
          samples,
176
          FLAGS,
177
          outputs_to_num_classes=outputs_to_num_classes,
178
          eval_scales=FLAGS.eval_scales,
179
          add_flipped_images=FLAGS.add_flipped_images,
180
          merge_method=FLAGS.merge_method,
181
          atrous_rates=FLAGS.atrous_rates,
182
          add_image_level_feature=FLAGS.add_image_level_feature,
183
          aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
184
          aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
185
          multi_grid=FLAGS.multi_grid,
186
          depth_multiplier=FLAGS.depth_multiplier,
187
          output_stride=FLAGS.output_stride,
188
          decoder_output_stride=FLAGS.decoder_output_stride,
189
          decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
190
          crop_size=[FLAGS.image_size, FLAGS.image_size],
191
          logits_kernel_size=FLAGS.logits_kernel_size,
192
          model_variant=FLAGS.model_variant)
193

194
    if FLAGS.output_mode == 'segment':
195
      predictions = tf.squeeze(
196
          tf.cast(predictions[FLAGS.output_mode], tf.int32))
197
      probs = probs[FLAGS.output_mode]
198

199
      labels = tf.squeeze(tf.cast(samples['label'], tf.int32))
200
      weights = tf.cast(
201
          tf.not_equal(
202
              labels,
203
              model_input.dataset_descriptors[FLAGS.dataset].ignore_label),
204
          tf.int32)
205

206
      labels *= weights
207
      predictions *= weights
208

209
      tf.train.get_or_create_global_step()
210
      saver = tf.train.Saver(contrib_slim.get_variables_to_restore())
211
      sv = tf.train.Supervisor(
212
          graph=g,
213
          logdir=FLAGS.vis_logdir,
214
          init_op=tf.global_variables_initializer(),
215
          summary_op=None,
216
          summary_writer=None,
217
          global_step=None,
218
          saver=saver)
219
      num_batches = int(math.ceil(num_vis_examples / float(FLAGS.batch_size)))
220
      last_checkpoint = None
221

222
      # Infinite loop to visualize the results when new checkpoint is created.
223
      while True:
224
        last_checkpoint = contrib_slim.evaluation.wait_for_new_checkpoint(
225
            FLAGS.checkpoint_dir, last_checkpoint)
226
        start = time.time()
227
        print('Starting visualization at ' +
228
              time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
229
        print('Visualizing with model %s', last_checkpoint)
230

231
        print('Visualizing with model ', last_checkpoint)
232

233
        with sv.managed_session(
234
            FLAGS.master, start_standard_services=False) as sess:
235
          # sv.start_queue_runners(sess)
236
          sv.saver.restore(sess, last_checkpoint)
237

238
          image_id_offset = 0
239
          refs = []
240
          for batch in range(num_batches):
241
            print('Visualizing batch', batch + 1, num_batches)
242
            refs.extend(
243
                _process_batch(
244
                    sess=sess,
245
                    samples=samples,
246
                    semantic_predictions=predictions,
247
                    labels=labels,
248
                    image_id_offset=image_id_offset,
249
                    save_dir=save_dir))
250
            image_id_offset += FLAGS.batch_size
251

252
      print('Finished visualization at ' +
253
            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
254
      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
255
      if time_to_next_eval > 0:
256
        time.sleep(time_to_next_eval)
257

258

259
if __name__ == '__main__':
260
  flags.mark_flag_as_required('checkpoint_dir')
261
  flags.mark_flag_as_required('vis_logdir')
262
  app.run()
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.