google-research

Форк
0
/
datasets.py 
188 строк · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Wrapper for datasets."""
17

18
import functools
19
import os
20
import re
21
import tensorflow as tf
22
import tensorflow_datasets as tfds
23
from coltran.utils import datasets_utils
24

25

26
def resize_to_square(image, resolution=32, train=True):
27
  """Preprocess the image in a way that is OK for generative modeling."""
28

29
  # Crop a square-shaped image by shortening the longer side.
30
  image_shape = tf.shape(image)
31
  height, width, channels = image_shape[0], image_shape[1], image_shape[2]
32
  side_size = tf.minimum(height, width)
33
  cropped_shape = tf.stack([side_size, side_size, channels])
34
  if train:
35
    image = tf.image.random_crop(image, cropped_shape)
36
  else:
37
    image = tf.image.resize_with_crop_or_pad(
38
        image, target_height=side_size, target_width=side_size)
39

40
  image = datasets_utils.change_resolution(image, res=resolution, method='area')
41
  return image
42

43

44
def preprocess(example, train=True, resolution=256):
45
  """Apply random crop (or) central crop to the image."""
46
  image = example
47

48
  is_label = False
49
  if isinstance(example, dict):
50
    image = example['image']
51
    is_label = 'label' in example.keys()
52

53
  image = resize_to_square(image, train=train, resolution=resolution)
54

55
  # keepng 'file_name' key creates some undebuggable TPU Error.
56
  example_copy = dict()
57
  example_copy['image'] = image
58
  example_copy['targets'] = image
59
  if is_label:
60
    example_copy['label'] = example['label']
61
  return example_copy
62

63

64
def get_gen_dataset(data_dir, batch_size):
65
  """Converts a list of generated TFRecords into a TF Dataset."""
66

67
  def parse_example(example_proto, res=64):
68
    features = {'image': tf.io.FixedLenFeature([res*res*3], tf.int64)}
69
    example = tf.io.parse_example(example_proto, features=features)
70
    image = tf.reshape(example['image'], (res, res, 3))
71
    return {'targets': image}
72

73
  # Provided generated dataset.
74
  def tf_record_name_to_num(x):
75
    x = x.split('.')[0]
76
    x = re.split(r'(\d+)', x)
77
    return int(x[1])
78

79
  assert data_dir is not None
80
  records = tf.io.gfile.listdir(data_dir)
81
  max_num = max(records, key=tf_record_name_to_num)
82
  max_num = tf_record_name_to_num(max_num)
83

84
  records = []
85
  for record in range(max_num + 1):
86
    path = os.path.join(data_dir, f'gen{record}.tfrecords')
87
    records.append(path)
88

89
  tf_dataset = tf.data.TFRecordDataset(records)
90
  tf_dataset = tf_dataset.map(parse_example, num_parallel_calls=100)
91
  tf_dataset = tf_dataset.batch(batch_size=batch_size)
92
  tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)
93
  return tf_dataset
94

95

96
def create_gen_dataset_from_images(image_dir):
97
  """Creates a dataset from the provided directory."""
98
  def load_image(path):
99
    image_str = tf.io.read_file(path)
100
    return tf.image.decode_image(image_str, channels=3)
101

102
  child_files = tf.io.gfile.listdir(image_dir)
103
  files = [os.path.join(image_dir, file) for file in child_files]
104
  files = tf.convert_to_tensor(files, dtype=tf.string)
105
  dataset = tf.data.Dataset.from_tensor_slices((files))
106
  return dataset.map(load_image, num_parallel_calls=100)
107

108

109
def get_imagenet(subset, read_config):
110
  """Gets imagenet dataset."""
111
  train = subset == 'train'
112
  num_val_examples = 0 if subset == 'eval_train' else 10000
113
  if subset == 'test':
114
    ds = tfds.load('imagenet2012', split='validation', shuffle_files=False)
115
  else:
116
    # split 10000 samples from the imagenet dataset for validation.
117
    ds, info = tfds.load('imagenet2012', split='train', with_info=True,
118
                         shuffle_files=train, read_config=read_config)
119
    num_train = info.splits['train'].num_examples - num_val_examples
120
    if train:
121
      ds = ds.take(num_train)
122
    elif subset == 'valid':
123
      ds = ds.skip(num_train)
124
  return ds
125

126

127
def get_dataset(name,
128
                config,
129
                batch_size,
130
                subset,
131
                read_config=None,
132
                data_dir=None):
133
  """Wrapper around TF-Datasets.
134

135
  * Setting `config.random_channel to be True` adds
136
    ds['targets_slice'] - Channel picked at random. (of 3).
137
    ds['channel_index'] - Index of the randomly picked channel
138
  * Setting `config.downsample` to be True, adds:.
139
    ds['targets_64'] - Downsampled 64x64 input using tf.resize.
140
    ds['targets_64_up_back] - 'targets_64' upsampled using tf.resize
141

142
  Args:
143
    name: imagenet
144
    config: dict
145
    batch_size: batch size.
146
    subset: 'train', 'eval_train', 'valid' or 'test'.
147
    read_config: optional, tfds.ReadConfg instance. This is used for sharding
148
                 across multiple workers.
149
    data_dir: Data Directory, Used for Custom dataset.
150
  Returns:
151
   dataset: TF Dataset.
152
  """
153
  downsample = config.get('downsample', False)
154
  random_channel = config.get('random_channel', False)
155
  downsample_res = config.get('downsample_res', 64)
156
  downsample_method = config.get('downsample_method', 'area')
157
  num_epochs = config.get('num_epochs', -1)
158
  data_dir = config.get('data_dir') or data_dir
159
  auto = tf.data.AUTOTUNE
160
  train = subset == 'train'
161

162
  if name == 'imagenet':
163
    ds = get_imagenet(subset, read_config)
164
  elif name == 'custom':
165
    assert data_dir is not None
166
    ds = create_gen_dataset_from_images(data_dir)
167
  else:
168
    raise ValueError(f'Expected dataset in [imagenet, custom]. Got {name}')
169

170
  ds = ds.map(
171
      lambda x: preprocess(x, train=train), num_parallel_calls=100)
172
  if train and random_channel:
173
    ds = ds.map(datasets_utils.random_channel_slice)
174
  if downsample:
175
    downsample_part = functools.partial(
176
        datasets_utils.downsample_and_upsample,
177
        train=train,
178
        downsample_res=downsample_res,
179
        upsample_res=256,
180
        method=downsample_method)
181
    ds = ds.map(downsample_part, num_parallel_calls=100)
182

183
  if train:
184
    ds = ds.repeat(num_epochs)
185
    ds = ds.shuffle(buffer_size=128)
186
  ds = ds.batch(batch_size, drop_remainder=True)
187
  ds = ds.prefetch(auto)
188
  return ds
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.