google-research
59 строк · 2.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""tf.data.Dataset interface to the CIFAR10 dataset."""
17
18from absl import logging19import tensorflow as tf20
21IMG_DIM = 3222NUM_CHANNELS = 323NUM_LABELS = 1024
25
26def dataset_randomized(pattern):27"""tf.data.Dataset object for CIFAR-10 training data."""28filenames = tf.io.gfile.glob(pattern)29logging.info('*** Input Files ***')30for input_file in filenames:31logging.info(' %s', input_file)32
33ds_filenames = tf.data.Dataset.from_tensor_slices(tf.constant(filenames))34ds_filenames = ds_filenames.shuffle(buffer_size=len(filenames))35
36dataset = tf.data.TFRecordDataset(ds_filenames)37
38# Create a description of the features.39feature_description = {40'image/class/label': tf.io.FixedLenFeature([], tf.int64),41'image/class/shuffled_label': tf.io.FixedLenFeature([], tf.int64),42'image/encoded': tf.io.FixedLenFeature([], tf.string),43}44
45def decode_image(image):46image = tf.io.decode_png(image)47image = tf.cast(image, tf.float32)48image = tf.image.per_image_standardization(image)49image = tf.reshape(image, [IMG_DIM * IMG_DIM * NUM_CHANNELS])50return image51
52def parse_function(example_proto):53# Parse the input tf.Example proto using the dictionary above.54features = tf.io.parse_single_example(example_proto, feature_description)55features['image/encoded'] = decode_image(features['image/encoded'])56return features57
58dataset = dataset.map(parse_function)59return dataset60