google-research

Форк
0
/
input_pipeline.py 
313 строк · 11.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities to load and save data."""
17

18
import functools
19
import json
20
import os
21
import random
22
from typing import Optional
23

24
from layout_gvt_public.datasets import datasets_info
25
import numpy as np
26
import tensorflow as tf
27

28

29
# Default discretization values.
30
DEFAULT_RESOLUTION_WIDTH = 32
31
DEFAULT_RESOLUTION_HEIGHT = 32
32

33

34
def _clamp(value):
35
  """Truncates `value` to the [0, 1] range."""
36
  return np.clip(value, 0.0, 1.0)
37

38

39
def _normalize_entries(documents, shuffle=False):
40
  """Normalizes the bounding box annotations to the [0, 1] range.
41

42
  Args:
43
    documents: A sequence of document entries as a dictionary with the
44
      'children' key containing a sequence of bounding boxes.
45
    shuffle: random shuffle objects or not.
46

47
  Returns:
48
    The same sequence as `documents` with normalized coordinates according to
49
    the document width and height. Order is preserved.
50
  """
51
  normalized_documents = []
52

53
  def _normalize_entry(element, document_width, document_height):
54
    return {
55
        **element, "center": [
56
            _clamp(element["center"][0] / document_width),
57
            _clamp(element["center"][1] / document_height)
58
        ],
59
        "width": _clamp(element["width"] / document_width),
60
        "height": _clamp(element["height"] / document_height)
61
    }
62

63
  for document in documents:
64
    children = document["children"]
65
    document_width = float(document["width"])
66
    document_height = float(document["height"])
67
    normalize_fn = functools.partial(
68
        _normalize_entry,
69
        document_width=document_width,
70
        document_height=document_height)
71
    normalized_children = [normalize_fn(c) for c in children]
72
    if shuffle:
73
      random.Random(0).shuffle(normalized_children)
74
      # normalized_children = normalized_children
75
    else:
76
      normalized_children = sorted(
77
          normalized_children, key=lambda e: (e["center"][1], e["center"][0]))
78

79
    normalized_document = {**document, "children": normalized_children}
80
    normalized_documents.append(normalized_document)
81

82
  return normalized_documents
83

84

85
def get_dataset(batch_size,
86
                dataset_folder,
87
                n_devices,
88
                ds_file,
89
                max_length,
90
                add_bos=True,
91
                dataset_name="RICO",
92
                shuffle=False):
93
  """Obtain dataset from preprocessed json data.
94

95
  Args:
96
    batch_size: Number of samples in one batch.
97
    dataset_folder: Dataset folder.
98
    n_devices: Number of devices we use.
99
    ds_file: The data file name.
100
    max_length: the maximum length of input sequence.
101
    add_bos: Whether add bos and eos to the input sequence.
102
    dataset_name: The name of dataset.
103
    shuffle: Shuffle objects or not.
104

105
  Returns:
106
    One tf dataset.
107
    Vocab size in this dataset.
108
  """
109
  assert batch_size % n_devices == 0
110
  ds_path = os.path.join(dataset_folder, ds_file)
111
  # shuffle = True if "train" not in ds_file else shuffle
112
  dataset = LayoutDataset(dataset_name, ds_path, add_bos, shuffle)
113

114
  class_range = [dataset.offset_class, dataset.number_classes]
115
  center_x_range = [dataset.offset_center_x, dataset.resolution_w]
116
  center_y_range = [dataset.offset_center_y, dataset.resolution_h]
117
  width_range = [dataset.offset_width, dataset.resolution_w]
118
  height_range = [dataset.offset_height, dataset.resolution_h]
119
  pos_info = [
120
      class_range, width_range, height_range, center_x_range, center_y_range
121
  ]
122
  ds = dataset.setup_tf_dataset(
123
      batch_size, max_length, group_data_by_size=False)
124
  vocab_size = dataset.get_vocab_size()
125
  return ds, vocab_size, pos_info
126

127

128
def get_all_dataset(batch_size,
129
                    dataset_folder,
130
                    n_devices,
131
                    add_bos,
132
                    max_length,
133
                    dataset_name="RICO",
134
                    shuffle=False):
135
  """Creates datasets for various splits, such as train, valid and test.
