google-research
66 строк · 2.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Model definition."""
17
18import tensorflow as tf
19
20
21def build_model(image_size, bias_last=True, num_classes=1, squeeze=True):
22"""Builds model."""
23
24input_shape = (image_size, image_size, 3)
25image = tf.keras.Input(shape=input_shape, name='input_image')
26training = tf.keras.Input(shape=[], name='training')
27
28x = tf.keras.layers.Conv2D(
29128, (3, 3), strides=(1, 1), padding='valid', activation=None)(
30image)
31x = tf.keras.layers.BatchNormalization()(x, training)
32x = tf.keras.layers.ReLU()(x)
33x = tf.keras.layers.Conv2D(
34128, (3, 3), strides=(2, 2), padding='valid', activation=None)(
35x)
36x = tf.keras.layers.BatchNormalization()(x, training)
37x = tf.keras.layers.ReLU()(x)
38x = tf.keras.layers.Conv2D(
39256, (3, 3), strides=(2, 2), padding='valid', activation=None)(
40x)
41x = tf.keras.layers.BatchNormalization()(x, training)
42x = tf.keras.layers.ReLU()(x)
43x = tf.keras.layers.Conv2D(
44256, (3, 3), strides=(2, 2), padding='valid', activation=None)(
45x)
46x = tf.keras.layers.BatchNormalization()(x, training)
47x = tf.keras.layers.ReLU()(x)
48x = tf.keras.layers.Conv2D(
49512, (1, 1), strides=(1, 1), padding='valid', activation=None)(
50x)
51x = tf.keras.layers.BatchNormalization()(x, training)
52x = tf.keras.layers.ReLU()(x)
53# x = tf.keras.layers.Conv2D(64, (2, 2), padding='valid')(x)
54x = tf.keras.layers.Flatten()(x)
55
56last_layer_fc = tf.keras.layers.Dense(num_classes, use_bias=bias_last)
57
58if squeeze:
59x = tf.squeeze(last_layer_fc(x))
60else:
61x = last_layer_fc(x)
62
63model = tf.keras.models.Model(
64inputs=[image, training], outputs=x, name='model')
65model.summary()
66return model
67