google-research

Форк
0
440 строк · 14.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Input pipeline."""
17

18
import collections
19
import os
20
from deeplab import feature_extractor
21
from deeplab import preprocess_utils
22
import tensorflow.compat.v1 as tf
23

24
LABEL_ID = "label"
25
IMAGE_ID = "image"
26
REF_EXP_ID = "ref_exp"
27
ELEMENTS_TEXT_ID = "elements_text"
28
ELEMENTS_NEIGHBORS_ID = "elements_neighbors"
29
ELEMENTS_REF_MATCH_ID = "elements_ref_match"
30
ELEMENTS_BOX_ID = "elements_box"
31
ELEMENTS_EXIST_ID = "elements_exists"
32
ELEMENTS_MASK_ID = "elements_mask"
33
IMAGE_PAD_WEIGHTS_ID = "image_pad_weights"
34
ELEMENTS_TYPE_ID = "elements_type"
35
SELECTED_CANDIDATE_ID = "selected_candidate"
36

37
GROUNDTRUTH_XMIN_ID = "groundTruth/bbox/xmin"
38
GROUNDTRUTH_XMAX_ID = "groundTruth/bbox/xmax"
39
GROUNDTRUTH_YMIN_ID = "groundTruth/bbox/ymin"
40
GROUNDTRUTH_YMAX_ID = "groundTruth/bbox/ymax"
41

42
# Information that changes from dataset to dataset.
43
DatasetDescriptor = collections.namedtuple(
44
    "DatasetDescriptor",
45
    [
46
        "subfolder",
47
        "num_classes",  # Number of semantic classes.
48
        "ignore_label",  # Ignore label value.
49
        "label_id",
50
        "image_id",
51
        "elements_box_id",
52
        "elements_text_id",
53
        "is_tfrecord",
54
        "has_candidate",
55
        "has_elements_boxes",
56
    ])
57

58

59
def get_resize_dim(width, height, image_size):
60
  """Calculates the size of each dimension so the aspect ratio is not changed.
61

62
  Args:
63
    width: The original width of the image.
64
    height: The original height of the image.
65
    image_size: The desired max image size.
66

67
  Returns:
68
    A tuple of resized dimensions, (resized_width, resized_height).
69
  """
70
  max_ = tf.maximum(width, height)
71
  ratio = tf.to_float(max_) / float(image_size)
72

73
  new_width = tf.to_float(width) / ratio
74
  new_height = tf.to_float(height) / ratio
75

76
  return tf.to_int32(new_width), tf.to_int32(new_height)
77

78

79
def resize_im(image, image_size, pad_val, channels, features=None):
80
  """Decodes and resizes the image.
81

82
  Args:
83
    image: Image to resize.
84
    image_size: The desired max image size.
85
    pad_val: The value to pad with.
86
    channels: The number of channels in the image.
87
    features: Other features to resize.
88

89
  Returns:
90
    Resized image with possible padded regions,
91
    and possibly the resized elements boxes.