136

137
  Args:
138
    batch_size: batch size of dataset loader.
139
    dataset_folder: path of dataset.
140
    n_devices: how many devices we will train our model on.
141
    add_bos: whether to add bos to the input sequence.
142
    max_length: the maximum length of input sequence.
143
    dataset_name: the name of the input dataset.
144
    shuffle: shuffle objects or not.
145
  Returns:
146
    datasets for various splits, the size of vocabulary and asset information.
147
  """
148
  train_ds, vocab_size, pos_info = get_dataset(batch_size, dataset_folder,
149
                                               n_devices, "train.json",
150
                                               max_length,
151
                                               add_bos,
152
                                               dataset_name,
153
                                               shuffle)
154
  eval_ds, _, _ = get_dataset(batch_size, dataset_folder, n_devices, "val.json",
155
                              max_length, add_bos, dataset_name, shuffle=False)
156
  test_ds, _, _ = get_dataset(batch_size, dataset_folder, n_devices,
157
                              "test.json", max_length, add_bos, dataset_name,
158
                              shuffle=False)
159
  return train_ds, eval_ds, test_ds, vocab_size, pos_info
160

161

162
class LayoutDataset:
163
  """Dataset for layout generation."""
164

165
  def __init__(self,
166
               dataset_name,
167
               path,
168
               add_bos = True,
169
               shuffle = False,
170
               resolution_w = DEFAULT_RESOLUTION_WIDTH,
171
               resolution_h = DEFAULT_RESOLUTION_HEIGHT,
172
               limit = 22):
173
    """Sets up the dataset instance, and computes the vocabulary.
174

175
    Args:
176
      dataset_name: The name of the input dataset, such as RICO.
177
      path: Path to the json file with the data. Raises ValueError if the data
178
        is faulty.
179
      add_bos: Whether add bos and eos to the input sequence.
180
      shuffle: shuffle objects or not.
181
      resolution_w: Discretization resolution to use for x and width
182
        coordinates.
183
      resolution_h: Discretization resolution to use for h and height
184
        coordinates.
185
      limit: Maximum amount of element in a layout.
186
    """
187
    with open(path, "r") as f:
188
      data = json.load(f)
189
    self.add_bos = add_bos
190
    self.dataset_name = datasets_info.DatasetName(dataset_name)
191
    self.data = _normalize_entries(data, shuffle)
192
    self.number_classes = datasets_info.get_number_classes(self.dataset_name)
193
    self.id_to_label = datasets_info.get_id_to_label_map(self.dataset_name)
194
    self.pad_idx, self.bos_idx, self.eos_idx = 0, 1, 2
195

196
    self.resolution_w = resolution_w
197
    self.resolution_h = resolution_h
198
    # ids of pad, bos and eos, unk are 0, 1, 2, 3, so we start from 4.
199
    self.offset_class = 4
200
    self.offset_center_x = self.offset_class + self.number_classes
201
    self.offset_center_y = self.offset_center_x + self.resolution_w
202

203
    self.offset_width = self.offset_center_y + self.resolution_w
204
    self.offset_height = self.offset_width + self.resolution_h
205
    self.limit = limit
206
    self.shuffle = shuffle
207

208
  def get_vocab_size(self):
209
    # Special symbols + num_classes +
210
    # all possible number of x, y, width and height positions.
211
    return self.offset_class + self.number_classes + (self.resolution_w +
212
                                                      self.resolution_h) * 2
213

214
  def _convert_entry_to_model_format(self, entries):
215
    """Converts a dataset entry to one sequence.
216

217
    E.g.:
218
    --> [BOS, entry1_pos, entry2_pos, ..., EOS]
