google-research
101 строка · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2016 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""Provides data for the Cifar10 dataset."""
32
33from __future__ import absolute_import
34from __future__ import division
35from __future__ import print_function
36
37import os
38
39import tensorflow as tf
40
41from tensorflow.contrib.slim.python.slim.data import dataset
42from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
43
44_FILE_PATTERN = 'cifar10_%s-*'
45
46_DATASET_DIR = ('')
47
48_SPLITS_TO_SIZES = {'train': 50000, 'test': 10000}
49
50_NUM_CLASSES = 10
51
52_ITEMS_TO_DESCRIPTIONS = {
53'image': 'A [32 x 32 x 3] color image.',
54'label': 'A single integer between 0 and 9',
55'image/format': 'a string indicating the image format.',
56}
57
58
59def get_split(split_name, dataset_dir=None):
60"""Gets a dataset tuple with instructions for reading cifar10.
61
62Args:
63split_name: A train/test split name.
64dataset_dir: The base directory of the dataset sources.
65
66Returns:
67A `Dataset` namedtuple. Image tensors are integers in [0, 255].
68
69Raises:
70ValueError: if `split_name` is not a valid train/test split.
71"""
72if split_name not in _SPLITS_TO_SIZES:
73raise ValueError('split name %s was not recognized.' % split_name)
74
75if dataset_dir is None:
76dataset_dir = _DATASET_DIR
77
78file_pattern = os.path.join(dataset_dir, _FILE_PATTERN % split_name)
79
80keys_to_features = {
81'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
82'image/format': tf.FixedLenFeature((), tf.string, default_value=''),
83'image/class/label': tf.FixedLenFeature(
84[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
85}
86
87items_to_handlers = {
88'image': tfexample_decoder.Image(shape=[32, 32, 3]),
89'label': tfexample_decoder.Tensor('image/class/label'),
90}
91
92decoder = tfexample_decoder.TFExampleDecoder(
93keys_to_features, items_to_handlers)
94
95return dataset.Dataset(
96data_sources=file_pattern,
97reader=tf.TFRecordReader,
98decoder=decoder,
99num_samples=_SPLITS_TO_SIZES[split_name],
100num_classes=_NUM_CLASSES,
101items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
102