google-research

Форк
0
59 строк · 2.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""tf.data.Dataset interface to the CIFAR10 dataset."""
17

18
from absl import logging
19
import tensorflow as tf
20

21
IMG_DIM = 32
22
NUM_CHANNELS = 3
23
NUM_LABELS = 10
24

25

26
def dataset_randomized(pattern):
27
  """tf.data.Dataset object for CIFAR-10 training data."""
28
  filenames = tf.io.gfile.glob(pattern)
29
  logging.info('*** Input Files ***')
30
  for input_file in filenames:
31
    logging.info('  %s', input_file)
32

33
  ds_filenames = tf.data.Dataset.from_tensor_slices(tf.constant(filenames))
34
  ds_filenames = ds_filenames.shuffle(buffer_size=len(filenames))
35

36
  dataset = tf.data.TFRecordDataset(ds_filenames)
37

38
  # Create a description of the features.
39
  feature_description = {
40
      'image/class/label': tf.io.FixedLenFeature([], tf.int64),
41
      'image/class/shuffled_label': tf.io.FixedLenFeature([], tf.int64),
42
      'image/encoded': tf.io.FixedLenFeature([], tf.string),
43
  }
44

45
  def decode_image(image):
46
    image = tf.io.decode_png(image)
47
    image = tf.cast(image, tf.float32)
48
    image = tf.image.per_image_standardization(image)
49
    image = tf.reshape(image, [IMG_DIM * IMG_DIM * NUM_CHANNELS])
50
    return image
51

52
  def parse_function(example_proto):
53
    # Parse the input tf.Example proto using the dictionary above.
54
    features = tf.io.parse_single_example(example_proto, feature_description)
55
    features['image/encoded'] = decode_image(features['image/encoded'])
56
    return features
57

58
  dataset = dataset.map(parse_function)
59
  return dataset
60

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.