google-research
146 строк · 4.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Code to generate hparam-metric pairs for the hyperparameter optimization experiments."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import random
23from absl import app
24from absl import flags
25import numpy as np
26import tensorflow as tf
27import tensorflow_datasets as tfds
28
29_DATASET = flags.DEFINE_string(
30"dataset", "mnist",
31"cifar10 / cifar100 / mnist / fashion_mnist / svhn_cropped")
32_HPARAMS = flags.DEFINE_string(
33"hparams", "",
34"use the following format: conv_units1;conv_units2;conv_units3;dense_units1;dense_units2;kernel_width;pool_width;epochs;method"
35)
36
37
38def normalize_img(image, label):
39"""Normalizes images: `uint8` -> `float32`."""
40return tf.cast(image, tf.float32) / 255., label
41
42
43def main(_):
44tfds.disable_progress_bar()
45
46if _DATASET == "mnist" or _DATASET == "fashion_mnist":
47input_shape = (28, 28, 1)
48output_size = 10
49elif _DATASET == "cifar10" or _DATASET == "svhn_cropped":
50input_shape = (32, 32, 3)
51output_size = 10
52elif _DATASET == "cifar100":
53input_shape = (32, 32, 3)
54output_size = 100
55
56if _HPARAMS:
57ds_train, ds_test = tfds.load(
58_DATASET,
59split=["train", "test"],
60shuffle_files=True,
61as_supervised=True)
62else:
63ds_train, ds_test = tfds.load(
64_DATASET,
65split=["train[0%:90%]", "train[90%:100%]"],
66shuffle_files=True,
67as_supervised=True)
68ds_train = ds_train.map(
69normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
70ds_train = ds_train.cache()
71ds_train = ds_train.batch(128)
72ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
73ds_test = ds_test.map(
74normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
75ds_test = ds_test.batch(128)
76ds_test = ds_test.cache()
77ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
78
79if _HPARAMS:
80hparams = _HPARAMS.split(";")
81conv_units1 = int(hparams[0])
82conv_units2 = int(hparams[1])
83conv_units3 = int(hparams[2])
84dense_units1 = int(hparams[3])
85dense_units2 = int(hparams[4])
86kernel_width = int(hparams[5])
87pool_width = int(hparams[6])
88epochs = int(hparams[7])
89else:
90conv_units1 = int(np.round(random.uniform(8, 512)))
91conv_units2 = int(np.round(random.uniform(8, 512)))
92conv_units3 = int(np.round(random.uniform(8, 512)))
93dense_units1 = int(np.round(random.uniform(8, 512)))
94dense_units2 = int(np.round(random.uniform(8, 512)))
95kernel_width = int(np.round(random.uniform(2, 6)))
96pool_width = int(np.round(random.uniform(2, 6)))
97epochs = int(np.round(random.uniform(1, 25)))
98
99model = tf.keras.models.Sequential()
100model.add(
101tf.keras.layers.Conv2D(
102conv_units1, (kernel_width, kernel_width),
103activation="relu",
104padding="same",
105input_shape=input_shape))
106model.add(
107tf.keras.layers.MaxPooling2D((pool_width, pool_width), padding="same"))
108model.add(
109tf.keras.layers.Conv2D(
110conv_units2, (kernel_width, kernel_width),
111padding="same",
112activation="relu"))
113model.add(
114tf.keras.layers.MaxPooling2D((pool_width, pool_width), padding="same"))
115model.add(
116tf.keras.layers.Conv2D(
117conv_units3, (kernel_width, kernel_width),
118padding="same",
119activation="relu"))
120model.add(tf.keras.layers.Flatten())
121model.add(tf.keras.layers.Dense(dense_units1, activation="relu"))
122model.add(tf.keras.layers.Dense(dense_units2, activation="relu"))
123model.add(tf.keras.layers.Dense(output_size, activation="softmax"))
124
125model.compile(
126loss="sparse_categorical_crossentropy",
127optimizer=tf.keras.optimizers.Adam(),
128metrics=["accuracy"],
129)
130history = model.fit(
131ds_train, epochs=epochs, validation_data=ds_test, verbose=0)
132
133print("[metric] conv_units1=" + str(conv_units1))
134print("[metric] conv_units2=" + str(conv_units2))
135print("[metric] conv_units3=" + str(conv_units3))
136print("[metric] dense_units1=" + str(dense_units1))
137print("[metric] dense_units2=" + str(dense_units2))
138print("[metric] kernel_width=" + str(kernel_width))
139print("[metric] pool_width=" + str(pool_width))
140print("[metric] epochs=" + str(epochs))
141print("[metric] val_accuracy=" + str(history.history["val_accuracy"][-1]))
142print(history.history)
143
144
145if __name__ == "__main__":
146app.run(main)
147