google-research
426 строк · 15.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=logging-format-interpolation
17# pylint: disable=g-long-lambda
18# pylint: disable=logging-not-lazy
19# pylint: disable=protected-access
20
21r"""Data."""
22
23import os24
25from absl import logging26
27import numpy as np28
29import tensorflow.compat.v1 as tf # tf30
31from differentiable_data_selection import augment32
33CIFAR_PATH = ''34CIFAR_MEAN = np.array([0.491400, 0.482158, 0.4465309], np.float32) * 255.35CIFAR_STDDEV = np.array([0.247032, 0.243485, 0.26159], np.float32) * 255.36
37
38
39################################################################################
40# #
41# OUTSIDE INTERFACE #
42# #
43################################################################################
44
45
46def _dataset_service(params, dataset, start_index=None, final_index=None):47"""Wrap `dataset` into `dataset_service`."""48return dataset49
50
51def get_image_mean_and_std(params):52"""Builds `eval_data` depends on `params.dataset_name`."""53if params.dataset_name.lower().startswith('cifar'):54return CIFAR_MEAN, CIFAR_STDDEV55else:56raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')57
58
59def convert_and_normalize(params, images):60"""Subtract mean and divide stddev depending on the dataset."""61dtype = tf.bfloat16 if params.use_bfloat16 else tf.float3262if 'cifar' in params.dataset_name.lower():63images = tf.cast(images, dtype)64else:65images = tf.image.convert_image_dtype(images, dtype)66shape = [1, 1, 1, 3] if len(images.shape.as_list()) == 4 else [1, 1, 3]67
68mean, stddev = get_image_mean_and_std(params)69mean = tf.reshape(tf.cast(mean, images.dtype), shape)70stddev = tf.reshape(tf.cast(stddev, images.dtype), shape)71images = (images - mean) / stddev72
73return images74
75
76def get_eval_size(params):77"""Builds `eval_data` depends on `params.dataset_name`."""78eval_sizes = {79'cifar10': 10000,80'cifar10_dds': 10000,81}82if params.dataset_name.lower() not in eval_sizes.keys():83raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')84eval_size = eval_sizes[params.dataset_name.lower()]85return compute_num_padded_data(params, eval_size)86
87
88def build_eval_dataset(params,89batch_size=None,90num_workers=None,91worker_index=None):92"""Builds `eval_data` depends on `params.dataset_name`."""93if params.dataset_name.lower() in ['cifar10', 'cifar10_dds']:94eval_data = cifar10_eval(95params, batch_size=batch_size, eval_mode='test',96num_workers=num_workers, worker_index=worker_index)97else:98raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')99
100return eval_data101
102
103def build_train_infeeds(params):104"""Create the TPU infeed ops."""105dev_assign = params.device_assignment106host_to_tpus = {}107for replica_id in range(params.num_replicas):108host_device = dev_assign.host_device(replica=replica_id, logical_core=0)109tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)110logging.info(f'replica_id={replica_id} '111f'host_device={host_device} '112f'tpu_ordinal={tpu_ordinal}')113
114if host_device not in host_to_tpus:115host_to_tpus[host_device] = [tpu_ordinal]116else:117assert tpu_ordinal not in host_to_tpus[host_device]118host_to_tpus[host_device].append(tpu_ordinal)119
120infeed_ops = []121infeed_graphs = []122num_inputs = len(host_to_tpus)123for i, (host, tpus) in enumerate(host_to_tpus.items()):124infeed_graph = tf.Graph()125infeed_graphs.append(infeed_graph)126with infeed_graph.as_default():127def enqueue_fn(host_device=host, input_index=i, device_ordinals=tpus):128"""Docs."""129worker_infeed_ops = []130with tf.device(host_device):131dataset = build_train_dataset(132params=params,133batch_size=params.train_batch_size // num_inputs,134num_inputs=num_inputs,135input_index=input_index)136inputs = tf.data.make_one_shot_iterator(dataset).get_next()137
138num_splits = len(device_ordinals)139if len(device_ordinals) > 1:140inputs = [tf.split(v, num_splits, 0) for v in inputs]141else:142inputs = [[v] for v in inputs]143input_dtypes = [v[0].dtype for v in inputs]144input_shapes = [v[0].shape for v in inputs]145params.add_hparam('train_dtypes', input_dtypes)146params.add_hparam('train_shapes', input_shapes)147for j, device_ordinal in enumerate(device_ordinals):148worker_infeed_ops.append(tf.raw_ops.InfeedEnqueueTuple(149inputs=[v[j] for v in inputs],150shapes=input_shapes,151device_ordinal=device_ordinal))152return worker_infeed_ops153def _body(i):154with tf.control_dependencies(enqueue_fn()):155return i+1156infeed_op = tf.while_loop(157lambda step: tf.less(step, tf.cast(params.save_every, step.dtype)),158_body, [0], parallel_iterations=1, name='train_infeed').op159infeed_ops.append(infeed_op)160
161return infeed_ops, infeed_graphs162
163
164def build_eval_infeeds(params):165"""Create the TPU infeed ops."""166
167eval_size = get_eval_size(params)168num_eval_steps = eval_size // params.eval_batch_size169
170dev_assign = params.device_assignment171host_to_tpus = {}172for replica_id in range(params.num_replicas):173host_device = dev_assign.host_device(replica=replica_id, logical_core=0)174tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)175
176if host_device not in host_to_tpus:177host_to_tpus[host_device] = [tpu_ordinal]178else:179assert tpu_ordinal not in host_to_tpus[host_device]180host_to_tpus[host_device].append(tpu_ordinal)181
182infeed_ops = []183infeed_graphs = []184num_inputs = len(host_to_tpus)185for i, (host, tpus) in enumerate(host_to_tpus.items()):186infeed_graph = tf.Graph()187infeed_graphs.append(infeed_graph)188with infeed_graph.as_default():189def enqueue_fn(host_device=host, input_index=i, device_ordinals=tpus):190"""Docs."""191worker_infeed_ops = []192with tf.device(host_device):193dataset = build_eval_dataset(194params,195batch_size=params.eval_batch_size // num_inputs,196num_workers=num_inputs,197worker_index=input_index)198inputs = tf.data.make_one_shot_iterator(dataset).get_next()199
200num_splits = len(device_ordinals)201if len(device_ordinals) > 1:202inputs = [tf.split(v, num_splits, 0) for v in inputs]203else:204inputs = [[v] for v in inputs]205input_dtypes = [v[0].dtype for v in inputs]206input_shapes = [v[0].shape for v in inputs]207params.add_hparam('eval_dtypes', input_dtypes)208params.add_hparam('eval_shapes', input_shapes)209for j, device_ordinal in enumerate(device_ordinals):210worker_infeed_ops.append(tf.raw_ops.InfeedEnqueueTuple(211inputs=[v[j] for v in inputs],212shapes=input_shapes,213device_ordinal=device_ordinal))214return worker_infeed_ops215def _body(i):216with tf.control_dependencies(enqueue_fn()):217return i+1218infeed_op = tf.while_loop(219lambda step: tf.less(step, tf.cast(num_eval_steps, step.dtype)),220_body, [0], parallel_iterations=1, name='eval_infeed').op221infeed_ops.append(infeed_op)222
223return infeed_ops, infeed_graphs, eval_size224
225
226def build_train_dataset(params, batch_size=None, num_inputs=1, input_index=0):227"""Builds `train_data` and `eval_data` depends on `params.dataset_name`."""228del num_inputs229del input_index230if params.dataset_name.lower() == 'cifar10':231dataset = cifar10_train(params, batch_size)232dataset = _dataset_service(params, dataset)233elif params.dataset_name.lower() == 'cifar10_dds':234dataset = cifar10_dds(params, batch_size)235dataset = _dataset_service(params, dataset)236else:237raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')238
239return dataset240
241
242def _smallest_mul(a, m):243return (a + m-1) // m * m244
245
246def compute_num_padded_data(params, data_size):247"""Compute number of eval steps."""248return (_smallest_mul(data_size, params.eval_batch_size)249+ params.eval_batch_size)250
251
252def _add_sample_weight(params, dataset, num_images):253"""Maps a `Dataset` of `d_0, d_1, ...` to compute its `sample_weights`."""254
255if 'eval_image_size' in params:256image_size = max(params.image_size, params.eval_image_size)257else:258image_size = params.image_size259
260dtype = tf.bfloat16 if params.use_bfloat16 else tf.float32261dummy_dataset = tf.data.Dataset.from_tensors((262tf.zeros([image_size, image_size, 3], dtype=dtype),263tf.zeros([params.num_classes], dtype=tf.float32),264tf.constant(0., dtype=tf.float32),265)).repeat()266
267def _transform(images, labels):268return images, labels, tf.constant(1., dtype=tf.float32)269dataset = dataset.map(_transform, tf.data.experimental.AUTOTUNE)270dataset = dataset.concatenate(dummy_dataset)271dataset = dataset.take(num_images // params.num_workers).cache().repeat()272
273return dataset274
275
276def _optimize_dataset(dataset):277"""Routines to optimize `Dataset`'s speed."""278dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)279options = tf.data.Options()280options.experimental_optimization.parallel_batch = True281options.experimental_optimization.map_fusion = True282options.experimental_optimization.map_parallelization = True283dataset = dataset.with_options(options)284dataset = dataset.prefetch(1)285return dataset286
287
288def _flip_and_jitter(x, replace_value=0):289"""Flip left/right and jitter."""290x = tf.image.random_flip_left_right(x)291image_size = min([x.shape[0], x.shape[1]])292pad_size = image_size // 8293x = tf.pad(x, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]],294constant_values=replace_value)295x = tf.image.random_crop(x, [image_size, image_size, 3])296x.set_shape([image_size, image_size, 3])297return x298
299
300def _jitter(x, replace_value=0):301"""Flip left/right and jitter."""302image_size = min([x.shape[0], x.shape[1]])303pad_size = image_size // 8304x = tf.pad(x, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]],305constant_values=replace_value)306x = tf.image.random_crop(x, [image_size, image_size, 3])307x.set_shape([image_size, image_size, 3])308return x309
310
311################################################################################
312# #
313# CIFAR-10 #
314# #
315################################################################################
316
317
318def _cifar10_parser(params, value, training):319"""Cifar10 parser."""320image_size = params.image_size321value = tf.io.decode_raw(value, tf.uint8)322label = tf.cast(value[0], tf.int32)323label = tf.one_hot(label, depth=params.num_classes, dtype=tf.float32)324image = tf.reshape(value[1:], [3, 32, 32]) # uint8325image = tf.transpose(image, [1, 2, 0])326if image_size != 32:327image = tf.image.resize_bicubic([image], [image_size, image_size])[0]328image.set_shape([image_size, image_size, 3])329
330if training:331if params.use_augment:332aug = augment.RandAugment(cutout_const=image_size//8,333translate_const=image_size//8,334magnitude=params.augment_magnitude)335image = _flip_and_jitter(image, 128)336image = aug.distort(image)337image = augment.cutout(image, pad_size=image_size//4, replace=128)338else:339image = _flip_and_jitter(image, 128)340image = convert_and_normalize(params, image)341return image, label342
343
344def cifar10_train(params, batch_size=None):345"""Load CIFAR-10 data."""346shuffle_size = batch_size * 16347
348filenames = [os.path.join(CIFAR_PATH, 'train.bin')]349record_bytes = 1 + (3 * 32 * 32)350dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)351dataset = dataset.skip(5000).cache()352dataset = dataset.map(353lambda x: _cifar10_parser(params, x, training=True),354num_parallel_calls=tf.data.experimental.AUTOTUNE)355dataset = dataset.shuffle(shuffle_size).repeat()356dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)357dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)358dataset = _optimize_dataset(dataset)359
360return dataset361
362
363def cifar10_dds(params, batch_size=None):364"""Load CIFAR-10 data."""365shuffle_size = batch_size * 16366
367filenames = [os.path.join(CIFAR_PATH, 'train.bin')]368record_bytes = 1 + (3 * 32 * 32)369dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)370
371train_dataset = dataset.skip(5000).cache()372train_dataset = train_dataset.map(373lambda x: _cifar10_parser(params, x, training=True),374num_parallel_calls=tf.data.experimental.AUTOTUNE)375train_dataset = train_dataset.shuffle(shuffle_size).repeat()376train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)377train_dataset = train_dataset.batch(batch_size=batch_size,378drop_remainder=True)379
380valid_dataset = dataset.take(5000).cache()381valid_dataset = valid_dataset.map(382lambda x: _cifar10_parser(params, x, training=False),383num_parallel_calls=tf.data.experimental.AUTOTUNE)384valid_dataset = valid_dataset.shuffle(shuffle_size).repeat()385valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE)386valid_dataset = valid_dataset.batch(batch_size=batch_size,387drop_remainder=True)388
389dataset = tf.data.Dataset.zip((train_dataset, valid_dataset))390dataset = dataset.map(lambda a, b: tuple([a[0], a[1], b[0], b[1]]),391tf.data.experimental.AUTOTUNE)392
393dataset = _optimize_dataset(dataset)394return dataset395
396
397def cifar10_eval(params, batch_size=None, eval_mode=None,398num_workers=None, worker_index=None):399"""Load CIFAR-10 data."""400
401if batch_size is None:402batch_size = params.eval_batch_size403
404if eval_mode == 'valid':405filenames = [os.path.join(CIFAR_PATH, 'val.bin')]406elif eval_mode == 'test':407filenames = [os.path.join(CIFAR_PATH, 'test_batch.bin')]408else:409raise ValueError(f'Unknown eval_mode {eval_mode}')410
411record_bytes = 1 + (3 * 32 * 32)412dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)413if num_workers is not None and worker_index is not None:414dataset = dataset.shard(num_workers, worker_index)415
416dataset = dataset.map(417lambda x: _cifar10_parser(params, x, training=False),418num_parallel_calls=tf.data.experimental.AUTOTUNE)419dataset = _add_sample_weight(params, dataset,420num_images=get_eval_size(params))421
422dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)423dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)424dataset = _optimize_dataset(dataset)425
426return dataset427