google-research

Форк
0
/
hparams_sample.py 
146 строк · 4.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Code to generate hparam-metric pairs for the hyperparameter optimization experiments."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import random
23
from absl import app
24
from absl import flags
25
import numpy as np
26
import tensorflow as tf
27
import tensorflow_datasets as tfds
28

29
_DATASET = flags.DEFINE_string(
30
    "dataset", "mnist",
31
    "cifar10 / cifar100 / mnist / fashion_mnist / svhn_cropped")
32
_HPARAMS = flags.DEFINE_string(
33
    "hparams", "",
34
    "use the following format: conv_units1;conv_units2;conv_units3;dense_units1;dense_units2;kernel_width;pool_width;epochs;method"
35
)
36

37

38
def normalize_img(image, label):
39
  """Normalizes images: `uint8` -> `float32`."""
40
  return tf.cast(image, tf.float32) / 255., label
41

42

43
def main(_):
44
  tfds.disable_progress_bar()
45

46
  if _DATASET == "mnist" or _DATASET == "fashion_mnist":
47
    input_shape = (28, 28, 1)
48
    output_size = 10
49
  elif _DATASET == "cifar10" or _DATASET == "svhn_cropped":
50
    input_shape = (32, 32, 3)
51
    output_size = 10
52
  elif _DATASET == "cifar100":
53
    input_shape = (32, 32, 3)
54
    output_size = 100
55

56
  if _HPARAMS:
57
    ds_train, ds_test = tfds.load(
58
        _DATASET,
59
        split=["train", "test"],
60
        shuffle_files=True,
61
        as_supervised=True)
62
  else:
63
    ds_train, ds_test = tfds.load(
64
        _DATASET,
65
        split=["train[0%:90%]", "train[90%:100%]"],
66
        shuffle_files=True,
67
        as_supervised=True)
68
  ds_train = ds_train.map(
69
      normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
70
  ds_train = ds_train.cache()
71
  ds_train = ds_train.batch(128)
72
  ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
73
  ds_test = ds_test.map(
74
      normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
75
  ds_test = ds_test.batch(128)
76
  ds_test = ds_test.cache()
77
  ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
78

79
  if _HPARAMS:
80
    hparams = _HPARAMS.split(";")
81
    conv_units1 = int(hparams[0])
82
    conv_units2 = int(hparams[1])
83
    conv_units3 = int(hparams[2])
84
    dense_units1 = int(hparams[3])
85
    dense_units2 = int(hparams[4])
86
    kernel_width = int(hparams[5])
87
    pool_width = int(hparams[6])
88
    epochs = int(hparams[7])
89
  else:
90
    conv_units1 = int(np.round(random.uniform(8, 512)))
91
    conv_units2 = int(np.round(random.uniform(8, 512)))
92
    conv_units3 = int(np.round(random.uniform(8, 512)))
93
    dense_units1 = int(np.round(random.uniform(8, 512)))
94
    dense_units2 = int(np.round(random.uniform(8, 512)))
95
    kernel_width = int(np.round(random.uniform(2, 6)))
96
    pool_width = int(np.round(random.uniform(2, 6)))
97
    epochs = int(np.round(random.uniform(1, 25)))
98

99
  model = tf.keras.models.Sequential()
100
  model.add(
101
      tf.keras.layers.Conv2D(
102
          conv_units1, (kernel_width, kernel_width),
103
          activation="relu",
104
          padding="same",
105
          input_shape=input_shape))
106
  model.add(
107
      tf.keras.layers.MaxPooling2D((pool_width, pool_width), padding="same"))
108
  model.add(
109
      tf.keras.layers.Conv2D(
110
          conv_units2, (kernel_width, kernel_width),
111
          padding="same",
112
          activation="relu"))
113
  model.add(
114
      tf.keras.layers.MaxPooling2D((pool_width, pool_width), padding="same"))
115
  model.add(
116
      tf.keras.layers.Conv2D(
117
          conv_units3, (kernel_width, kernel_width),
118
          padding="same",
119
          activation="relu"))
120
  model.add(tf.keras.layers.Flatten())
121
  model.add(tf.keras.layers.Dense(dense_units1, activation="relu"))
122
  model.add(tf.keras.layers.Dense(dense_units2, activation="relu"))
123
  model.add(tf.keras.layers.Dense(output_size, activation="softmax"))
124

125
  model.compile(
126
      loss="sparse_categorical_crossentropy",
127
      optimizer=tf.keras.optimizers.Adam(),
128
      metrics=["accuracy"],
129
  )
130
  history = model.fit(
131
      ds_train, epochs=epochs, validation_data=ds_test, verbose=0)
132

133
  print("[metric] conv_units1=" + str(conv_units1))
134
  print("[metric] conv_units2=" + str(conv_units2))
135
  print("[metric] conv_units3=" + str(conv_units3))
136
  print("[metric] dense_units1=" + str(dense_units1))
137
  print("[metric] dense_units2=" + str(dense_units2))
138
  print("[metric] kernel_width=" + str(kernel_width))
139
  print("[metric] pool_width=" + str(pool_width))
140
  print("[metric] epochs=" + str(epochs))
141
  print("[metric] val_accuracy=" + str(history.history["val_accuracy"][-1]))
142
  print(history.history)
143

144

145
if __name__ == "__main__":
146
  app.run(main)
147

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.