google-research
313 строк · 11.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities to load and save data."""
17
18import functools19import json20import os21import random22from typing import Optional23
24from layout_gvt_public.datasets import datasets_info25import numpy as np26import tensorflow as tf27
28
29# Default discretization values.
30DEFAULT_RESOLUTION_WIDTH = 3231DEFAULT_RESOLUTION_HEIGHT = 3232
33
34def _clamp(value):35"""Truncates `value` to the [0, 1] range."""36return np.clip(value, 0.0, 1.0)37
38
39def _normalize_entries(documents, shuffle=False):40"""Normalizes the bounding box annotations to the [0, 1] range.41
42Args:
43documents: A sequence of document entries as a dictionary with the
44'children' key containing a sequence of bounding boxes.
45shuffle: random shuffle objects or not.
46
47Returns:
48The same sequence as `documents` with normalized coordinates according to
49the document width and height. Order is preserved.
50"""
51normalized_documents = []52
53def _normalize_entry(element, document_width, document_height):54return {55**element, "center": [56_clamp(element["center"][0] / document_width),57_clamp(element["center"][1] / document_height)58],59"width": _clamp(element["width"] / document_width),60"height": _clamp(element["height"] / document_height)61}62
63for document in documents:64children = document["children"]65document_width = float(document["width"])66document_height = float(document["height"])67normalize_fn = functools.partial(68_normalize_entry,69document_width=document_width,70document_height=document_height)71normalized_children = [normalize_fn(c) for c in children]72if shuffle:73random.Random(0).shuffle(normalized_children)74# normalized_children = normalized_children75else:76normalized_children = sorted(77normalized_children, key=lambda e: (e["center"][1], e["center"][0]))78
79normalized_document = {**document, "children": normalized_children}80normalized_documents.append(normalized_document)81
82return normalized_documents83
84
85def get_dataset(batch_size,86dataset_folder,87n_devices,88ds_file,89max_length,90add_bos=True,91dataset_name="RICO",92shuffle=False):93"""Obtain dataset from preprocessed json data.94
95Args:
96batch_size: Number of samples in one batch.
97dataset_folder: Dataset folder.
98n_devices: Number of devices we use.
99ds_file: The data file name.
100max_length: the maximum length of input sequence.
101add_bos: Whether add bos and eos to the input sequence.
102dataset_name: The name of dataset.
103shuffle: Shuffle objects or not.
104
105Returns:
106One tf dataset.
107Vocab size in this dataset.
108"""
109assert batch_size % n_devices == 0110ds_path = os.path.join(dataset_folder, ds_file)111# shuffle = True if "train" not in ds_file else shuffle112dataset = LayoutDataset(dataset_name, ds_path, add_bos, shuffle)113
114class_range = [dataset.offset_class, dataset.number_classes]115center_x_range = [dataset.offset_center_x, dataset.resolution_w]116center_y_range = [dataset.offset_center_y, dataset.resolution_h]117width_range = [dataset.offset_width, dataset.resolution_w]118height_range = [dataset.offset_height, dataset.resolution_h]119pos_info = [120class_range, width_range, height_range, center_x_range, center_y_range121]122ds = dataset.setup_tf_dataset(123batch_size, max_length, group_data_by_size=False)124vocab_size = dataset.get_vocab_size()125return ds, vocab_size, pos_info126
127
128def get_all_dataset(batch_size,129dataset_folder,130n_devices,131add_bos,132max_length,133dataset_name="RICO",134shuffle=False):135"""Creates datasets for various splits, such as train, valid and test.136
137Args:
138batch_size: batch size of dataset loader.
139dataset_folder: path of dataset.
140n_devices: how many devices we will train our model on.
141add_bos: whether to add bos to the input sequence.
142max_length: the maximum length of input sequence.
143dataset_name: the name of the input dataset.
144shuffle: shuffle objects or not.
145Returns:
146datasets for various splits, the size of vocabulary and asset information.
147"""
148train_ds, vocab_size, pos_info = get_dataset(batch_size, dataset_folder,149n_devices, "train.json",150max_length,151add_bos,152dataset_name,153shuffle)154eval_ds, _, _ = get_dataset(batch_size, dataset_folder, n_devices, "val.json",155max_length, add_bos, dataset_name, shuffle=False)156test_ds, _, _ = get_dataset(batch_size, dataset_folder, n_devices,157"test.json", max_length, add_bos, dataset_name,158shuffle=False)159return train_ds, eval_ds, test_ds, vocab_size, pos_info160
161
162class LayoutDataset:163"""Dataset for layout generation."""164
165def __init__(self,166dataset_name,167path,168add_bos = True,169shuffle = False,170resolution_w = DEFAULT_RESOLUTION_WIDTH,171resolution_h = DEFAULT_RESOLUTION_HEIGHT,172limit = 22):173"""Sets up the dataset instance, and computes the vocabulary.174
175Args:
176dataset_name: The name of the input dataset, such as RICO.
177path: Path to the json file with the data. Raises ValueError if the data
178is faulty.
179add_bos: Whether add bos and eos to the input sequence.
180shuffle: shuffle objects or not.
181resolution_w: Discretization resolution to use for x and width
182coordinates.
183resolution_h: Discretization resolution to use for h and height
184coordinates.
185limit: Maximum amount of element in a layout.
186"""
187with open(path, "r") as f:188data = json.load(f)189self.add_bos = add_bos190self.dataset_name = datasets_info.DatasetName(dataset_name)191self.data = _normalize_entries(data, shuffle)192self.number_classes = datasets_info.get_number_classes(self.dataset_name)193self.id_to_label = datasets_info.get_id_to_label_map(self.dataset_name)194self.pad_idx, self.bos_idx, self.eos_idx = 0, 1, 2195
196self.resolution_w = resolution_w197self.resolution_h = resolution_h198# ids of pad, bos and eos, unk are 0, 1, 2, 3, so we start from 4.199self.offset_class = 4200self.offset_center_x = self.offset_class + self.number_classes201self.offset_center_y = self.offset_center_x + self.resolution_w202
203self.offset_width = self.offset_center_y + self.resolution_w204self.offset_height = self.offset_width + self.resolution_h205self.limit = limit206self.shuffle = shuffle207
208def get_vocab_size(self):209# Special symbols + num_classes +210# all possible number of x, y, width and height positions.211return self.offset_class + self.number_classes + (self.resolution_w +212self.resolution_h) * 2213
214def _convert_entry_to_model_format(self, entries):215"""Converts a dataset entry to one sequence.216
217E.g.:
218--> [BOS, entry1_pos, entry2_pos, ..., EOS]
219--> entry1_pos = class_id, center_x, center_y, width, height
220
221Args:
222entries: The sequence of bounding boxes to parse.
223
224Returns:
225One numpy array which contains positions of all items in the input entry.
226The first token and the last one is BOS and EOS symbols.
227Following previous works, we discrete the positon information according to
228resolution_w and resolution_h.
229"""
230processed_entry = []231for box in entries[:self.limit]:232category_id = box["category_id"]233center = box["center"]234width = box["width"]235height = box["height"]236class_id = category_id + self.offset_class237discrete_x = round(center[0] *238(self.resolution_w - 1)) + self.offset_center_x239discrete_y = round(center[1] *240(self.resolution_h - 1)) + self.offset_center_y241# Clip the width and height of assets at least 1.242discrete_width = round(243np.clip(width * (self.resolution_w - 1), 1.,244self.resolution_w - 1)) + self.offset_width245discrete_height = round(246np.clip(height * (self.resolution_h - 1), 1.,247self.resolution_h - 1)) + self.offset_height248processed_entry.extend(249[class_id, discrete_width, discrete_height, discrete_x, discrete_y])250if self.add_bos:251# add bos and eos to the input seq252processed_entry = [self.bos_idx] + processed_entry + [self.eos_idx]253return np.array(processed_entry, dtype=np.int32)254
255def boxes_iterator(self,):256"""Reads the dataset and produces an generator with each preprocessed entry.257
258Yields:
259Preprocessed entry format for VTN model.
260"""
261if self.shuffle:262data = list(self.data)263random.shuffle(data)264else:265data = self.data266
267for entry in data:268# if (max_number_elements is not None and269# len(entry["children"]) > max_number_elements):270# continue271# # sorted_entry = sorted(272# # entry["children"], key=lambda e: (e["center"][1], e["center"][0]))273# sorted_entry = entry["children"]274# # random.shuffle(sorted_entry)275# inputs = self._convert_entry_to_model_format(sorted_entry)276inputs = self._convert_entry_to_model_format(entry["children"])277yield inputs278
279def setup_tf_dataset(280self,281batch_size,282max_length,283group_data_by_size):284"""Instantiates a tf.data.Dataset from a `dataset_parser`.285
286Args:
287batch_size: The dataset instance is batched using this value.
288max_length: The maximum length of input sequence.
289group_data_by_size: If true, the data is batched grouping entries by their
290sequence length.
291
292Returns:
293An initialized dataset instance containing the same data as
294`dataset_parser`.
295"""
296bucket_boundaries = (32, 52, 72, 102, 132, 152, 172, 192)297dataset = tf.data.Dataset.from_generator(298functools.partial(self.boxes_iterator),299output_types=tf.int32, output_shapes=tf.TensorShape([None]))300if group_data_by_size:301dataset = dataset.apply(302tf.data.experimental.bucket_by_sequence_length(303element_length_func=lambda x: tf.shape(x)[0],304bucket_boundaries=bucket_boundaries,305bucket_batch_sizes=[306batch_size for _ in range(len(bucket_boundaries) + 1)307]))308else:309dataset = dataset.padded_batch(310batch_size,311padding_values=self.pad_idx,312padded_shapes=max_length)313return dataset.prefetch(tf.data.experimental.AUTOTUNE)314