TheAlgorithms-Python
343 строки · 11.7 Кб
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for downloading and reading MNIST data (deprecated).
16
17This module and all its submodules are deprecated.
18"""
19
20import gzip21import os22import typing23import urllib24
25import numpy as np26from tensorflow.python.framework import dtypes, random_seed27from tensorflow.python.platform import gfile28from tensorflow.python.util.deprecation import deprecated29
30
31class _Datasets(typing.NamedTuple):32train: "_DataSet"33validation: "_DataSet"34test: "_DataSet"35
36
37# CVDF mirror of http://yann.lecun.com/exdb/mnist/
38DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"39
40
41def _read32(bytestream):42dt = np.dtype(np.uint32).newbyteorder(">")43return np.frombuffer(bytestream.read(4), dtype=dt)[0]44
45
46@deprecated(None, "Please use tf.data to implement this functionality.")47def _extract_images(f):48"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].49
50Args:
51f: A file object that can be passed into a gzip reader.
52
53Returns:
54data: A 4D uint8 numpy array [index, y, x, depth].
55
56Raises:
57ValueError: If the bytestream does not start with 2051.
58
59"""
60print("Extracting", f.name)61with gzip.GzipFile(fileobj=f) as bytestream:62magic = _read32(bytestream)63if magic != 2051:64raise ValueError(65"Invalid magic number %d in MNIST image file: %s" % (magic, f.name)66)67num_images = _read32(bytestream)68rows = _read32(bytestream)69cols = _read32(bytestream)70buf = bytestream.read(rows * cols * num_images)71data = np.frombuffer(buf, dtype=np.uint8)72data = data.reshape(num_images, rows, cols, 1)73return data74
75
76@deprecated(None, "Please use tf.one_hot on tensors.")77def _dense_to_one_hot(labels_dense, num_classes):78"""Convert class labels from scalars to one-hot vectors."""79num_labels = labels_dense.shape[0]80index_offset = np.arange(num_labels) * num_classes81labels_one_hot = np.zeros((num_labels, num_classes))82labels_one_hot.flat[index_offset + labels_dense.ravel()] = 183return labels_one_hot84
85
86@deprecated(None, "Please use tf.data to implement this functionality.")87def _extract_labels(f, one_hot=False, num_classes=10):88"""Extract the labels into a 1D uint8 numpy array [index].89
90Args:
91f: A file object that can be passed into a gzip reader.
92one_hot: Does one hot encoding for the result.
93num_classes: Number of classes for the one hot encoding.
94
95Returns:
96labels: a 1D uint8 numpy array.
97
98Raises:
99ValueError: If the bystream doesn't start with 2049.
100"""
101print("Extracting", f.name)102with gzip.GzipFile(fileobj=f) as bytestream:103magic = _read32(bytestream)104if magic != 2049:105raise ValueError(106"Invalid magic number %d in MNIST label file: %s" % (magic, f.name)107)108num_items = _read32(bytestream)109buf = bytestream.read(num_items)110labels = np.frombuffer(buf, dtype=np.uint8)111if one_hot:112return _dense_to_one_hot(labels, num_classes)113return labels114
115
116class _DataSet:117"""Container class for a _DataSet (deprecated).118
119THIS CLASS IS DEPRECATED.
120"""
121
122@deprecated(123None,124"Please use alternatives such as official/mnist/_DataSet.py"125" from tensorflow/models.",126)127def __init__(128self,129images,130labels,131fake_data=False,132one_hot=False,133dtype=dtypes.float32,134reshape=True,135seed=None,136):137"""Construct a _DataSet.138
139one_hot arg is used only if fake_data is true. `dtype` can be either
140`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
141`[0, 1]`. Seed arg provides for convenient deterministic testing.
142
143Args:
144images: The images
145labels: The labels
146fake_data: Ignore inages and labels, use fake data.
147one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
148False).
149dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
150range [0,255]. float32 output has range [0,1].
151reshape: Bool. If True returned images are returned flattened to vectors.
152seed: The random seed to use.
153"""
154seed1, seed2 = random_seed.get_seed(seed)155# If op level seed is not set, use whatever graph level seed is returned156self._rng = np.random.default_rng(seed1 if seed is None else seed2)157dtype = dtypes.as_dtype(dtype).base_dtype158if dtype not in (dtypes.uint8, dtypes.float32):159raise TypeError("Invalid image dtype %r, expected uint8 or float32" % dtype)160if fake_data:161self._num_examples = 10000162self.one_hot = one_hot163else:164assert (165images.shape[0] == labels.shape[0]166), f"images.shape: {images.shape} labels.shape: {labels.shape}"167self._num_examples = images.shape[0]168
169# Convert shape from [num examples, rows, columns, depth]170# to [num examples, rows*columns] (assuming depth == 1)171if reshape:172assert images.shape[3] == 1173images = images.reshape(174images.shape[0], images.shape[1] * images.shape[2]175)176if dtype == dtypes.float32:177# Convert from [0, 255] -> [0.0, 1.0].178images = images.astype(np.float32)179images = np.multiply(images, 1.0 / 255.0)180self._images = images181self._labels = labels182self._epochs_completed = 0183self._index_in_epoch = 0184
185@property186def images(self):187return self._images188
189@property190def labels(self):191return self._labels192
193@property194def num_examples(self):195return self._num_examples196
197@property198def epochs_completed(self):199return self._epochs_completed200
201def next_batch(self, batch_size, fake_data=False, shuffle=True):202"""Return the next `batch_size` examples from this data set."""203if fake_data:204fake_image = [1] * 784205fake_label = [1] + [0] * 9 if self.one_hot else 0206return (207[fake_image for _ in range(batch_size)],208[fake_label for _ in range(batch_size)],209)210start = self._index_in_epoch211# Shuffle for the first epoch212if self._epochs_completed == 0 and start == 0 and shuffle:213perm0 = np.arange(self._num_examples)214self._rng.shuffle(perm0)215self._images = self.images[perm0]216self._labels = self.labels[perm0]217# Go to the next epoch218if start + batch_size > self._num_examples:219# Finished epoch220self._epochs_completed += 1221# Get the rest examples in this epoch222rest_num_examples = self._num_examples - start223images_rest_part = self._images[start : self._num_examples]224labels_rest_part = self._labels[start : self._num_examples]225# Shuffle the data226if shuffle:227perm = np.arange(self._num_examples)228self._rng.shuffle(perm)229self._images = self.images[perm]230self._labels = self.labels[perm]231# Start next epoch232start = 0233self._index_in_epoch = batch_size - rest_num_examples234end = self._index_in_epoch235images_new_part = self._images[start:end]236labels_new_part = self._labels[start:end]237return (238np.concatenate((images_rest_part, images_new_part), axis=0),239np.concatenate((labels_rest_part, labels_new_part), axis=0),240)241else:242self._index_in_epoch += batch_size243end = self._index_in_epoch244return self._images[start:end], self._labels[start:end]245
246
247@deprecated(None, "Please write your own downloading logic.")248def _maybe_download(filename, work_directory, source_url):249"""Download the data from source url, unless it's already here.250
251Args:
252filename: string, name of the file in the directory.
253work_directory: string, path to working directory.
254source_url: url to download from if file doesn't exist.
255
256Returns:
257Path to resulting file.
258"""
259if not gfile.Exists(work_directory):260gfile.MakeDirs(work_directory)261filepath = os.path.join(work_directory, filename)262if not gfile.Exists(filepath):263urllib.request.urlretrieve(source_url, filepath) # noqa: S310264with gfile.GFile(filepath) as f:265size = f.size()266print("Successfully downloaded", filename, size, "bytes.")267return filepath268
269
270@deprecated(None, "Please use alternatives such as: tensorflow_datasets.load('mnist')")271def read_data_sets(272train_dir,273fake_data=False,274one_hot=False,275dtype=dtypes.float32,276reshape=True,277validation_size=5000,278seed=None,279source_url=DEFAULT_SOURCE_URL,280):281if fake_data:282
283def fake():284return _DataSet(285[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed286)287
288train = fake()289validation = fake()290test = fake()291return _Datasets(train=train, validation=validation, test=test)292
293if not source_url: # empty string check294source_url = DEFAULT_SOURCE_URL295
296train_images_file = "train-images-idx3-ubyte.gz"297train_labels_file = "train-labels-idx1-ubyte.gz"298test_images_file = "t10k-images-idx3-ubyte.gz"299test_labels_file = "t10k-labels-idx1-ubyte.gz"300
301local_file = _maybe_download(302train_images_file, train_dir, source_url + train_images_file303)304with gfile.Open(local_file, "rb") as f:305train_images = _extract_images(f)306
307local_file = _maybe_download(308train_labels_file, train_dir, source_url + train_labels_file309)310with gfile.Open(local_file, "rb") as f:311train_labels = _extract_labels(f, one_hot=one_hot)312
313local_file = _maybe_download(314test_images_file, train_dir, source_url + test_images_file315)316with gfile.Open(local_file, "rb") as f:317test_images = _extract_images(f)318
319local_file = _maybe_download(320test_labels_file, train_dir, source_url + test_labels_file321)322with gfile.Open(local_file, "rb") as f:323test_labels = _extract_labels(f, one_hot=one_hot)324
325if not 0 <= validation_size <= len(train_images):326msg = (327"Validation size should be between 0 and "328f"{len(train_images)}. Received: {validation_size}."329)330raise ValueError(msg)331
332validation_images = train_images[:validation_size]333validation_labels = train_labels[:validation_size]334train_images = train_images[validation_size:]335train_labels = train_labels[validation_size:]336
337options = {"dtype": dtype, "reshape": reshape, "seed": seed}338
339train = _DataSet(train_images, train_labels, **options)340validation = _DataSet(validation_images, validation_labels, **options)341test = _DataSet(test_images, test_labels, **options)342
343return _Datasets(train=train, validation=validation, test=test)344