google-research

Форк
0
/
data_util.py 
81 строка · 2.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Getting a input function that will give input and label tensors."""
17

18
from tensor2tensor import problems
19
import tensorflow.compat.v1 as tf
20
from tensorflow.compat.v1 import estimator as tf_estimator
21

22

23
def get_input(
24
    batch_size=50,
25
    augmented=False,
26
    data='cifar10',
27
    mode=tf_estimator.ModeKeys.TRAIN,
28
    repeat_num=None,
29
    data_format='HWC'):
30
  """Returns a input function for the estimator framework.
31

32
  Args:
33
    batch_size: batch size for training or testing
34
    augmented:  whether data augmentation is used
35
    data:       a string that specifies the dataset, must be cifar10
36
                  or cifar100
37
    mode:       indicates whether the input is for training or testing,
38
                  needs to be a member of tf.estimator.ModeKeys
39
    repeat_num: how many times the dataset is repeated
40
    data_format: order of the data's axis
41

42
  Returns:
43
    an input function
44
  """
45
  assert data == 'cifar10' or data == 'cifar100'
46
  class_num = 10 if data == 'cifar10' else 100
47
  data = 'image_' + data
48

49
  if mode != tf_estimator.ModeKeys.TRAIN:
50
    repeat_num = 1
51

52
  problem_name = data
53
  if data == 'image_cifar10' and not augmented:
54
    problem_name = 'image_cifar10_plain'
55

56
  def preprocess(example):
57
    """Perform per image standardization on a single image."""
58
    image = example['inputs']
59
    image.set_shape([32, 32, 3])
60
    image = tf.cast(image, tf.float32)
61
    example['inputs'] = tf.image.per_image_standardization(image)
62
    return example
63

64
  def input_data():
65
    """Input function to be returned."""
66
    prob = problems.problem(problem_name)
67
    if data == 'image_cifar100':
68
      dataset = prob.dataset(mode, preprocess=augmented)
69
      if not augmented: dataset = dataset.map(map_func=preprocess)
70
    else:
71
      dataset = prob.dataset(mode, preprocess=False)
72
      dataset = dataset.map(map_func=preprocess)
73

74
    dataset = dataset.batch(batch_size)
75
    dataset = dataset.repeat(repeat_num)
76
    dataset = dataset.make_one_shot_iterator().get_next()
77
    if data_format == 'CHW':
78
      dataset['inputs'] = tf.transpose(dataset['inputs'], (0, 3, 1, 2))
79
    return dataset['inputs'], tf.squeeze(tf.one_hot(dataset['targets'],
80
                                                    class_num))
81
  return input_data
82

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.