google-research
262 строки · 9.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Segmentation results visualization on a given set of images.
17
18See model.py for more details and usage.
19"""
20
21import math
22import os.path
23import time
24import common # pylint: disable=unused-import
25from deeplab import save_annotation
26import model
27import model_input
28import numpy as np
29import tensorflow.compat.v1 as tf
30import tensorflow.contrib.slim as slim
31from tensorflow.compat.v1.python.platform import app
32from tensorflow.contrib import slim as contrib_slim
33
34flags = tf.app.flags
35FLAGS = flags.FLAGS
36
37# Settings for log directories.
38
39flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
40
41flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
42
43# Settings for evaluating the model.
44
45flags.DEFINE_integer('batch_size', 32,
46'The number of images in each batch during evaluation.')
47
48flags.DEFINE_integer('num_vis_examples', 32,
49'Number of examples for visualization.')
50
51flags.DEFINE_integer('eval_interval_secs', 60,
52'How often (in seconds) to run evaluation.')
53
54flags.DEFINE_multi_float('eval_scales', [1.0],
55'The scales to resize images for evaluation.')
56
57flags.DEFINE_bool('add_flipped_images', False,
58'Add flipped images for evaluation or not.')
59
60flags.DEFINE_string('split', 'val',
61'Which split of the dataset used for evaluation')
62
63flags.DEFINE_string('colormap_type', 'pascal', 'Visualization colormap type.')
64
65# The html template directory.
66_HTML_TEMPLATE_DIR = '.'
67
68# The folder where semantic segmentation predictions are saved.
69_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results'
70
71# The folder where raw semantic segmentation predictions are saved.
72_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results'
73
74# The format to save image.
75_IMAGE_FORMAT = '%06d_image'
76
77# The format to save prediction
78_PREDICTION_FORMAT = '%06d_prediction'
79
80_LABEL_FORMAT = '%06d_label'
81
82
83def _process_batch(sess, samples, semantic_predictions, labels, image_id_offset,
84save_dir):
85"""Evaluates one single batch qualitatively.
86
87Args:
88sess: TensorFlow session.
89samples: The input features.
90semantic_predictions: Model predictions.
91labels: Ground truth labels.
92image_id_offset: Image id offset for indexing images.
93save_dir: The directory where the predictions will be saved.
94Returns:
95The referring expressions.
96"""
97(original_images, new_refs, semantic_predictions, labels) = sess.run(
98[samples['image'], samples['ref_exp'], semantic_predictions, labels])
99
100num_image = semantic_predictions.shape[0]
101for i in range(num_image):
102original_image = np.squeeze(original_images[i])
103semantic_prediction = np.squeeze(semantic_predictions[i])
104label = np.squeeze(labels[i])
105
106# Save image.
107save_annotation.save_annotation(
108original_image,
109save_dir,
110_IMAGE_FORMAT % (image_id_offset + i),
111add_colormap=False)
112
113# Save prediction.
114save_annotation.save_annotation(
115semantic_prediction,
116save_dir,
117_PREDICTION_FORMAT % (image_id_offset + i),
118add_colormap=True,
119colormap_type=FLAGS.colormap_type)
120
121save_annotation.save_annotation(
122label,
123save_dir,
124_LABEL_FORMAT % (image_id_offset + i),
125add_colormap=True,
126colormap_type=FLAGS.colormap_type)
127
128return new_refs.tolist()
129
130
131def main(unused_argv):
132# Get dataset-dependent information.
133# Prepare for visualization.
134tf.gfile.MakeDirs(FLAGS.vis_logdir)
135save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
136tf.gfile.MakeDirs(save_dir)
137raw_save_dir = os.path.join(FLAGS.vis_logdir,
138_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
139tf.gfile.MakeDirs(raw_save_dir)
140num_vis_examples = FLAGS.num_vis_examples
141
142print('Visualizing on set', FLAGS.split)
143
144g = tf.Graph()
145with g.as_default():
146samples = model_input.get_input_fn(FLAGS)()
147outputs_to_num_classes = model.get_output_to_num_classes(FLAGS)
148
149# Get model segmentation predictions.
150if tuple(FLAGS.eval_scales) == (1.0,):
151tf.logging.info('Performing single-scale test.')
152predictions, probs = model.predict_labels(
153samples['image'],
154samples,
155FLAGS,
156outputs_to_num_classes=outputs_to_num_classes,
157image_pyramid=FLAGS.image_pyramid,
158merge_method=FLAGS.merge_method,
159atrous_rates=FLAGS.atrous_rates,
160add_image_level_feature=FLAGS.add_image_level_feature,
161aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
162aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
163multi_grid=FLAGS.multi_grid,
164depth_multiplier=FLAGS.depth_multiplier,
165output_stride=FLAGS.output_stride,
166decoder_output_stride=FLAGS.decoder_output_stride,
167decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
168crop_size=[FLAGS.image_size, FLAGS.image_size],
169logits_kernel_size=FLAGS.logits_kernel_size,
170model_variant=FLAGS.model_variant)
171else:
172tf.logging.info('Performing multi-scale test.')
173predictions, probs = model.predict_labels_multi_scale(
174samples['image'],
175samples,
176FLAGS,
177outputs_to_num_classes=outputs_to_num_classes,
178eval_scales=FLAGS.eval_scales,
179add_flipped_images=FLAGS.add_flipped_images,
180merge_method=FLAGS.merge_method,
181atrous_rates=FLAGS.atrous_rates,
182add_image_level_feature=FLAGS.add_image_level_feature,
183aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
184aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
185multi_grid=FLAGS.multi_grid,
186depth_multiplier=FLAGS.depth_multiplier,
187output_stride=FLAGS.output_stride,
188decoder_output_stride=FLAGS.decoder_output_stride,
189decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
190crop_size=[FLAGS.image_size, FLAGS.image_size],
191logits_kernel_size=FLAGS.logits_kernel_size,
192model_variant=FLAGS.model_variant)
193
194if FLAGS.output_mode == 'segment':
195predictions = tf.squeeze(
196tf.cast(predictions[FLAGS.output_mode], tf.int32))
197probs = probs[FLAGS.output_mode]
198
199labels = tf.squeeze(tf.cast(samples['label'], tf.int32))
200weights = tf.cast(
201tf.not_equal(
202labels,
203model_input.dataset_descriptors[FLAGS.dataset].ignore_label),
204tf.int32)
205
206labels *= weights
207predictions *= weights
208
209tf.train.get_or_create_global_step()
210saver = tf.train.Saver(contrib_slim.get_variables_to_restore())
211sv = tf.train.Supervisor(
212graph=g,
213logdir=FLAGS.vis_logdir,
214init_op=tf.global_variables_initializer(),
215summary_op=None,
216summary_writer=None,
217global_step=None,
218saver=saver)
219num_batches = int(math.ceil(num_vis_examples / float(FLAGS.batch_size)))
220last_checkpoint = None
221
222# Infinite loop to visualize the results when new checkpoint is created.
223while True:
224last_checkpoint = contrib_slim.evaluation.wait_for_new_checkpoint(
225FLAGS.checkpoint_dir, last_checkpoint)
226start = time.time()
227print('Starting visualization at ' +
228time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
229print('Visualizing with model %s', last_checkpoint)
230
231print('Visualizing with model ', last_checkpoint)
232
233with sv.managed_session(
234FLAGS.master, start_standard_services=False) as sess:
235# sv.start_queue_runners(sess)
236sv.saver.restore(sess, last_checkpoint)
237
238image_id_offset = 0
239refs = []
240for batch in range(num_batches):
241print('Visualizing batch', batch + 1, num_batches)
242refs.extend(
243_process_batch(
244sess=sess,
245samples=samples,
246semantic_predictions=predictions,
247labels=labels,
248image_id_offset=image_id_offset,
249save_dir=save_dir))
250image_id_offset += FLAGS.batch_size
251
252print('Finished visualization at ' +
253time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
254time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
255if time_to_next_eval > 0:
256time.sleep(time_to_next_eval)
257
258
259if __name__ == '__main__':
260flags.mark_flag_as_required('checkpoint_dir')
261flags.mark_flag_as_required('vis_logdir')
262app.run()
263