TheAlgorithms-Python

Форк
0
343 строки · 11.7 Кб
1
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
# ==============================================================================
15
"""Functions for downloading and reading MNIST data (deprecated).
16

17
This module and all its submodules are deprecated.
18
"""
19

20
import gzip
21
import os
22
import typing
23
import urllib
24

25
import numpy as np
26
from tensorflow.python.framework import dtypes, random_seed
27
from tensorflow.python.platform import gfile
28
from tensorflow.python.util.deprecation import deprecated
29

30

31
class _Datasets(typing.NamedTuple):
32
    train: "_DataSet"
33
    validation: "_DataSet"
34
    test: "_DataSet"
35

36

37
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
38
DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"
39

40

41
def _read32(bytestream):
42
    dt = np.dtype(np.uint32).newbyteorder(">")
43
    return np.frombuffer(bytestream.read(4), dtype=dt)[0]
44

45

46
@deprecated(None, "Please use tf.data to implement this functionality.")
47
def _extract_images(f):
48
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth].
49

50
    Args:
51
      f: A file object that can be passed into a gzip reader.
52

53
    Returns:
54
      data: A 4D uint8 numpy array [index, y, x, depth].
55

56
    Raises:
57
      ValueError: If the bytestream does not start with 2051.
58

59
    """
60
    print("Extracting", f.name)
61
    with gzip.GzipFile(fileobj=f) as bytestream:
62
        magic = _read32(bytestream)
63
        if magic != 2051:
64
            raise ValueError(
65
                "Invalid magic number %d in MNIST image file: %s" % (magic, f.name)
66
            )
67
        num_images = _read32(bytestream)
68
        rows = _read32(bytestream)
69
        cols = _read32(bytestream)
70
        buf = bytestream.read(rows * cols * num_images)
71
        data = np.frombuffer(buf, dtype=np.uint8)
72
        data = data.reshape(num_images, rows, cols, 1)
73
        return data
74

75

76
@deprecated(None, "Please use tf.one_hot on tensors.")
77
def _dense_to_one_hot(labels_dense, num_classes):
78
    """Convert class labels from scalars to one-hot vectors."""
79
    num_labels = labels_dense.shape[0]
80
    index_offset = np.arange(num_labels) * num_classes
81
    labels_one_hot = np.zeros((num_labels, num_classes))
82
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
83
    return labels_one_hot
84

85

86
@deprecated(None, "Please use tf.data to implement this functionality.")
87
def _extract_labels(f, one_hot=False, num_classes=10):
88
    """Extract the labels into a 1D uint8 numpy array [index].
89

90
    Args:
91
      f: A file object that can be passed into a gzip reader.
92
      one_hot: Does one hot encoding for the result.
93
      num_classes: Number of classes for the one hot encoding.
94

95
    Returns:
96
      labels: a 1D uint8 numpy array.
97

98
    Raises:
99
      ValueError: If the bystream doesn't start with 2049.
100
    """
101
    print("Extracting", f.name)
102
    with gzip.GzipFile(fileobj=f) as bytestream:
103
        magic = _read32(bytestream)
104
        if magic != 2049:
105
            raise ValueError(
106
                "Invalid magic number %d in MNIST label file: %s" % (magic, f.name)
107
            )
108
        num_items = _read32(bytestream)
109
        buf = bytestream.read(num_items)
110
        labels = np.frombuffer(buf, dtype=np.uint8)
111
        if one_hot:
112
            return _dense_to_one_hot(labels, num_classes)
113
        return labels
114

115

116
class _DataSet:
117
    """Container class for a _DataSet (deprecated).
118

119
    THIS CLASS IS DEPRECATED.
120
    """
121

122
    @deprecated(
123
        None,
124
        "Please use alternatives such as official/mnist/_DataSet.py"
125
        " from tensorflow/models.",
126
    )
127
    def __init__(
128
        self,
129
        images,
130
        labels,
131
        fake_data=False,
132
        one_hot=False,
133
        dtype=dtypes.float32,
134
        reshape=True,
135
        seed=None,
136
    ):
137
        """Construct a _DataSet.
