google-research

Форк
0
/
colorizer_test.py 
105 строк · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for colorizer."""
17
from ml_collections import ConfigDict
18
import numpy as np
19
import tensorflow as tf
20
from coltran.models import colorizer
21

22

23
class ColTranCoreTest(tf.test.TestCase):
24

25
  def get_config(self, encoder_net='attention'):
26
    config = ConfigDict()
27
    config.image_bit_depth = 3
28
    config.encoder_1x1 = True
29
    config.resolution = [64, 64]
30
    config.batch_size = 2
31
    config.encoder_net = encoder_net
32
    config.hidden_size = 128
33
    config.stage = 'decoder'
34

35
    config.encoder = ConfigDict()
36
    config.encoder.dropout = 0.0
37
    config.encoder.ff_size = 128
38
    config.encoder.hidden_size = 128
39
    config.encoder.num_heads = 1
40
    config.encoder.num_encoder_layers = 1
41

42
    config.decoder = ConfigDict()
43
    config.decoder.ff_size = 128
44
    config.decoder.hidden_size = 128
45
    config.decoder.num_heads = 1
46
    config.decoder.num_outer_layers = 1
47
    config.decoder.num_inner_layers = 1
48
    config.decoder.resolution = [64, 64]
49
    config.decoder.dropout = 0.1
50
    config.decoder.cond_ln = True
51
    config.decoder.cond_q = True
52
    config.decoder.cond_k = True
53
    config.decoder.cond_v = True
54
    config.decoder.cond_q = True
55
    config.decoder.cond_scale = True
56
    config.decoder.cond_mlp = 'affine'
57
    return config
58

59
  def test_transformer_attention_encoder(self):
60
    config = self.get_config(encoder_net='attention')
61
    config.stage = 'encoder_decoder'
62
    transformer = colorizer.ColTranCore(config=config)
63
    images = tf.random.uniform(shape=(2, 2, 2, 3), minval=0,
64
                               maxval=256, dtype=tf.int32)
65
    logits = transformer(inputs=images, training=True)[0]
66
    self.assertEqual(logits.shape, (2, 2, 2, 1, 512))
67

68
    # batch-size=2
69
    gray = tf.image.rgb_to_grayscale(images)
70
    output = transformer.sample(gray, mode='argmax')
71
    output_np = output['auto_argmax'].numpy()
72
    proba_np = output['proba'].numpy()
73
    self.assertEqual(output_np.shape, (2, 2, 2, 3))
74
    self.assertEqual(proba_np.shape, (2, 2, 2, 512))
75
    # logging.info(output_np[0, ..., 0])
76

77
    # batch-size=1
78
    output_np_bs_1, proba_np_bs_1 = [], []
79
    for batch_ind in [0, 1]:
80
      curr_gray = tf.expand_dims(gray[batch_ind], axis=0)
81
      curr_out = transformer.sample(curr_gray, mode='argmax')
82
      curr_out_np = curr_out['auto_argmax'].numpy()
83
      curr_proba_np = curr_out['proba'].numpy()
84
      output_np_bs_1.append(curr_out_np)
85
      proba_np_bs_1.append(curr_proba_np)
86
    output_np_bs_1 = np.concatenate(output_np_bs_1, axis=0)
87
    proba_np_bs_1 = np.concatenate(proba_np_bs_1, axis=0)
88
    self.assertTrue(np.allclose(output_np, output_np_bs_1))
89
    self.assertTrue(np.allclose(proba_np, proba_np_bs_1))
90

91
  def test_transformer_encoder_decoder(self):
92
    config = self.get_config()
93
    config.stage = 'encoder_decoder'
94

95
    transformer = colorizer.ColTranCore(config=config)
96
    images = tf.random.uniform(shape=(1, 64, 64, 3), minval=0,
97
                               maxval=256, dtype=tf.int32)
98
    logits, enc_logits = transformer(inputs=images, training=True)
99
    enc_logits = enc_logits['encoder_logits']
100
    self.assertEqual(enc_logits.shape, (1, 64, 64, 1, 512))
101
    self.assertEqual(logits.shape, (1, 64, 64, 1, 512))
102

103

104
if __name__ == '__main__':
105
  tf.test.main()
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.