google-research
105 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for colorizer."""
17from ml_collections import ConfigDict
18import numpy as np
19import tensorflow as tf
20from coltran.models import colorizer
21
22
23class ColTranCoreTest(tf.test.TestCase):
24
25def get_config(self, encoder_net='attention'):
26config = ConfigDict()
27config.image_bit_depth = 3
28config.encoder_1x1 = True
29config.resolution = [64, 64]
30config.batch_size = 2
31config.encoder_net = encoder_net
32config.hidden_size = 128
33config.stage = 'decoder'
34
35config.encoder = ConfigDict()
36config.encoder.dropout = 0.0
37config.encoder.ff_size = 128
38config.encoder.hidden_size = 128
39config.encoder.num_heads = 1
40config.encoder.num_encoder_layers = 1
41
42config.decoder = ConfigDict()
43config.decoder.ff_size = 128
44config.decoder.hidden_size = 128
45config.decoder.num_heads = 1
46config.decoder.num_outer_layers = 1
47config.decoder.num_inner_layers = 1
48config.decoder.resolution = [64, 64]
49config.decoder.dropout = 0.1
50config.decoder.cond_ln = True
51config.decoder.cond_q = True
52config.decoder.cond_k = True
53config.decoder.cond_v = True
54config.decoder.cond_q = True
55config.decoder.cond_scale = True
56config.decoder.cond_mlp = 'affine'
57return config
58
59def test_transformer_attention_encoder(self):
60config = self.get_config(encoder_net='attention')
61config.stage = 'encoder_decoder'
62transformer = colorizer.ColTranCore(config=config)
63images = tf.random.uniform(shape=(2, 2, 2, 3), minval=0,
64maxval=256, dtype=tf.int32)
65logits = transformer(inputs=images, training=True)[0]
66self.assertEqual(logits.shape, (2, 2, 2, 1, 512))
67
68# batch-size=2
69gray = tf.image.rgb_to_grayscale(images)
70output = transformer.sample(gray, mode='argmax')
71output_np = output['auto_argmax'].numpy()
72proba_np = output['proba'].numpy()
73self.assertEqual(output_np.shape, (2, 2, 2, 3))
74self.assertEqual(proba_np.shape, (2, 2, 2, 512))
75# logging.info(output_np[0, ..., 0])
76
77# batch-size=1
78output_np_bs_1, proba_np_bs_1 = [], []
79for batch_ind in [0, 1]:
80curr_gray = tf.expand_dims(gray[batch_ind], axis=0)
81curr_out = transformer.sample(curr_gray, mode='argmax')
82curr_out_np = curr_out['auto_argmax'].numpy()
83curr_proba_np = curr_out['proba'].numpy()
84output_np_bs_1.append(curr_out_np)
85proba_np_bs_1.append(curr_proba_np)
86output_np_bs_1 = np.concatenate(output_np_bs_1, axis=0)
87proba_np_bs_1 = np.concatenate(proba_np_bs_1, axis=0)
88self.assertTrue(np.allclose(output_np, output_np_bs_1))
89self.assertTrue(np.allclose(proba_np, proba_np_bs_1))
90
91def test_transformer_encoder_decoder(self):
92config = self.get_config()
93config.stage = 'encoder_decoder'
94
95transformer = colorizer.ColTranCore(config=config)
96images = tf.random.uniform(shape=(1, 64, 64, 3), minval=0,
97maxval=256, dtype=tf.int32)
98logits, enc_logits = transformer(inputs=images, training=True)
99enc_logits = enc_logits['encoder_logits']
100self.assertEqual(enc_logits.shape, (1, 64, 64, 1, 512))
101self.assertEqual(logits.shape, (1, 64, 64, 1, 512))
102
103
104if __name__ == '__main__':
105tf.test.main()
106