google-research

Форк
0
128 строк · 4.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for core."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import itertools
23
from absl import logging
24
from absl.testing import parameterized
25
from ml_collections import ConfigDict
26
import numpy as np
27
import tensorflow as tf
28
from coltran.models import core
29

30

31
def get_num_variables(model):
32
  var_shapes = [np.prod(variable.shape) for variable in model.variables]
33
  return np.sum(var_shapes)
34

35

36
cond_hparams = itertools.product(["shift", "affine"],
37
                                 [True, False],
38
                                 [True, False],
39
                                 [True, False])
40

41
new_hparams = []
42
for hparams in cond_hparams:
43
  new_args = "_".join([str(hparam) for hparam in hparams])
44
  new_hparams.append((new_args, *hparams))
45

46

47
class ColTranComponentsTest(tf.test.TestCase, parameterized.TestCase):
48

49
  def get_config(self):
50
    config = ConfigDict()
51
    config.hidden_size = 256
52
    config.ff_size = 256
53
    config.image_bit_depth = 5
54
    config.num_symbols = 32
55
    config.num_heads = 4
56
    config.resolution = [8, 8]
57
    config.num_outer_layers = 1
58
    config.num_inner_layers = 3
59
    config.num_encoder_layers = 1
60
    config.batch_size = 2
61
    config.skip = True
62

63
    config.cond_mlp = "affine_dense"
64
    config.cond_mlp_act = "identity"
65

66
    config.cond_ln = True
67
    config.cond_ln_act = "tanh"
68
    config.cond_ln_seq = "cs"
69
    config.cond_ln_sp_ave = "learnable"
70
    config.cond_ln_init = "glorot_uniform"
71

72
    config.cond_att_act = "identity"
73
    config.cond_att_scale = True
74
    config.cond_att_k = True
75
    config.cond_att_q = True
76
    config.cond_att_v = True
77
    return config
78

79
  def test_grayscale_encoder(self):
80
    config = self.get_config()
81
    inputs = tf.random.uniform(shape=(2, 32, 32, 3), minval=0, maxval=256,
82
                               dtype=tf.int32)
83
    gray = tf.image.rgb_to_grayscale(inputs)
84
    encoder = core.GrayScaleEncoder(config)
85
    output = encoder(gray)
86
    self.assertEqual(output.shape, (2, 32, 32, 256))
87

88
  @parameterized.named_parameters(*new_hparams)
89
  def test_inner_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale):
90
    embeddings = tf.random.uniform(shape=(2, 8, 8, 256))
91
    channel_context = tf.random.uniform(shape=(2, 8, 8, 256))
92
    upper_context = tf.random.uniform(shape=(2, 8, 8, 256))
93
    config = self.get_config()
94
    config.cond_mlp = cond_mlp
95
    config.cond_ln = cond_ln
96
    config.cond_att_q = cond_att_q
97
    config.cond_att_scale = cond_att_scale
98

99
    model = core.InnerDecoder(config=config)
100
    output = model(inputs=(embeddings, upper_context, channel_context))
101
    num_vars = get_num_variables(model)
102
    logging.info(num_vars)
103
    self.assertEqual(output.shape, (2, 8, 8, 256))
104

105
  @parameterized.named_parameters(*new_hparams)
106
  def test_outer_decoder(self, cond_mlp, cond_ln, cond_att_q, cond_att_scale):
107
    embeddings = tf.random.uniform(shape=(2, 8, 8, 256))
108
    channel_context = tf.random.uniform(shape=(2, 8, 8, 256))
109
    config = self.get_config()
110
    config.cond_mlp = cond_mlp
111
    config.cond_ln = cond_ln
112
    config.cond_att_q = cond_att_q
113
    config.cond_att_scale = cond_att_scale
114

115
    model = core.OuterDecoder(config=config)
116
    num_vars = get_num_variables(model)
117
    logging.info(num_vars)
118
    upper_context = model(inputs=(embeddings, channel_context))
119
    upper_context_np = upper_context.numpy()
120

121
    # the first row slice should have zero context since both the present
122
    # and future are effectively masked.
123
    self.assertTrue(np.allclose(upper_context_np[:, 0], 0.0))
124
    self.assertEqual(upper_context_np.shape, (2, 8, 8, 256))
125

126

127
if __name__ == "__main__":
128
  tf.test.main()
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.