138

139
        one_hot arg is used only if fake_data is true.  `dtype` can be either
140
        `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
141
        `[0, 1]`.  Seed arg provides for convenient deterministic testing.
142

143
        Args:
144
          images: The images
145
          labels: The labels
146
          fake_data: Ignore inages and labels, use fake data.
147
          one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
148
            False).
149
          dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
150
            range [0,255]. float32 output has range [0,1].
151
          reshape: Bool. If True returned images are returned flattened to vectors.
152
          seed: The random seed to use.
153
        """
154
        seed1, seed2 = random_seed.get_seed(seed)
155
        # If op level seed is not set, use whatever graph level seed is returned
156
        self._rng = np.random.default_rng(seed1 if seed is None else seed2)
157
        dtype = dtypes.as_dtype(dtype).base_dtype
158
        if dtype not in (dtypes.uint8, dtypes.float32):
159
            raise TypeError("Invalid image dtype %r, expected uint8 or float32" % dtype)
160
        if fake_data:
161
            self._num_examples = 10000
162
            self.one_hot = one_hot
163
        else:
164
            assert (
165
                images.shape[0] == labels.shape[0]
166
            ), f"images.shape: {images.shape} labels.shape: {labels.shape}"
167
            self._num_examples = images.shape[0]
168

169
            # Convert shape from [num examples, rows, columns, depth]
170
            # to [num examples, rows*columns] (assuming depth == 1)
171
            if reshape:
172
                assert images.shape[3] == 1
173
                images = images.reshape(
174
                    images.shape[0], images.shape[1] * images.shape[2]
175
                )
176
            if dtype == dtypes.float32:
177
                # Convert from [0, 255] -> [0.0, 1.0].
178
                images = images.astype(np.float32)
179
                images = np.multiply(images, 1.0 / 255.0)
180
        self._images = images
181
        self._labels = labels
182
        self._epochs_completed = 0
183
        self._index_in_epoch = 0
184

185
    @property
186
    def images(self):
187
        return self._images
188

189
    @property
190
    def labels(self):
191
        return self._labels
192

193
    @property
194
    def num_examples(self):
195
        return self._num_examples
196

197
    @property
198
    def epochs_completed(self):
199
        return self._epochs_completed
200

201
    def next_batch(self, batch_size, fake_data=False, shuffle=True):
202
        """Return the next `batch_size` examples from this data set."""
203
        if fake_data:
204
            fake_image = [1] * 784
205
            fake_label = [1] + [0] * 9 if self.one_hot else 0
206
            return (
207
                [fake_image for _ in range(batch_size)],
208
                [fake_label for _ in range(batch_size)],
209
            )
210
        start = self._index_in_epoch
211
        # Shuffle for the first epoch
212
        if self._epochs_completed == 0 and start == 0 and shuffle:
213
            perm0 = np.arange(self._num_examples)
214
            self._rng.shuffle(perm0)
215
            self._images = self.images[perm0]
216
            self._labels = self.labels[perm0]
217
        # Go to the next epoch
218
        if start + batch_size > self._num_examples:
219
            # Finished epoch
220
            self._epochs_completed += 1
221
            # Get the rest examples in this epoch
222
            rest_num_examples = self._num_examples - start
223
            images_rest_part = self._images[start : self._num_examples]
224
            labels_rest_part = self._labels[start : self._num_examples]
225
            # Shuffle the data
226
            if shuffle:
227
                perm = np.arange(self._num_examples)
228
                self._rng.shuffle(perm)
229
                self._images = self.images[perm]
230
                self._labels = self.labels[perm]
231
            # Start next epoch
232
            start = 0
233
            self._index_in_epoch = batch_size - rest_num_examples
234
            end = self._index_in_epoch
235
            images_new_part = self._images[start:end]
236
            labels_new_part = self._labels[start:end]
237
            return (
238
                np.concatenate((images_rest_part, images_new_part), axis=0),
239
                np.concatenate((labels_rest_part, labels_new_part), axis=0),
240
            )
241
        else:
242
            self._index_in_epoch += batch_size
243
            end = self._index_in_epoch
244
            return self._images[start:end], self._labels[start:end]
245

246

247
@deprecated(None, "Please write your own downloading logic.")
248
def _maybe_download(filename, work_directory, source_url):
249
    """Download the data from source url, unless it's already here.
250

251
    Args:
252
        filename: string, name of the file in the directory.
253
        work_directory: string, path to working directory.
254
        source_url: url to download from if file doesn't exist.
255

256
    Returns:
257
        Path to resulting file.
258
    """
259
    if not gfile.Exists(work_directory):
260
        gfile.MakeDirs(work_directory)
261
    filepath = os.path.join(work_directory, filename)
262
    if not gfile.Exists(filepath):
263
        urllib.request.urlretrieve(source_url, filepath)  # noqa: S310
264
        with gfile.GFile(filepath) as f:
265
            size = f.size()
266
        print("Successfully downloaded", filename, size, "bytes.")
267
    return filepath
268

269

270
@deprecated(None, "Please use alternatives such as: tensorflow_datasets.load('mnist')")
271
def read_data_sets(
272
    train_dir,
273
    fake_data=False,
274
    one_hot=False,
275
    dtype=dtypes.float32,
276
    reshape=True,
277
    validation_size=5000,
278
    seed=None,
279
    source_url=DEFAULT_SOURCE_URL,
280
):
281
    if fake_data:
282

283
        def fake():
284
            return _DataSet(
285
                [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed
286
            )
287

288
        train = fake()
289
        validation = fake()
290
        test = fake()
291
        return _Datasets(train=train, validation=validation, test=test)
292

293
    if not source_url:  # empty string check
294
        source_url = DEFAULT_SOURCE_URL
295

296
    train_images_file = "train-images-idx3-ubyte.gz"
297
    train_labels_file = "train-labels-idx1-ubyte.gz"
298
    test_images_file = "t10k-images-idx3-ubyte.gz"
299
    test_labels_file = "t10k-labels-idx1-ubyte.gz"
300

301
    local_file = _maybe_download(
302
        train_images_file, train_dir, source_url + train_images_file
303
    )
304
    with gfile.Open(local_file, "rb") as f:
305
        train_images = _extract_images(f)
306

307
    local_file = _maybe_download(
308
        train_labels_file, train_dir, source_url + train_labels_file
309
    )
310
    with gfile.Open(local_file, "rb") as f:
311
        train_labels = _extract_labels(f, one_hot=one_hot)
312

313
    local_file = _maybe_download(
314
        test_images_file, train_dir, source_url + test_images_file
315
    )
316
    with gfile.Open(local_file, "rb") as f:
317
        test_images = _extract_images(f)
318

319
    local_file = _maybe_download(
320
        test_labels_file, train_dir, source_url + test_labels_file
321
    )
322
    with gfile.Open(local_file, "rb") as f:
323
        test_labels = _extract_labels(f, one_hot=one_hot)
324

325
    if not 0 <= validation_size <= len(train_images):
326
        msg = (
327
            "Validation size should be between 0 and "
328
            f"{len(train_images)}. Received: {validation_size}."
329
        )
330
        raise ValueError(msg)
331

332
    validation_images = train_images[:validation_size]
333
    validation_labels = train_labels[:validation_size]
334
    train_images = train_images[validation_size:]
335
    train_labels = train_labels[validation_size:]
336

337
    options = {"dtype": dtype, "reshape": reshape, "seed": seed}
338

339
    train = _DataSet(train_images, train_labels, **options)
340
    validation = _DataSet(validation_images, validation_labels, **options)
341
    test = _DataSet(test_images, test_labels, **options)
342

343
    return _Datasets(train=train, validation=validation, test=test)
344

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.