google-research
89 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests the `u_net` module."""
17import tensorflow as tf18
19from flare_removal.python import u_net20
21
22class UNetTest(tf.test.TestCase):23
24def test_zero_scale(self):25model = u_net.get_model(26input_shape=(128, 128, 1), scales=0, bottleneck_depth=32)27model.summary()28
29input_layer = model.get_layer('input')30bottleneck_conv1 = model.get_layer('bottleneck_conv1')31bottleneck_conv2 = model.get_layer('bottleneck_conv2')32output_layer = model.get_layer('output')33self.assertIs(input_layer.output, bottleneck_conv1.input)34self.assertIs(bottleneck_conv1.output, bottleneck_conv2.input)35self.assertIs(bottleneck_conv2.output, output_layer.input)36self.assertAllEqual(model.input_shape, [None, 128, 128, 1])37self.assertAllEqual(bottleneck_conv1.output_shape, [None, 128, 128, 32])38self.assertAllEqual(bottleneck_conv2.output_shape, [None, 128, 128, 32])39self.assertAllEqual(model.output_shape, [None, 128, 128, 1])40
41def test_one_scale(self):42model = u_net.get_model(43input_shape=(64, 64, 3), scales=1, bottleneck_depth=128)44model.summary()45
46# Downscaling arm.47input_layer = model.get_layer('input')48down_conv1 = model.get_layer('down64_conv1')49down_conv2 = model.get_layer('down64_conv2')50down_pool = model.get_layer('down64_pool')51bottleneck_conv1 = model.get_layer('bottleneck_conv1')52self.assertIs(input_layer.output, down_conv1.input)53self.assertIs(down_conv1.output, down_conv2.input)54self.assertIs(down_conv2.output, down_pool.input)55self.assertIs(down_pool.output, bottleneck_conv1.input)56self.assertAllEqual(model.input_shape, [None, 64, 64, 3])57self.assertAllEqual(down_conv1.output_shape, [None, 64, 64, 64])58self.assertAllEqual(down_conv2.output_shape, [None, 64, 64, 64])59self.assertAllEqual(down_pool.output_shape, [None, 32, 32, 64])60self.assertAllEqual(bottleneck_conv1.output_shape, [None, 32, 32, 128])61
62# Upscaling arm.63bottleneck_conv2 = model.get_layer('bottleneck_conv2')64up_2x = model.get_layer('up64_2x')65up_2xconv = model.get_layer('up64_2xconv')66up_concat = model.get_layer('up64_concat')67up_conv1 = model.get_layer('up64_conv1')68up_conv2 = model.get_layer('up64_conv2')69output_layer = model.get_layer('output')70self.assertIs(bottleneck_conv2.output, up_2x.input)71self.assertIs(up_2x.output, up_2xconv.input)72self.assertIs(up_2xconv.output, up_concat.input[0])73self.assertIs(up_concat.output, up_conv1.input)74self.assertIs(up_conv1.output, up_conv2.input)75self.assertIs(up_conv2.output, output_layer.input)76self.assertAllEqual(bottleneck_conv2.output_shape, [None, 32, 32, 128])77self.assertAllEqual(up_2x.output_shape, [None, 64, 64, 128])78self.assertAllEqual(up_2xconv.output_shape, [None, 64, 64, 64])79self.assertAllEqual(up_concat.output_shape, [None, 64, 64, 128])80self.assertAllEqual(up_conv1.output_shape, [None, 64, 64, 64])81self.assertAllEqual(up_conv2.output_shape, [None, 64, 64, 64])82self.assertAllEqual(output_layer.output_shape, [None, 64, 64, 3])83
84# Skip connection.85self.assertIs(down_conv2.output, up_concat.input[1])86
87
88if __name__ == '__main__':89tf.test.main()90