google-research

Форк
0
66 строк · 2.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Model definition."""
17

18
import tensorflow as tf
19

20

21
def build_model(image_size, bias_last=True, num_classes=1, squeeze=True):
22
  """Builds model."""
23

24
  input_shape = (image_size, image_size, 3)
25
  image = tf.keras.Input(shape=input_shape, name='input_image')
26
  training = tf.keras.Input(shape=[], name='training')
27

28
  x = tf.keras.layers.Conv2D(
29
      128, (3, 3), strides=(1, 1), padding='valid', activation=None)(
30
          image)
31
  x = tf.keras.layers.BatchNormalization()(x, training)
32
  x = tf.keras.layers.ReLU()(x)
33
  x = tf.keras.layers.Conv2D(
34
      128, (3, 3), strides=(2, 2), padding='valid', activation=None)(
35
          x)
36
  x = tf.keras.layers.BatchNormalization()(x, training)
37
  x = tf.keras.layers.ReLU()(x)
38
  x = tf.keras.layers.Conv2D(
39
      256, (3, 3), strides=(2, 2), padding='valid', activation=None)(
40
          x)
41
  x = tf.keras.layers.BatchNormalization()(x, training)
42
  x = tf.keras.layers.ReLU()(x)
43
  x = tf.keras.layers.Conv2D(
44
      256, (3, 3), strides=(2, 2), padding='valid', activation=None)(
45
          x)
46
  x = tf.keras.layers.BatchNormalization()(x, training)
47
  x = tf.keras.layers.ReLU()(x)
48
  x = tf.keras.layers.Conv2D(
49
      512, (1, 1), strides=(1, 1), padding='valid', activation=None)(
50
          x)
51
  x = tf.keras.layers.BatchNormalization()(x, training)
52
  x = tf.keras.layers.ReLU()(x)
53
  # x = tf.keras.layers.Conv2D(64, (2, 2), padding='valid')(x)
54
  x = tf.keras.layers.Flatten()(x)
55

56
  last_layer_fc = tf.keras.layers.Dense(num_classes, use_bias=bias_last)
57

58
  if squeeze:
59
    x = tf.squeeze(last_layer_fc(x))
60
  else:
61
    x = last_layer_fc(x)
62

63
  model = tf.keras.models.Model(
64
      inputs=[image, training], outputs=x, name='model')
65
  model.summary()
66
  return model
67

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.