google-research

Форк
0
89 строк · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests the `u_net` module."""
17
import tensorflow as tf
18

19
from flare_removal.python import u_net
20

21

22
class UNetTest(tf.test.TestCase):
23

24
  def test_zero_scale(self):
25
    model = u_net.get_model(
26
        input_shape=(128, 128, 1), scales=0, bottleneck_depth=32)
27
    model.summary()
28

29
    input_layer = model.get_layer('input')
30
    bottleneck_conv1 = model.get_layer('bottleneck_conv1')
31
    bottleneck_conv2 = model.get_layer('bottleneck_conv2')
32
    output_layer = model.get_layer('output')
33
    self.assertIs(input_layer.output, bottleneck_conv1.input)
34
    self.assertIs(bottleneck_conv1.output, bottleneck_conv2.input)
35
    self.assertIs(bottleneck_conv2.output, output_layer.input)
36
    self.assertAllEqual(model.input_shape, [None, 128, 128, 1])
37
    self.assertAllEqual(bottleneck_conv1.output_shape, [None, 128, 128, 32])
38
    self.assertAllEqual(bottleneck_conv2.output_shape, [None, 128, 128, 32])
39
    self.assertAllEqual(model.output_shape, [None, 128, 128, 1])
40

41
  def test_one_scale(self):
42
    model = u_net.get_model(
43
        input_shape=(64, 64, 3), scales=1, bottleneck_depth=128)
44
    model.summary()
45

46
    # Downscaling arm.
47
    input_layer = model.get_layer('input')
48
    down_conv1 = model.get_layer('down64_conv1')
49
    down_conv2 = model.get_layer('down64_conv2')
50
    down_pool = model.get_layer('down64_pool')
51
    bottleneck_conv1 = model.get_layer('bottleneck_conv1')
52
    self.assertIs(input_layer.output, down_conv1.input)
53
    self.assertIs(down_conv1.output, down_conv2.input)
54
    self.assertIs(down_conv2.output, down_pool.input)
55
    self.assertIs(down_pool.output, bottleneck_conv1.input)
56
    self.assertAllEqual(model.input_shape, [None, 64, 64, 3])
57
    self.assertAllEqual(down_conv1.output_shape, [None, 64, 64, 64])
58
    self.assertAllEqual(down_conv2.output_shape, [None, 64, 64, 64])
59
    self.assertAllEqual(down_pool.output_shape, [None, 32, 32, 64])
60
    self.assertAllEqual(bottleneck_conv1.output_shape, [None, 32, 32, 128])
61

62
    # Upscaling arm.
63
    bottleneck_conv2 = model.get_layer('bottleneck_conv2')
64
    up_2x = model.get_layer('up64_2x')
65
    up_2xconv = model.get_layer('up64_2xconv')
66
    up_concat = model.get_layer('up64_concat')
67
    up_conv1 = model.get_layer('up64_conv1')
68
    up_conv2 = model.get_layer('up64_conv2')
69
    output_layer = model.get_layer('output')
70
    self.assertIs(bottleneck_conv2.output, up_2x.input)
71
    self.assertIs(up_2x.output, up_2xconv.input)
72
    self.assertIs(up_2xconv.output, up_concat.input[0])
73
    self.assertIs(up_concat.output, up_conv1.input)
74
    self.assertIs(up_conv1.output, up_conv2.input)
75
    self.assertIs(up_conv2.output, output_layer.input)
76
    self.assertAllEqual(bottleneck_conv2.output_shape, [None, 32, 32, 128])
77
    self.assertAllEqual(up_2x.output_shape, [None, 64, 64, 128])
78
    self.assertAllEqual(up_2xconv.output_shape, [None, 64, 64, 64])
79
    self.assertAllEqual(up_concat.output_shape, [None, 64, 64, 128])
80
    self.assertAllEqual(up_conv1.output_shape, [None, 64, 64, 64])
81
    self.assertAllEqual(up_conv2.output_shape, [None, 64, 64, 64])
82
    self.assertAllEqual(output_layer.output_shape, [None, 64, 64, 3])
83

84
    # Skip connection.
85
    self.assertIs(down_conv2.output, up_concat.input[1])
86

87

88
if __name__ == '__main__':
89
  tf.test.main()
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.