google-research
113 строк · 3.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for loading the Pathfinder dataset."""
17
18from lra.lra_benchmarks.data import pathfinder19import tensorflow as tf20import tensorflow_datasets as tfds21
22
23# Please set this variable to the path for the LRA pathfinder data.
24_PATHFINDER_TFDS_PATH = None25
26
27
28AUTOTUNE = tf.data.experimental.AUTOTUNE29
30
31def load(n_devices=1,32batch_size=256,33resolution=32,34normalize=False,35difficulty='easy'):36"""Get Pathfinder dataset splits.37
38Args:
39n_devices: Number of devices used. Default: 1
40batch_size: Batch size
41resolution: Resolution of the images. Either 32, 64 or 128.
42normalize: If True, the images have float elements in [0,1].
43difficulty: Controls the number of distractor paths.
44
45Returns:
46(train_dataset, val_dataset, test_dataset, num_classes,
47vocab_size, input_shape)
48"""
49
50if _PATHFINDER_TFDS_PATH is None:51raise ValueError(52'You must set _PATHFINDER_TFDS_PATH above to your pathfinder data path.'53)54
55if batch_size % n_devices:56raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %57(batch_size, n_devices))58
59if difficulty not in ['easy', 'intermediate', 'hard']:60raise ValueError("difficulty must be in ['easy', 'intermediate', 'hard'].")61
62if resolution == 32:63builder = pathfinder.Pathfinder32(data_dir=_PATHFINDER_TFDS_PATH)64inputs_shape = (32, 32)65elif resolution == 64:66builder = pathfinder.Pathfinder64(data_dir=_PATHFINDER_TFDS_PATH)67inputs_shape = (64, 64)68elif resolution == 128:69builder = pathfinder.Pathfinder128(data_dir=_PATHFINDER_TFDS_PATH)70inputs_shape = (128, 128)71elif resolution == 256:72builder = pathfinder.Pathfinder256(data_dir=_PATHFINDER_TFDS_PATH)73inputs_shape = (256, 256)74else:75raise ValueError('Resolution must be in [32, 64, 128, 256].')76
77def get_split(difficulty):78ds = builder.as_dataset(79split=difficulty, decoders={'image': tfds.decode.SkipDecoding()})80
81# Filter out examples with empty images:82ds = ds.filter(lambda x: tf.strings.length((x['image'])) > 0)83
84return ds85
86train_dataset = get_split(f'{difficulty}[:80%]')87val_dataset = get_split(f'{difficulty}[80%:90%]')88test_dataset = get_split(f'{difficulty}[90%:]')89
90def decode(x):91decoded = {92'inputs': tf.cast(tf.image.decode_png(x['image']), dtype=tf.int32),93'targets': x['label']94}95if normalize:96decoded['inputs'] = decoded['inputs'] / 25597return decoded98
99train_dataset = train_dataset.map(decode, num_parallel_calls=AUTOTUNE)100val_dataset = val_dataset.map(decode, num_parallel_calls=AUTOTUNE)101test_dataset = test_dataset.map(decode, num_parallel_calls=AUTOTUNE)102
103# TODO(gnegiar): Don't shuffle and batch here.104# Let the train.py file convert datapoints to graph representation105# and cache before shuffling and batching.106train_dataset = train_dataset.shuffle(107buffer_size=256 * 8, reshuffle_each_iteration=True)108
109train_dataset = train_dataset.batch(batch_size, drop_remainder=True)110val_dataset = val_dataset.batch(batch_size, drop_remainder=True)111test_dataset = test_dataset.batch(batch_size, drop_remainder=True)112
113return train_dataset, val_dataset, test_dataset, 2, 256, inputs_shape114