google-research

Форк
0
/
pathfinder_data.py 
113 строк · 3.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for loading the Pathfinder dataset."""
17

18
from lra.lra_benchmarks.data import pathfinder
19
import tensorflow as tf
20
import tensorflow_datasets as tfds
21

22

23
# Please set this variable to the path for the LRA pathfinder data.
24
_PATHFINDER_TFDS_PATH = None
25

26

27

28
AUTOTUNE = tf.data.experimental.AUTOTUNE
29

30

31
def load(n_devices=1,
32
         batch_size=256,
33
         resolution=32,
34
         normalize=False,
35
         difficulty='easy'):
36
  """Get Pathfinder dataset splits.
37

38
  Args:
39
    n_devices: Number of devices used. Default: 1
40
    batch_size: Batch size
41
    resolution: Resolution of the images. Either 32, 64 or 128.
42
    normalize: If True, the images have float elements in [0,1].
43
    difficulty: Controls the number of distractor paths.
44

45
  Returns:
46
    (train_dataset, val_dataset, test_dataset, num_classes,
47
    vocab_size, input_shape)
48
  """
49

50
  if _PATHFINDER_TFDS_PATH is None:
51
    raise ValueError(
52
        'You must set _PATHFINDER_TFDS_PATH above to your pathfinder data path.'
53
    )
54

55
  if batch_size % n_devices:
56
    raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
57
                     (batch_size, n_devices))
58

59
  if difficulty not in ['easy', 'intermediate', 'hard']:
60
    raise ValueError("difficulty must be in ['easy', 'intermediate', 'hard'].")
61

62
  if resolution == 32:
63
    builder = pathfinder.Pathfinder32(data_dir=_PATHFINDER_TFDS_PATH)
64
    inputs_shape = (32, 32)
65
  elif resolution == 64:
66
    builder = pathfinder.Pathfinder64(data_dir=_PATHFINDER_TFDS_PATH)
67
    inputs_shape = (64, 64)
68
  elif resolution == 128:
69
    builder = pathfinder.Pathfinder128(data_dir=_PATHFINDER_TFDS_PATH)
70
    inputs_shape = (128, 128)
71
  elif resolution == 256:
72
    builder = pathfinder.Pathfinder256(data_dir=_PATHFINDER_TFDS_PATH)
73
    inputs_shape = (256, 256)
74
  else:
75
    raise ValueError('Resolution must be in [32, 64, 128, 256].')
76

77
  def get_split(difficulty):
78
    ds = builder.as_dataset(
79
        split=difficulty, decoders={'image': tfds.decode.SkipDecoding()})
80

81
    # Filter out examples with empty images:
82
    ds = ds.filter(lambda x: tf.strings.length((x['image'])) > 0)
83

84
    return ds
85

86
  train_dataset = get_split(f'{difficulty}[:80%]')
87
  val_dataset = get_split(f'{difficulty}[80%:90%]')
88
  test_dataset = get_split(f'{difficulty}[90%:]')
89

90
  def decode(x):
91
    decoded = {
92
        'inputs': tf.cast(tf.image.decode_png(x['image']), dtype=tf.int32),
93
        'targets': x['label']
94
    }
95
    if normalize:
96
      decoded['inputs'] = decoded['inputs'] / 255
97
    return decoded
98

99
  train_dataset = train_dataset.map(decode, num_parallel_calls=AUTOTUNE)
100
  val_dataset = val_dataset.map(decode, num_parallel_calls=AUTOTUNE)
101
  test_dataset = test_dataset.map(decode, num_parallel_calls=AUTOTUNE)
102

103
  # TODO(gnegiar): Don't shuffle and batch here.
104
  # Let the train.py file convert datapoints to graph representation
105
  # and cache before shuffling and batching.
106
  train_dataset = train_dataset.shuffle(
107
      buffer_size=256 * 8, reshuffle_each_iteration=True)
108

109
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
110
  val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
111
  test_dataset = test_dataset.batch(batch_size, drop_remainder=True)
112

113
  return train_dataset, val_dataset, test_dataset, 2, 256, inputs_shape
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.