google-research
107 строк · 3.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2016 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""Provides data for the Cifar100 dataset."""
32
33from __future__ import absolute_import34from __future__ import division35from __future__ import print_function36
37import os38
39import tensorflow as tf40
41from tensorflow.contrib.slim.python.slim.data import dataset42from tensorflow.contrib.slim.python.slim.data import tfexample_decoder43
44_FILE_PATTERN = 'cifar100_%s-*'45
46_DATASET_DIR = ('')47
48_SPLITS_TO_SIZES = {'train': 50000, 'test': 10000}49
50_NUM_CLASSES = 10051
52_ITEMS_TO_DESCRIPTIONS = {53'image': 'A [32 x 32 x 3] color image.',54'image/class/label': 'A single integer between 0 and 99.',55'image/format': 'a string indicating the image format.',56'image/class/fine_label': 'A single integer between 0 and 99.',57}
58
59
60def get_split(split_name, dataset_dir=None):61"""Gets a dataset tuple with instructions for reading cifar100.62
63Args:
64split_name: A train/test split name.
65dataset_dir: The base directory of the dataset sources.
66
67Returns:
68A `Dataset` namedtuple. Image tensors are integers in [0, 255].
69
70Raises:
71ValueError: if `split_name` is not a valid train/test split.
72"""
73if split_name not in _SPLITS_TO_SIZES:74raise ValueError('split name %s was not recognized.' % split_name)75
76file_pattern = os.path.join(dataset_dir, _FILE_PATTERN % split_name)77
78keys_to_features = {79'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),80'image/format': tf.FixedLenFeature((), tf.string, default_value=''),81'image/class/label': tf.FixedLenFeature(82[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),83'image/class/fine_label': tf.FixedLenFeature(84[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),85}86
87if split_name == 'train':88items_to_handlers = {89'image': tfexample_decoder.Image(shape=[32, 32, 3]),90'label': tfexample_decoder.Tensor('image/class/label'),91}92else:93items_to_handlers = {94'image': tfexample_decoder.Image(shape=[32, 32, 3]),95'label': tfexample_decoder.Tensor('image/class/fine_label'),96}97
98decoder = tfexample_decoder.TFExampleDecoder(99keys_to_features, items_to_handlers)100
101return dataset.Dataset(102data_sources=file_pattern,103reader=tf.TFRecordReader,104decoder=decoder,105num_samples=_SPLITS_TO_SIZES[split_name],106num_classes=_NUM_CLASSES,107items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)108