google-research
81 строка · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Getting a input function that will give input and label tensors."""
17
18from tensor2tensor import problems19import tensorflow.compat.v1 as tf20from tensorflow.compat.v1 import estimator as tf_estimator21
22
23def get_input(24batch_size=50,25augmented=False,26data='cifar10',27mode=tf_estimator.ModeKeys.TRAIN,28repeat_num=None,29data_format='HWC'):30"""Returns a input function for the estimator framework.31
32Args:
33batch_size: batch size for training or testing
34augmented: whether data augmentation is used
35data: a string that specifies the dataset, must be cifar10
36or cifar100
37mode: indicates whether the input is for training or testing,
38needs to be a member of tf.estimator.ModeKeys
39repeat_num: how many times the dataset is repeated
40data_format: order of the data's axis
41
42Returns:
43an input function
44"""
45assert data == 'cifar10' or data == 'cifar100'46class_num = 10 if data == 'cifar10' else 10047data = 'image_' + data48
49if mode != tf_estimator.ModeKeys.TRAIN:50repeat_num = 151
52problem_name = data53if data == 'image_cifar10' and not augmented:54problem_name = 'image_cifar10_plain'55
56def preprocess(example):57"""Perform per image standardization on a single image."""58image = example['inputs']59image.set_shape([32, 32, 3])60image = tf.cast(image, tf.float32)61example['inputs'] = tf.image.per_image_standardization(image)62return example63
64def input_data():65"""Input function to be returned."""66prob = problems.problem(problem_name)67if data == 'image_cifar100':68dataset = prob.dataset(mode, preprocess=augmented)69if not augmented: dataset = dataset.map(map_func=preprocess)70else:71dataset = prob.dataset(mode, preprocess=False)72dataset = dataset.map(map_func=preprocess)73
74dataset = dataset.batch(batch_size)75dataset = dataset.repeat(repeat_num)76dataset = dataset.make_one_shot_iterator().get_next()77if data_format == 'CHW':78dataset['inputs'] = tf.transpose(dataset['inputs'], (0, 3, 1, 2))79return dataset['inputs'], tf.squeeze(tf.one_hot(dataset['targets'],80class_num))81return input_data82