google-research
119 строк · 3.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""LeNET-like model architecture."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22from absl import flags23
24import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import25from tensorflow import estimator as tf_estimator26
27flags.DEFINE_string('source_dataset', 'mnist', 'Name of the source dataset.')28flags.DEFINE_string('target_dataset', 'svhn_cropped_small',29'Name of the target dataset.')30flags.DEFINE_integer('src_num_classes', 10,31'The number of classes in the source dataset.')32flags.DEFINE_integer('src_hw', 28, 'The height and width of source inputs.')33flags.DEFINE_integer('target_hw', 32, 'The height and width of source inputs.')34flags.DEFINE_integer('random_seed', 1, 'Random seed.')35
36FLAGS = flags.FLAGS37
38
39def conv_model(features,40mode,41target_dataset,42src_hw=28,43target_hw=32,44dataset_name=None,45reuse=None):46"""Architecture of the LeNet model for MNIST."""47
48def build_network(features, is_training):49"""Returns the network output."""50# Input reshape51if dataset_name == 'mnist' or target_dataset == 'mnist':52input_layer = tf.reshape(features, [-1, src_hw, src_hw, 1])53input_layer = tf.pad(input_layer, [[0, 0], [2, 2], [2, 2], [0, 0]])54else:55input_layer = tf.reshape(features, [-1, target_hw, target_hw, 3])56input_layer = tf.image.rgb_to_grayscale(input_layer)57
58input_layer = tf.reshape(input_layer, [-1, target_hw, target_hw, 1])59input_layer = tf.image.convert_image_dtype(input_layer, dtype=tf.float32)60
61discard_rate = 0.262
63conv1 = tf.compat.v1.layers.conv2d(64inputs=input_layer,65filters=32,66kernel_size=[5, 5],67padding='same',68activation=tf.nn.relu,69name='conv1',70reuse=reuse)71
72pool1 = tf.compat.v1.layers.max_pooling2d(73inputs=conv1, pool_size=[2, 2], strides=2)74
75if is_training:76pool1 = tf.compat.v1.layers.dropout(inputs=pool1, rate=discard_rate)77
78conv2 = tf.compat.v1.layers.conv2d(79inputs=pool1,80filters=32,81kernel_size=[5, 5],82padding='same',83activation=tf.nn.relu,84name='conv2',85reuse=reuse,86)87
88pool2 = tf.compat.v1.layers.max_pooling2d(89inputs=conv2, pool_size=[2, 2], strides=2)90
91if is_training:92pool2 = tf.compat.v1.layers.dropout(inputs=pool2, rate=discard_rate)93
94pool2_flat = tf.reshape(pool2, [-1, 2048])95dense = tf.compat.v1.layers.dense(96inputs=pool2_flat,97units=512,98activation=tf.nn.relu,99name='dense1',100reuse=reuse)101
102if is_training:103dense = tf.compat.v1.layers.dropout(inputs=dense, rate=discard_rate)104
105dense = tf.compat.v1.layers.dense(106inputs=dense,107units=128,108activation=tf.nn.relu,109name='dense2',110reuse=reuse)111
112if is_training:113dense = tf.compat.v1.layers.dropout(inputs=dense, rate=discard_rate)114
115return dense116
117is_training = mode == tf_estimator.ModeKeys.TRAIN118network_output = build_network(features, is_training=is_training)119return network_output120