92
  """
93
  [height, width, got_channels] = preprocess_utils.resolve_shape(image, rank=3)
94

95
  new_height, new_width = get_resize_dim(height, width, image_size)
96

97
  image = tf.reshape(image, [height, width, -1])
98
  image = tf.cond(
99
      tf.logical_and(channels == 3, tf.equal(got_channels, 1)),
100
      true_fn=lambda: tf.image.grayscale_to_rgb(image),
101
      false_fn=lambda: image,
102
  )
103

104
  image = tf.image.resize_images(image, [new_height, new_width])
105

106
  image = preprocess_utils.pad_to_bounding_box(image, 0, 0, image_size,
107
                                               image_size, pad_val)
108
  if features is not None:
109
    width, height = tf.to_float(width), tf.to_float(height)
110
    max_dim = tf.to_float(tf.maximum(width, height))
111
    features[ELEMENTS_BOX_ID] = features[ELEMENTS_BOX_ID] / max_dim
112
    if GROUNDTRUTH_XMIN_ID in features:
113
      features[GROUNDTRUTH_XMIN_ID] *= width / max_dim
114
      features[GROUNDTRUTH_XMAX_ID] *= width / max_dim
115
      features[GROUNDTRUTH_YMIN_ID] *= height / max_dim
116
      features[GROUNDTRUTH_YMAX_ID] *= height / max_dim
117
  return image
118

119

120
def assert_or_warn(condition, message, is_assert):
121
  """Errors or prints a warning when the condition is met."""
122
  if is_assert:
123
    return tf.Assert(condition, message)
124
  else:
125
    return tf.cond(condition, lambda: condition,
126
                   lambda: tf.Print(condition, message))
127

128

129
refer_descriptor = DatasetDescriptor(
130
    subfolder="",
131
    num_classes=2,
132
    ignore_label=255,
133
    label_id="mask",
134
    image_id=IMAGE_ID,
135
    elements_text_id=ELEMENTS_TEXT_ID,
136
    elements_box_id=ELEMENTS_BOX_ID,
137
    is_tfrecord=True,
138
    has_candidate=True,
139
    has_elements_boxes=True,
140
)
141

142
dataset_descriptors = {
143
    "default":
144
        DatasetDescriptor(
145
            subfolder="",
146
            num_classes=2,
147
            ignore_label=255,
148
            label_id="mask",
149
            image_id=IMAGE_ID,
150
            elements_text_id=ELEMENTS_TEXT_ID,
151
            elements_box_id=ELEMENTS_BOX_ID,
152
            is_tfrecord=True,
153
            has_candidate=False,
154
            has_elements_boxes=True,
155
        )
156
}
157

158

159
def convert_string_neighbors(string_neighbors):
160
  split = tf.string_split(string_neighbors, "")
161
  string_dense = tf.sparse_tensor_to_dense(split, default_value="0")
162
  num = tf.string_to_number(string_dense, out_type=tf.int32)
163
  bool_neigh = tf.cast(num, tf.bool)
164
  return bool_neigh
165

166

167
def input_fn_dataset(dataset, flags):
168
  """Gets the model input from the given dataset."""
169
  features = {}
170
  dataset_descriptor = dataset_descriptors[flags.dataset]
171

172
  def process_label(label):
173
    """Preprocesses the label."""
174
    label = tf.image.decode_image(label, channels=1)
175
    ignore_label = 255
176
    label = tf.cast(label, tf.int32)
177

178
    if flags.preprocess_divide_label:
179
      label /= 255
180

181
    label = resize_im(label, flags.image_size, ignore_label, 1)
182
    label = tf.cast(label, tf.int32)
183
    return label
184

185
  def _parse_function(*args):
186
    """Parses the tf example."""
187
    serialized_example = args[-1]
188

189
    context_feature_names = {
190
        dataset_descriptor.image_id: tf.FixedLenFeature([], tf.string),
191
    }
192
    sequence_feature_names = {}
193
    if flags.use_ref_exp:
194
      context_feature_names[REF_EXP_ID] = tf.FixedLenFeature([], tf.string)
195

196
    if flags.use_labels:
197
      if dataset_descriptor.has_candidate:
198
        context_feature_names[SELECTED_CANDIDATE_ID] = tf.FixedLenFeature(
199
            [], tf.int64)
200
        sequence_feature_names[ELEMENTS_MASK_ID] = tf.FixedLenSequenceFeature(
201
            [], tf.string)
202
      else:
203
        context_feature_names[dataset_descriptor.label_id] = tf.FixedLenFeature(
204
            [], tf.string)
205

206
    if dataset_descriptor.has_elements_boxes:
207
      sequence_feature_names[
208
          dataset_descriptor.elements_box_id] = tf.FixedLenSequenceFeature(
209
              [4], dtype=tf.float32)
210
    if flags.use_elements_texts:
211
      sequence_feature_names[
212
          dataset_descriptor.elements_text_id] = tf.FixedLenSequenceFeature(
213
              [], dtype=tf.string)
214
    if flags.use_elements_neighbors:
215
      sequence_feature_names[
216
          ELEMENTS_NEIGHBORS_ID] = tf.FixedLenSequenceFeature(
217
              [], dtype=tf.string)
218
    if flags.use_elements_ref_match:
219
      sequence_feature_names[
220
          ELEMENTS_REF_MATCH_ID] = tf.FixedLenSequenceFeature(
221
              [], dtype=tf.string)
222

223
    if flags.use_groundtruth_box:
224
      context_feature_names[GROUNDTRUTH_XMIN_ID] = tf.FixedLenFeature(
225
          [], tf.float32)
226
      context_feature_names[GROUNDTRUTH_XMAX_ID] = tf.FixedLenFeature(
227
          [], tf.float32)
228
      context_feature_names[GROUNDTRUTH_YMIN_ID] = tf.FixedLenFeature(
229
          [], tf.float32)
230
      context_feature_names[GROUNDTRUTH_YMAX_ID] = tf.FixedLenFeature(
231
          [], tf.float32)
232

233
    context_features, sequence_features = tf.parse_single_sequence_example(
234
        serialized_example,
235
        context_features=context_feature_names,
236
        sequence_features=sequence_feature_names,
237
    )
238

239
    features.update(context_features)
240
    features.update(sequence_features)
241

242
    if flags.use_elements_texts:
243
      features[ELEMENTS_TEXT_ID] = features.pop(
244
          dataset_descriptor.elements_text_id)
245
    if dataset_descriptor.has_elements_boxes:
246
      features[ELEMENTS_BOX_ID] = features.pop(
247
          dataset_descriptor.elements_box_id)
248

249
    image = features.pop(dataset_descriptor.image_id)
250
    image = tf.image.decode_image(image, channels=3)
251

252
    image = tf.cast(image, tf.float32)
253
    mean_pixel = tf.reshape(
254
        feature_extractor.mean_pixel(flags.model_variant), [1, 1, 3])
255

256
    features[IMAGE_PAD_WEIGHTS_ID] = tf.ones_like(image[:, :, 0:1])
257
    features[IMAGE_PAD_WEIGHTS_ID] = resize_im(features[IMAGE_PAD_WEIGHTS_ID],
258
                                               flags.image_size, 0, 1)
259
    features[IMAGE_PAD_WEIGHTS_ID] = tf.squeeze(features[IMAGE_PAD_WEIGHTS_ID],
260
                                                2)
261

262
    if dataset_descriptor.has_elements_boxes:
263
      image = resize_im(image, flags.image_size, mean_pixel, 3, features)
264
    else:
265
      image = resize_im(image, flags.image_size, mean_pixel, 3)
266

267
    if flags.use_labels:
268
      if dataset_descriptor.has_candidate:
269
        features[ELEMENTS_MASK_ID] = tf.map_fn(
270
            process_label,
271
            features.pop(ELEMENTS_MASK_ID),
272
            parallel_iterations=128,
273
            dtype=tf.int32,
274
            name="mask_map")
275
        features[LABEL_ID] = tf.gather_nd(features[ELEMENTS_MASK_ID],
276
                                          [features[SELECTED_CANDIDATE_ID]])
277
      else:
278
        label = features.pop(dataset_descriptor.label_id)
279
        label = process_label(label)
280
        features[LABEL_ID] = label
281

282
    if flags.use_elements_texts:
283
      features[ELEMENTS_EXIST_ID] = tf.ones_like(
284
          features[ELEMENTS_TEXT_ID], dtype=tf.int32)
285
    elif dataset_descriptor.has_elements_boxes:
286
      features[ELEMENTS_EXIST_ID] = tf.ones(
287
          tf.shape(features[ELEMENTS_BOX_ID])[:1], dtype=tf.int32)
288

289
    if flags.use_elements_neighbors:
290
      features[ELEMENTS_NEIGHBORS_ID] = convert_string_neighbors(
291
          features[ELEMENTS_NEIGHBORS_ID])
292

293
    features[IMAGE_ID] = image
294

295
    return features
296

297
  dataset = dataset.map(
298
      _parse_function,
299
      num_parallel_calls=flags.dataset_threads).prefetch(flags.batch_size)
300

301
  padded_shapes = {
302
      IMAGE_ID: [None, None, None],
303
  }
304
  if flags.use_labels:
305
    padded_shapes[LABEL_ID] = [None, None, None]
306
    if flags.use_groundtruth_box:
307
      padded_shapes[GROUNDTRUTH_XMIN_ID] = []
308
      padded_shapes[GROUNDTRUTH_XMAX_ID] = []
309
      padded_shapes[GROUNDTRUTH_YMIN_ID] = []
310
      padded_shapes[GROUNDTRUTH_YMAX_ID] = []
311
  if flags.use_elements_texts:
312
    padded_shapes[ELEMENTS_TEXT_ID] = [None]
313
    padded_shapes[ELEMENTS_EXIST_ID] = [None]
314
  if dataset_descriptor.has_elements_boxes:
315
    padded_shapes[ELEMENTS_BOX_ID] = [None, None]
316
    padded_shapes[ELEMENTS_EXIST_ID] = [None]
317
  if flags.use_elements_neighbors:
318
    padded_shapes[ELEMENTS_NEIGHBORS_ID] = [None, None]
319
  if flags.use_elements_ref_match:
320
    padded_shapes[ELEMENTS_REF_MATCH_ID] = [None]
321

322
  padded_shapes[IMAGE_PAD_WEIGHTS_ID] = [None, None]
323

324
  if flags.use_ref_exp:
325
    padded_shapes.update({
326
        REF_EXP_ID: [],
327
    })
328
  if dataset_descriptor.has_candidate:
329
    padded_shapes.update({
330
        SELECTED_CANDIDATE_ID: [],
331
        ELEMENTS_MASK_ID: [None, None, None, None],
332
    })
333

334
  dataset = dataset.padded_batch(flags.batch_size, padded_shapes=padded_shapes)
335
  dataset = dataset.prefetch(1)
336

337
  try:
338
    iterator = dataset.make_one_shot_iterator()
339
    feature_map = iterator.get_next()
340
  except ValueError:
341
    # This means the input pipeline uses placeholders probably because it's in
342
    # inference mode.
343
    feature_map = tf.contrib.data.get_single_element(dataset)
344

345
  feature_map[IMAGE_ID] = tf.reshape(
346
      feature_map[IMAGE_ID], [-1, flags.image_size, flags.image_size, 3])
347

348
  assert_ops = []
349
  if dataset_descriptor.has_elements_boxes:
350
    assert_ops.append(
351
        assert_or_warn(
352
            tf.greater_equal(
353
                tf.reduce_min(feature_map[ELEMENTS_BOX_ID]), -.001), [
354
                    "Bounding box is negative",
355
                    tf.reduce_min(feature_map[ELEMENTS_BOX_ID])
356
                ], flags.incorrect_boxes_as_errors))
357

358
    assert_ops.append(
359
        assert_or_warn(
360
            tf.less_equal(
361
                tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 0] +
362
                              feature_map[ELEMENTS_BOX_ID][:, :, 2]), 1.001),
363
            [
364
                "Bounding box x dim is too large.",
365
                tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 0] +
366
                              feature_map[ELEMENTS_BOX_ID][:, :, 2])
367
            ], flags.incorrect_boxes_as_errors))
368

369
    assert_ops.append(
370
        assert_or_warn(
371
            tf.less_equal(
372
                tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 1] +
373
                              feature_map[ELEMENTS_BOX_ID][:, :, 3]), 1.001),
374
            [
375
                "Bounding box y dim is too large.",
376
                tf.reduce_max(feature_map[ELEMENTS_BOX_ID][:, :, 1] +
377
                              feature_map[ELEMENTS_BOX_ID][:, :, 3])
378
            ], flags.incorrect_boxes_as_errors))
379

380
  with tf.control_dependencies(assert_ops):
381
    if dataset_descriptor.has_elements_boxes:
382
      feature_map[ELEMENTS_BOX_ID].set_shape([None, None, 4])
383
      feature_map[ELEMENTS_EXIST_ID] = tf.cast(feature_map[ELEMENTS_EXIST_ID],
384
                                               tf.bool)
385
    if flags.use_labels:
386
      if flags.output_mode == "segment" or flags.output_mode == "regression":
387
        feature_map[LABEL_ID] = tf.reshape(
388
            feature_map[LABEL_ID], [-1, flags.image_size, flags.image_size, 1])
389
  return feature_map
390

391

392
def get_input_fn(flags):
393
  """Returns input_fn."""
394

395
  def input_fn():
396
    """Reads the input features from files."""
397
    dataset_descriptor = dataset_descriptors[flags.dataset]
398

399
    with tf.variable_scope("input"):
400
      pattern = os.path.join(
401
          os.path.join(flags.dataset_dir, dataset_descriptor.subfolder),
402
          flags.split + "*")
403
      print "Pattern", pattern
404
      dataset = tf.data.Dataset.list_files(pattern)
405

406
      dataset = dataset.shuffle(buffer_size=flags.file_shuffle_buffer_size)
407
      dataset = dataset.repeat()
408

409
      def prefetch_map_fn(filename):
410
        if dataset_descriptor.is_tfrecord:
411
          return tf.data.TFRecordDataset(filename).prefetch(flags.batch_size)
412
        else:
413
          return tf.data.SSTableDataset(filename).prefetch(flags.batch_size)
414

415
      dataset = dataset.interleave(
416
          prefetch_map_fn, cycle_length=100, block_length=flags.batch_size)
417

418
      print "shuffle buffer size", flags.shuffle_buffer_size
419
      dataset = dataset.shuffle(buffer_size=flags.shuffle_buffer_size)
420

421
      return input_fn_dataset(dataset, flags)
422

423
  return input_fn
424

425

426
def get_serving_input_receiver_fn(flags):
427
  """Returns serving_input_receiver_fn."""
428

429
  def serving_input_receiver_fn():
430
    """Used for exporting the model. Expects a serialized tf.Example."""
431
    serialized_tf_example = tf.placeholder(
432
        dtype=tf.string, shape=[None], name="input")
433
    receiver_tensors = serialized_tf_example
434
    dataset = tf.data.Dataset.from_tensor_slices(serialized_tf_example)
435

436
    features = input_fn_dataset(dataset, flags)
437

438
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
439

440
  return serving_input_receiver_fn
441

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.