219
    --> entry1_pos = class_id, center_x, center_y, width, height
220

221
    Args:
222
      entries: The sequence of bounding boxes to parse.
223

224
    Returns:
225
      One numpy array which contains positions of all items in the input entry.
226
      The first token and the last one is BOS and EOS symbols.
227
      Following previous works, we discrete the positon information according to
228
      resolution_w  and resolution_h.
229
    """
230
    processed_entry = []
231
    for box in entries[:self.limit]:
232
      category_id = box["category_id"]
233
      center = box["center"]
234
      width = box["width"]
235
      height = box["height"]
236
      class_id = category_id + self.offset_class
237
      discrete_x = round(center[0] *
238
                         (self.resolution_w - 1)) + self.offset_center_x
239
      discrete_y = round(center[1] *
240
                         (self.resolution_h - 1)) + self.offset_center_y
241
      # Clip the width and height of assets at least 1.
242
      discrete_width = round(
243
          np.clip(width * (self.resolution_w - 1), 1.,
244
                  self.resolution_w - 1)) + self.offset_width
245
      discrete_height = round(
246
          np.clip(height * (self.resolution_h - 1), 1.,
247
                  self.resolution_h - 1)) + self.offset_height
248
      processed_entry.extend(
249
          [class_id, discrete_width, discrete_height, discrete_x, discrete_y])
250
    if self.add_bos:
251
      # add bos and eos to the input seq
252
      processed_entry = [self.bos_idx] + processed_entry + [self.eos_idx]
253
    return np.array(processed_entry, dtype=np.int32)
254

255
  def boxes_iterator(self,):
256
    """Reads the dataset and produces an generator with each preprocessed entry.
257

258
    Yields:
259
      Preprocessed entry format for VTN model.
260
    """
261
    if self.shuffle:
262
      data = list(self.data)
263
      random.shuffle(data)
264
    else:
265
      data = self.data
266

267
    for entry in data:
268
      # if (max_number_elements is not None and
269
      #     len(entry["children"]) > max_number_elements):
270
      #   continue
271
      # # sorted_entry = sorted(
272
      # #     entry["children"], key=lambda e: (e["center"][1], e["center"][0]))
273
      # sorted_entry = entry["children"]
274
      # # random.shuffle(sorted_entry)
275
      # inputs = self._convert_entry_to_model_format(sorted_entry)
276
      inputs = self._convert_entry_to_model_format(entry["children"])
277
      yield inputs
278

279
  def setup_tf_dataset(
280
      self,
281
      batch_size,
282
      max_length,
283
      group_data_by_size):
284
    """Instantiates a tf.data.Dataset from a `dataset_parser`.
285

286
    Args:
287
      batch_size: The dataset instance is batched using this value.
288
      max_length: The maximum length of input sequence.
289
      group_data_by_size: If true, the data is batched grouping entries by their
290
        sequence length.
291

292
    Returns:
293
      An initialized dataset instance containing the same data as
294
      `dataset_parser`.
295
    """
296
    bucket_boundaries = (32, 52, 72, 102, 132, 152, 172, 192)
297
    dataset = tf.data.Dataset.from_generator(
298
        functools.partial(self.boxes_iterator),
299
        output_types=tf.int32, output_shapes=tf.TensorShape([None]))
300
    if group_data_by_size:
301
      dataset = dataset.apply(
302
          tf.data.experimental.bucket_by_sequence_length(
303
              element_length_func=lambda x: tf.shape(x)[0],
304
              bucket_boundaries=bucket_boundaries,
305
              bucket_batch_sizes=[
306
                  batch_size for _ in range(len(bucket_boundaries) + 1)
307
              ]))
308
    else:
309
      dataset = dataset.padded_batch(
310
          batch_size,
311
          padding_values=self.pad_idx,
312
          padded_shapes=max_length)
313
    return dataset.prefetch(tf.data.experimental.AUTOTUNE)
314

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.