google-research
128 строк · 4.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for core."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import itertools
23from absl import logging
24from absl.testing import parameterized
25from ml_collections import ConfigDict
26import numpy as np
27import tensorflow as tf
28from coltran.models import core
29
30
31def get_num_variables(model):
32var_shapes = [np.prod(variable.shape) for variable in model.variables]
33return np.sum(var_shapes)
34
35
36cond_hparams = itertools.product(["shift", "affine"],
37[True, False],
38[True, False],
39[True, False])
40
41new_hparams = []
42for hparams in cond_hparams:
43new_args = "_".join([str(hparam) for hparam in hparams])
44new_hparams.append((new_args, *hparams))
45
46
47class ColTranComponentsTest(tf.test.TestCase, parameterized.TestCase):
48
49def get_config(self):
50config = ConfigDict()
51config.hidden_size = 256
52config.ff_size = 256
53config.image_bit_depth = 5
54config.num_symbols = 32
55config.num_heads = 4
56config.resolution = [8, 8]
57config.num_outer_layers = 1
58config.num_inner_layers = 3
59config.num_encoder_layers = 1
60config.batch_size = 2
61config.skip = True
62
63config.cond_mlp = "affine_dense"
64config.cond_mlp_act = "identity"
65
66config.cond_ln = True
67config.cond_ln_act = "tanh"
68config.cond_ln_seq = "cs"
69config.cond_ln_sp_ave = "learnable"
70config.cond_ln_init = "glorot_uniform"
71
72config.cond_att_act = "identity"
73config.cond_att_scale = True
74config.cond_att_k = True
75config.cond_att_q = True
76config.cond_att_v = True
77return config
78
79def test_grayscale_encoder(self):
80config = self.get_config()
81inputs = tf.random.uniform(shape=(2, 32, 32, 3), minval=0, maxval=256,
82dtype=tf.int32)
83gray = tf.image.rgb_to_grayscale(inputs)
84encoder = core.GrayScaleEncoder(config)
85output = encoder(gray)
86self.assertEqual(output.shape, (2, 32, 32, 256))
87
88@parameterized.named_parameters(*new_hparams)
89def test_inner_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale):
90embeddings = tf.random.uniform(shape=(2, 8, 8, 256))
91channel_context = tf.random.uniform(shape=(2, 8, 8, 256))
92upper_context = tf.random.uniform(shape=(2, 8, 8, 256))
93config = self.get_config()
94config.cond_mlp = cond_mlp
95config.cond_ln = cond_ln
96config.cond_att_q = cond_att_q
97config.cond_att_scale = cond_att_scale
98
99model = core.InnerDecoder(config=config)
100output = model(inputs=(embeddings, upper_context, channel_context))
101num_vars = get_num_variables(model)
102logging.info(num_vars)
103self.assertEqual(output.shape, (2, 8, 8, 256))
104
105@parameterized.named_parameters(*new_hparams)
106def test_outer_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale):
107embeddings = tf.random.uniform(shape=(2, 8, 8, 256))
108channel_context = tf.random.uniform(shape=(2, 8, 8, 256))
109config = self.get_config()
110config.cond_mlp = cond_mlp
111config.cond_ln = cond_ln
112config.cond_att_q = cond_att_q
113config.cond_att_scale = cond_att_scale
114
115model = core.OuterDecoder(config=config)
116num_vars = get_num_variables(model)
117logging.info(num_vars)
118upper_context = model(inputs=(embeddings, channel_context))
119upper_context_np = upper_context.numpy()
120
121# the first row slice should have zero context since both the present
122# and future are effectively masked.
123self.assertTrue(np.allclose(upper_context_np[:, 0], 0.0))
124self.assertEqual(upper_context_np.shape, (2, 8, 8, 256))
125
126
127if __name__ == "__main__":
128tf.test.main()
129