google-research
440 строк · 14.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Input pipeline."""
17
18import collections19import os20from deeplab import feature_extractor21from deeplab import preprocess_utils22import tensorflow.compat.v1 as tf23
24LABEL_ID = "label"25IMAGE_ID = "image"26REF_EXP_ID = "ref_exp"27ELEMENTS_TEXT_ID = "elements_text"28ELEMENTS_NEIGHBORS_ID = "elements_neighbors"29ELEMENTS_REF_MATCH_ID = "elements_ref_match"30ELEMENTS_BOX_ID = "elements_box"31ELEMENTS_EXIST_ID = "elements_exists"32ELEMENTS_MASK_ID = "elements_mask"33IMAGE_PAD_WEIGHTS_ID = "image_pad_weights"34ELEMENTS_TYPE_ID = "elements_type"35SELECTED_CANDIDATE_ID = "selected_candidate"36
37GROUNDTRUTH_XMIN_ID = "groundTruth/bbox/xmin"38GROUNDTRUTH_XMAX_ID = "groundTruth/bbox/xmax"39GROUNDTRUTH_YMIN_ID = "groundTruth/bbox/ymin"40GROUNDTRUTH_YMAX_ID = "groundTruth/bbox/ymax"41
42# Information that changes from dataset to dataset.
43DatasetDescriptor = collections.namedtuple(44"DatasetDescriptor",45[46"subfolder",47"num_classes", # Number of semantic classes.48"ignore_label", # Ignore label value.49"label_id",50"image_id",51"elements_box_id",52"elements_text_id",53"is_tfrecord",54"has_candidate",55"has_elements_boxes",56])57
58
59def get_resize_dim(width, height, image_size):60"""Calculates the size of each dimension so the aspect ratio is not changed.61
62Args:
63width: The original width of the image.
64height: The original height of the image.
65image_size: The desired max image size.
66
67Returns:
68A tuple of resized dimensions, (resized_width, resized_height).
69"""
70max_ = tf.maximum(width, height)71ratio = tf.to_float(max_) / float(image_size)72
73new_width = tf.to_float(width) / ratio74new_height = tf.to_float(height) / ratio75
76return tf.to_int32(new_width), tf.to_int32(new_height)77
78
79def resize_im(image, image_size, pad_val, channels, features=None):80"""Decodes and resizes the image.81
82Args:
83image: Image to resize.
84image_size: The desired max image size.
85pad_val: The value to pad with.
86channels: The number of channels in the image.
87features: Other features to resize.
88
89Returns:
90Resized image with possible padded regions,
91and possibly the resized elements boxes.
92"""
93[height, width, got_channels] = preprocess_utils.resolve_shape(image, rank=3)94
95new_height, new_width = get_resize_dim(height, width, image_size)96
97image = tf.reshape(image, [height, width, -1])98image = tf.cond(99tf.logical_and(channels == 3, tf.equal(got_channels, 1)),100true_fn=lambda: tf.image.grayscale_to_rgb(image),101false_fn=lambda: image,102)103
104image = tf.image.resize_images(image, [new_height, new_width])105
106image = preprocess_utils.pad_to_bounding_box(image, 0, 0, image_size,107image_size, pad_val)108if features is not None:109width, height = tf.to_float(width), tf.to_float(height)110max_dim = tf.to_float(tf.maximum(width, height))111features[ELEMENTS_BOX_ID] = features[ELEMENTS_BOX_ID] / max_dim112if GROUNDTRUTH_XMIN_ID in features:113features[GROUNDTRUTH_XMIN_ID] *= width / max_dim114features[GROUNDTRUTH_XMAX_ID] *= width / max_dim115features[GROUNDTRUTH_YMIN_ID] *= height / max_dim116features[GROUNDTRUTH_YMAX_ID] *= height / max_dim117return image118
119
120def assert_or_warn(condition, message, is_assert):121"""Errors or prints a warning when the condition is met."""122if is_assert:123return tf.Assert(condition, message)124else:125return tf.cond(condition, lambda: condition,126lambda: tf.Print(condition, message))127
128
129refer_descriptor = DatasetDescriptor(130subfolder="",131num_classes=2,132ignore_label=255,133label_id="mask",134image_id=IMAGE_ID,135elements_text_id=ELEMENTS_TEXT_ID,136elements_box_id=ELEMENTS_BOX_ID,137is_tfrecord=True,138has_candidate=True,139has_elements_boxes=True,140)
141
142dataset_descriptors = {143"default":144DatasetDescriptor(145subfolder="",146num_classes=2,147ignore_label=255,148label_id="mask",149image_id=IMAGE_ID,150elements_text_id=ELEMENTS_TEXT_ID,151elements_box_id=ELEMENTS_BOX_ID,152is_tfrecord=True,153has_candidate=False,154has_elements_boxes=True,155)156}
157
158
159def convert_string_neighbors(string_neighbors):160split = tf.string_split(string_neighbors, "")161string_dense = tf.sparse_tensor_to_dense(split, default_value="0")162num = tf.string_to_number(string_dense, out_type=tf.int32)163bool_neigh = tf.cast(num, tf.bool)164return bool_neigh165
166
167def input_fn_dataset(dataset, flags):168"""Gets the model input from the given dataset."""169features = {}170dataset_descriptor = dataset_descriptors[flags.dataset]171
172def process_label(label):173"""Preprocesses the label."""174label = tf.image.decode_image(label, channels=1)175ignore_label = 255176label = tf.cast(label, tf.int32)177
178if flags.preprocess_divide_label:179label /= 255180
181label = resize_im(label, flags.image_size, ignore_label, 1)182label = tf.cast(label, tf.int32)183return label184
185def _parse_function(*args):186"""Parses the tf example."""187serialized_example = args[-1]188
189context_feature_names = {190dataset_descriptor.image_id: tf.FixedLenFeature([], tf.string),191}192sequence_feature_names = {}193if flags.use_ref_exp:194context_feature_names[REF_EXP_ID] = tf.FixedLenFeature([], tf.string)195
196if flags.use_labels:197if dataset_descriptor.has_candidate:198context_feature_names[SELECTED_CANDIDATE_ID] = tf.FixedLenFeature(199[], tf.int64)200sequence_feature_names[ELEMENTS_MASK_ID] = tf.FixedLenSequenceFeature(201[], tf.string)202else:203context_feature_names[dataset_descriptor.label_id] = tf.FixedLenFeature(204[], tf.string)205
206if dataset_descriptor.has_elements_boxes:207sequence_feature_names[208dataset_descriptor.elements_box_id] = tf.FixedLenSequenceFeature(209[4], dtype=tf.float32)210if flags.use_elements_texts:211sequence_feature_names[212dataset_descriptor.elements_text_id] = tf.FixedLenSequenceFeature(213[], dtype=tf.string)214if flags.use_elements_neighbors:215sequence_feature_names[216ELEMENTS_NEIGHBORS_ID] = tf.FixedLenSequenceFeature(217[], dtype=tf.string)218if flags.use_elements_ref_match:219sequence_feature_names[220ELEMENTS_REF_MATCH_ID] = tf.FixedLenSequenceFeature(221[], dtype=tf.string)222
223if flags.use_groundtruth_box:224context_feature_names[GROUNDTRUTH_XMIN_ID] = tf.FixedLenFeature(225[], tf.float32)226context_feature_names[GROUNDTRUTH_XMAX_ID] = tf.FixedLenFeature(227[], tf.float32)228context_feature_names[GROUNDTRUTH_YMIN_ID] = tf.FixedLenFeature(229[], tf.float32)230context_feature_names[GROUNDTRUTH_YMAX_ID] = tf.FixedLenFeature(231[], tf.float32)232
233context_features, sequence_features = tf.parse_single_sequence_example(234serialized_example,235context_features=context_feature_names,236sequence_features=sequence_feature_names,237)238
239features.update(context_features)240features.update(sequence_features)241
242if flags.use_elements_texts:243features[ELEMENTS_TEXT_ID] = features.pop(244dataset_descriptor.elements_text_id)245if dataset_descriptor.has_elements_boxes:246features[ELEMENTS_BOX_ID] = features.pop(247dataset_descriptor.elements_box_id)248
249image = features.pop(dataset_descriptor.image_id)250image = tf.image.decode_image(image, channels=3)251
252image = tf.cast(image, tf.float32)253mean_pixel = tf.reshape(254feature_extractor.mean_pixel(flags.model_variant), [1, 1, 3])255
256features[IMAGE_PAD_WEIGHTS_ID] = tf.ones_like(image[:, :, 0:1])257features[IMAGE_PAD_WEIGHTS_ID] = resize_im(features[IMAGE_PAD_WEIGHTS_ID],258flags.image_size, 0, 1)259features[IMAGE_PAD_WEIGHTS_ID] = tf.squeeze(features[IMAGE_PAD_WEIGHTS_ID],2602)261
262if dataset_descriptor.has_elements_boxes:263image = resize_im(image, flags.image_size, mean_pixel, 3, features)264else:265image = resize_im(image, flags.image_size, mean_pixel, 3)266
267if flags.use_labels:268if dataset_descriptor.has_candidate:269features[ELEMENTS_MASK_ID] = tf.map_fn(270process_label,271features.pop(ELEMENTS_MASK_ID),272parallel_iterations=128,273dtype=tf.int32,274name="mask_map")275features[LABEL_ID] = tf.gather_nd(features[ELEMENTS_MASK_ID],276[features[SELECTED_CANDIDATE_ID]])277else:278label = features.pop(dataset_descriptor.label_id)279label = process_label(label)280features[LABEL_ID] = label281
282if flags.use_elements_texts:283features[ELEMENTS_EXIST_ID] = tf.ones_like(284features[ELEMENTS_TEXT_ID], dtype=tf.int32)285elif dataset_descriptor.has_elements_boxes:286features[ELEMENTS_EXIST_ID] = tf.ones(287tf.shape(features[ELEMENTS_BOX_ID])[:1], dtype=tf.int32)288
289if flags.use_elements_neighbors:290features[ELEMENTS_NEIGHBORS_ID] = convert_string_neighbors(291features[ELEMENTS_NEIGHBORS_ID])292
293features[IMAGE_ID] = image294
295return features296
297dataset = dataset.map(298_parse_function,299num_parallel_calls=flags.dataset_threads).prefetch(flags.batch_size)300
301padded_shapes = {302IMAGE_ID: [None, None, None],303}304if flags.use_labels:305padded_shapes[LABEL_ID] = [None, None, None]306if flags.use_groundtruth_box:307padded_shapes[GROUNDTRUTH_XMIN_ID] = []308padded_shapes[GROUNDTRUTH_XMAX_ID] = []309padded_shapes[GROUNDTRUTH_YMIN_ID] = []310padded_shapes[GROUNDTRUTH_YMAX_ID] = []311if flags.use_elements_texts:312padded_shapes[ELEMENTS_TEXT_ID] = [None]313padded_shapes[ELEMENTS_EXIST_ID] = [None]314if dataset_descriptor.has_elements_boxes:315padded_shapes[ELEMENTS_BOX_ID] = [None, None]316padded_shapes[ELEMENTS_EXIST_ID] = [None]317if flags.use_elements_neighbors:318padded_shapes[ELEMENTS_NEIGHBORS_ID] = [None, None]319if flags.use_elements_ref_match:320padded_shapes[ELEMENTS_REF_MATCH_ID] = [None]321
322padded_shapes[IMAGE_PAD_WEIGHTS_ID] = [None, None]323
324if flags.use_ref_exp:325padded_shapes.update({326REF_EXP_ID: [],327})328if dataset_descriptor.has_candidate:329padded_shapes.update({330SELECTED_CANDIDATE_ID: [],331ELEMENTS_MASK_ID: [None, None, None, None],332})333
334dataset = dataset.padded_batch(flags.batch_size, padded_shapes=padded_shapes)335dataset = dataset.prefetch(1)336
337try:338iterator = dataset.make_one_shot_iterator()339feature_map = iterator.get_next()340except ValueError:341# This means the input pipeline uses placeholders probably because it's in342# inference mode.343feature_map = tf.contrib.data.get_single_element(dataset)344
345feature_map[IMAGE_ID] = tf.reshape(346feature_map[IMAGE_ID], [-1, flags.image_size, flags.image_size, 3])347
348assert_ops = []349if dataset_descriptor.has_elements_boxes:350assert_ops.append(351assert_or_warn(352tf.greater_equal(353tf.reduce_min(feature_map[ELEMENTS_BOX_ID]), -.001), [354"Bounding box is negative",355tf.reduce_min(feature_map[ELEMENTS_BOX_ID])356], flags.incorrect_boxes_as_errors))357
358assert_ops.append(359assert_or_warn(360tf.less_equal(361tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 0] +362feature_map[ELEMENTS_BOX_ID][:, :, 2]), 1.001),363[364"Bounding box x dim is too large.",365tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 0] +366feature_map[ELEMENTS_BOX_ID][:, :, 2])367], flags.incorrect_boxes_as_errors))368
369assert_ops.append(370assert_or_warn(371tf.less_equal(372tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 1] +373feature_map[ELEMENTS_BOX_ID][:, :, 3]), 1.001),374[375"Bounding box y dim is too large.",376tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 1] +377feature_map[ELEMENTS_BOX_ID][:, :, 3])378], flags.incorrect_boxes_as_errors))379
380with tf.control_dependencies(assert_ops):381if dataset_descriptor.has_elements_boxes:382feature_map[ELEMENTS_BOX_ID].set_shape([None, None, 4])383feature_map[ELEMENTS_EXIST_ID] = tf.cast(feature_map[ELEMENTS_EXIST_ID],384tf.bool)385if flags.use_labels:386if flags.output_mode == "segment" or flags.output_mode == "regression":387feature_map[LABEL_ID] = tf.reshape(388feature_map[LABEL_ID], [-1, flags.image_size, flags.image_size, 1])389return feature_map390
391
392def get_input_fn(flags):393"""Returns input_fn."""394
395def input_fn():396"""Reads the input features from files."""397dataset_descriptor = dataset_descriptors[flags.dataset]398
399with tf.variable_scope("input"):400pattern = os.path.join(401os.path.join(flags.dataset_dir, dataset_descriptor.subfolder),402flags.split + "*")403print "Pattern", pattern404dataset = tf.data.Dataset.list_files(pattern)405
406dataset = dataset.shuffle(buffer_size=flags.file_shuffle_buffer_size)407dataset = dataset.repeat()408
409def prefetch_map_fn(filename):410if dataset_descriptor.is_tfrecord:411return tf.data.TFRecordDataset(filename).prefetch(flags.batch_size)412else:413return tf.data.SSTableDataset(filename).prefetch(flags.batch_size)414
415dataset = dataset.interleave(416prefetch_map_fn, cycle_length=100, block_length=flags.batch_size)417
418print "shuffle buffer size", flags.shuffle_buffer_size419dataset = dataset.shuffle(buffer_size=flags.shuffle_buffer_size)420
421return input_fn_dataset(dataset, flags)422
423return input_fn424
425
426def get_serving_input_receiver_fn(flags):427"""Returns serving_input_receiver_fn."""428
429def serving_input_receiver_fn():430"""Used for exporting the model. Expects a serialized tf.Example."""431serialized_tf_example = tf.placeholder(432dtype=tf.string, shape=[None], name="input")433receiver_tensors = serialized_tf_example434dataset = tf.data.Dataset.from_tensor_slices(serialized_tf_example)435
436features = input_fn_dataset(dataset, flags)437
438return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)439
440return serving_input_receiver_fn441