google-research
188 строк · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Wrapper for datasets."""
17
18import functools
19import os
20import re
21import tensorflow as tf
22import tensorflow_datasets as tfds
23from coltran.utils import datasets_utils
24
25
26def resize_to_square(image, resolution=32, train=True):
27"""Preprocess the image in a way that is OK for generative modeling."""
28
29# Crop a square-shaped image by shortening the longer side.
30image_shape = tf.shape(image)
31height, width, channels = image_shape[0], image_shape[1], image_shape[2]
32side_size = tf.minimum(height, width)
33cropped_shape = tf.stack([side_size, side_size, channels])
34if train:
35image = tf.image.random_crop(image, cropped_shape)
36else:
37image = tf.image.resize_with_crop_or_pad(
38image, target_height=side_size, target_width=side_size)
39
40image = datasets_utils.change_resolution(image, res=resolution, method='area')
41return image
42
43
44def preprocess(example, train=True, resolution=256):
45"""Apply random crop (or) central crop to the image."""
46image = example
47
48is_label = False
49if isinstance(example, dict):
50image = example['image']
51is_label = 'label' in example.keys()
52
53image = resize_to_square(image, train=train, resolution=resolution)
54
55# keepng 'file_name' key creates some undebuggable TPU Error.
56example_copy = dict()
57example_copy['image'] = image
58example_copy['targets'] = image
59if is_label:
60example_copy['label'] = example['label']
61return example_copy
62
63
64def get_gen_dataset(data_dir, batch_size):
65"""Converts a list of generated TFRecords into a TF Dataset."""
66
67def parse_example(example_proto, res=64):
68features = {'image': tf.io.FixedLenFeature([res*res*3], tf.int64)}
69example = tf.io.parse_example(example_proto, features=features)
70image = tf.reshape(example['image'], (res, res, 3))
71return {'targets': image}
72
73# Provided generated dataset.
74def tf_record_name_to_num(x):
75x = x.split('.')[0]
76x = re.split(r'(\d+)', x)
77return int(x[1])
78
79assert data_dir is not None
80records = tf.io.gfile.listdir(data_dir)
81max_num = max(records, key=tf_record_name_to_num)
82max_num = tf_record_name_to_num(max_num)
83
84records = []
85for record in range(max_num + 1):
86path = os.path.join(data_dir, f'gen{record}.tfrecords')
87records.append(path)
88
89tf_dataset = tf.data.TFRecordDataset(records)
90tf_dataset = tf_dataset.map(parse_example, num_parallel_calls=100)
91tf_dataset = tf_dataset.batch(batch_size=batch_size)
92tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)
93return tf_dataset
94
95
96def create_gen_dataset_from_images(image_dir):
97"""Creates a dataset from the provided directory."""
98def load_image(path):
99image_str = tf.io.read_file(path)
100return tf.image.decode_image(image_str, channels=3)
101
102child_files = tf.io.gfile.listdir(image_dir)
103files = [os.path.join(image_dir, file) for file in child_files]
104files = tf.convert_to_tensor(files, dtype=tf.string)
105dataset = tf.data.Dataset.from_tensor_slices((files))
106return dataset.map(load_image, num_parallel_calls=100)
107
108
109def get_imagenet(subset, read_config):
110"""Gets imagenet dataset."""
111train = subset == 'train'
112num_val_examples = 0 if subset == 'eval_train' else 10000
113if subset == 'test':
114ds = tfds.load('imagenet2012', split='validation', shuffle_files=False)
115else:
116# split 10000 samples from the imagenet dataset for validation.
117ds, info = tfds.load('imagenet2012', split='train', with_info=True,
118shuffle_files=train, read_config=read_config)
119num_train = info.splits['train'].num_examples - num_val_examples
120if train:
121ds = ds.take(num_train)
122elif subset == 'valid':
123ds = ds.skip(num_train)
124return ds
125
126
127def get_dataset(name,
128config,
129batch_size,
130subset,
131read_config=None,
132data_dir=None):
133"""Wrapper around TF-Datasets.
134
135* Setting `config.random_channel to be True` adds
136ds['targets_slice'] - Channel picked at random. (of 3).
137ds['channel_index'] - Index of the randomly picked channel
138* Setting `config.downsample` to be True, adds:.
139ds['targets_64'] - Downsampled 64x64 input using tf.resize.
140ds['targets_64_up_back] - 'targets_64' upsampled using tf.resize
141
142Args:
143name: imagenet
144config: dict
145batch_size: batch size.
146subset: 'train', 'eval_train', 'valid' or 'test'.
147read_config: optional, tfds.ReadConfg instance. This is used for sharding
148across multiple workers.
149data_dir: Data Directory, Used for Custom dataset.
150Returns:
151dataset: TF Dataset.
152"""
153downsample = config.get('downsample', False)
154random_channel = config.get('random_channel', False)
155downsample_res = config.get('downsample_res', 64)
156downsample_method = config.get('downsample_method', 'area')
157num_epochs = config.get('num_epochs', -1)
158data_dir = config.get('data_dir') or data_dir
159auto = tf.data.AUTOTUNE
160train = subset == 'train'
161
162if name == 'imagenet':
163ds = get_imagenet(subset, read_config)
164elif name == 'custom':
165assert data_dir is not None
166ds = create_gen_dataset_from_images(data_dir)
167else:
168raise ValueError(f'Expected dataset in [imagenet, custom]. Got {name}')
169
170ds = ds.map(
171lambda x: preprocess(x, train=train), num_parallel_calls=100)
172if train and random_channel:
173ds = ds.map(datasets_utils.random_channel_slice)
174if downsample:
175downsample_part = functools.partial(
176datasets_utils.downsample_and_upsample,
177train=train,
178downsample_res=downsample_res,
179upsample_res=256,
180method=downsample_method)
181ds = ds.map(downsample_part, num_parallel_calls=100)
182
183if train:
184ds = ds.repeat(num_epochs)
185ds = ds.shuffle(buffer_size=128)
186ds = ds.batch(batch_size, drop_remainder=True)
187ds = ds.prefetch(auto)
188return ds
189