google-research

Форк
0
119 строк · 3.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""LeNET-like model architecture."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from absl import flags
23

24
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
25
from tensorflow import estimator as tf_estimator
26

27
flags.DEFINE_string('source_dataset', 'mnist', 'Name of the source dataset.')
28
flags.DEFINE_string('target_dataset', 'svhn_cropped_small',
29
                    'Name of the target dataset.')
30
flags.DEFINE_integer('src_num_classes', 10,
31
                     'The number of classes in the source dataset.')
32
flags.DEFINE_integer('src_hw', 28, 'The height and width of source inputs.')
33
flags.DEFINE_integer('target_hw', 32, 'The height and width of source inputs.')
34
flags.DEFINE_integer('random_seed', 1, 'Random seed.')
35

36
FLAGS = flags.FLAGS
37

38

39
def conv_model(features,
40
               mode,
41
               target_dataset,
42
               src_hw=28,
43
               target_hw=32,
44
               dataset_name=None,
45
               reuse=None):
46
  """Architecture of the LeNet model for MNIST."""
47

48
  def build_network(features, is_training):
49
    """Returns the network output."""
50
    # Input reshape
51
    if dataset_name == 'mnist' or target_dataset == 'mnist':
52
      input_layer = tf.reshape(features, [-1, src_hw, src_hw, 1])
53
      input_layer = tf.pad(input_layer, [[0, 0], [2, 2], [2, 2], [0, 0]])
54
    else:
55
      input_layer = tf.reshape(features, [-1, target_hw, target_hw, 3])
56
      input_layer = tf.image.rgb_to_grayscale(input_layer)
57

58
    input_layer = tf.reshape(input_layer, [-1, target_hw, target_hw, 1])
59
    input_layer = tf.image.convert_image_dtype(input_layer, dtype=tf.float32)
60

61
    discard_rate = 0.2
62

63
    conv1 = tf.compat.v1.layers.conv2d(
64
        inputs=input_layer,
65
        filters=32,
66
        kernel_size=[5, 5],
67
        padding='same',
68
        activation=tf.nn.relu,
69
        name='conv1',
70
        reuse=reuse)
71

72
    pool1 = tf.compat.v1.layers.max_pooling2d(
73
        inputs=conv1, pool_size=[2, 2], strides=2)
74

75
    if is_training:
76
      pool1 = tf.compat.v1.layers.dropout(inputs=pool1, rate=discard_rate)
77

78
    conv2 = tf.compat.v1.layers.conv2d(
79
        inputs=pool1,
80
        filters=32,
81
        kernel_size=[5, 5],
82
        padding='same',
83
        activation=tf.nn.relu,
84
        name='conv2',
85
        reuse=reuse,
86
    )
87

88
    pool2 = tf.compat.v1.layers.max_pooling2d(
89
        inputs=conv2, pool_size=[2, 2], strides=2)
90

91
    if is_training:
92
      pool2 = tf.compat.v1.layers.dropout(inputs=pool2, rate=discard_rate)
93

94
    pool2_flat = tf.reshape(pool2, [-1, 2048])
95
    dense = tf.compat.v1.layers.dense(
96
        inputs=pool2_flat,
97
        units=512,
98
        activation=tf.nn.relu,
99
        name='dense1',
100
        reuse=reuse)
101

102
    if is_training:
103
      dense = tf.compat.v1.layers.dropout(inputs=dense, rate=discard_rate)
104

105
    dense = tf.compat.v1.layers.dense(
106
        inputs=dense,
107
        units=128,
108
        activation=tf.nn.relu,
109
        name='dense2',
110
        reuse=reuse)
111

112
    if is_training:
113
      dense = tf.compat.v1.layers.dropout(inputs=dense, rate=discard_rate)
114

115
    return dense
116

117
  is_training = mode == tf_estimator.ModeKeys.TRAIN
118
  network_output = build_network(features, is_training=is_training)
119
  return network_output
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.