google-research

Форк
0
298 строк · 11.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""ColTran core.
17

18
Core autoregressive component of the colorization transformer based on
19
the AxialTransformer with conditional self-attention layers.
20

21
See Section 3 and Section 4.1 of https://openreview.net/pdf?id=5NA1PinlGFu
22
for more details.
23
"""
24
import tensorflow.compat.v2 as tf
25
from tensorflow.compat.v2.keras import layers
26
from coltran.models import core
27
from coltran.models import layers as coltran_layers
28
from coltran.utils import base_utils
29

30

31
class ColTranCore(tf.keras.Model):
32
  """Colorization Transformer."""
33

34
  def __init__(self, config, **kwargs):
35
    super(ColTranCore, self).__init__(**kwargs)
36
    self.config = config
37

38
    # 3 bits per channel, 8 colors per channel, a total of 512 colors.
39
    self.num_symbols_per_channel = 2**3
40
    self.num_symbols = self.num_symbols_per_channel**3
41
    self.gray_symbols, self.num_channels = 256, 1
42

43
    self.enc_cfg = config.encoder
44
    self.dec_cfg = config.decoder
45
    self.hidden_size = self.config.get('hidden_size',
46
                                       self.dec_cfg.hidden_size)
47

48
    # stage can be 'encoder_decoder' or 'decoder'
49
    # 1. decoder -> loss only due to autoregressive model.
50
    # 2. encoder_decoder -> loss due to both the autoregressive and parallel
51
    # model.
52
    # encoder_only and all
53
    self.stage = config.get('stage', 'decoder')
54
    self.is_parallel_loss = 'encoder' in self.stage
55
    stages = ['decoder', 'encoder_decoder']
56
    if self.stage not in stages:
57
      raise ValueError('Expected stage to be in %s, got %s' %
58
                       (str(stages), self.stage))
59

60
  @property
61
  def metric_keys(self):
62
    if self.stage == 'encoder_decoder':
63
      return ['encoder']
64
    return []
65

66
  def build(self, input_shape):
67
    # encoder graph
68
    self.encoder = core.GrayScaleEncoder(self.enc_cfg)
69
    if self.is_parallel_loss:
70
      self.parallel_dense = layers.Dense(
71
          units=self.num_symbols, name='parallel_logits', use_bias=False)
72

73
    # decoder graph: outer decoder -> inner decoder -> logits.
74
    self.pixel_embed_layer = layers.Dense(
75
        units=self.hidden_size, use_bias=False)
76
    self.outer_decoder = core.OuterDecoder(self.dec_cfg)
77
    self.inner_decoder = core.InnerDecoder(self.dec_cfg)
78
    self.final_dense = layers.Dense(
79
        units=self.num_symbols, name='auto_logits')
80
    self.final_norm = layers.LayerNormalization()
81

82
  def call(self, inputs, training=True):
83
    # encodes grayscale (H, W) into activations of shape (H, W, 512).
84
    gray = tf.image.rgb_to_grayscale(inputs)
85
    z = self.encoder(gray)
86

87
    if self.is_parallel_loss:
88
      enc_logits = self.parallel_dense(z)
89
      enc_logits = tf.expand_dims(enc_logits, axis=-2)
90

91
    dec_logits = self.decoder(inputs, z, training=training)
92
    if self.is_parallel_loss:
93
      return dec_logits, {'encoder_logits': enc_logits}
94
    return dec_logits, {}
95

96
  def decoder(self, inputs, z, training):
97
    """Decodes grayscale representation and masked colors into logits."""
98
    # (H, W, 512) preprocessing.
99
    # quantize to 3 bits.
100
    labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)
101

102
    # bin each channel triplet -> (H, W, 3) with 8 possible symbols
103
    # (H, W, 512)
104
    labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel)
105

106
    # (H, W) with 512 symbols to (H, W, 512)
107
    labels = tf.one_hot(labels, depth=self.num_symbols)
108

109
    h_dec = self.pixel_embed_layer(labels)
110
    h_upper = self.outer_decoder((h_dec, z), training=training)
111
    h_inner = self.inner_decoder((h_dec, h_upper, z), training=training)
112

113
    activations = self.final_norm(h_inner)
114
    logits = self.final_dense(activations)
115
    return tf.expand_dims(logits, axis=-2)
116

117
  def image_loss(self, logits, labels):
118
    """Cross-entropy between the logits and labels."""
119
    height, width = labels.shape[1:3]
120
    logits = tf.squeeze(logits, axis=-2)
121
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
122
        labels=labels, logits=logits)
123
    loss = tf.reduce_mean(loss, axis=0)
124
    loss = base_utils.nats_to_bits(tf.reduce_sum(loss))
125
    return loss / (height * width)
126

127
  def loss(self, targets, logits, train_config, training, aux_output=None):
128
    """Converts targets to coarse colors and computes log-likelihood."""
129
    downsample = train_config.get('downsample', False)
130
    downsample_res = train_config.get('downsample_res', 64)
131
    if downsample:
132
      labels = targets['targets_%d' % downsample_res]
133
    else:
134
      labels = targets['targets']
135

136
    if aux_output is None:
137
      aux_output = {}
138

139
    # quantize labels.
140
    labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3)
141

142
    # bin each channel triplet.
143
    labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel)
144

145
    loss = self.image_loss(logits, labels)
146
    enc_logits = aux_output.get('encoder_logits')
147
    if enc_logits is None:
148
      return loss, {}
149

150
    enc_loss = self.image_loss(enc_logits, labels)
151
    return loss, {'encoder': enc_loss}
152

153
  def get_logits(self, inputs_dict, train_config, training):
154
    is_downsample = train_config.get('downsample', False)
155
    downsample_res = train_config.get('downsample_res', 64)
156
    if is_downsample:
157
      inputs = inputs_dict['targets_%d' % downsample_res]
158
    else:
159
      inputs = inputs_dict['targets']
160
    return self(inputs=inputs, training=training)
161

162
  def sample(self, gray_cond, mode='argmax'):
163
    output = {}
164

165
    z_gray = self.encoder(gray_cond, training=False)
166
    if self.is_parallel_loss:
167
      z_logits = self.parallel_dense(z_gray)
168
      parallel_image = tf.argmax(z_logits, axis=-1, output_type=tf.int32)
169
      parallel_image = self.post_process_image(parallel_image)
170

171
      output['parallel'] = parallel_image
172

173
    image, proba = self.autoregressive_sample(z_gray=z_gray, mode=mode)
174
    output['auto_%s' % mode] = image
175
    output['proba'] = proba
176
    return output
177

178
  def autoregressive_sample(self, z_gray, mode='sample'):
179
    """Generates pixel-by-pixel.
180

181
    1. The encoder is run once per-channel.
182
    2. The outer decoder is run once per-row.
183
    3. the inner decoder is run once per-pixel.
184

185
    The context from the encoder and outer decoder conditions the
186
    inner decoder. The inner decoder then generates a row, one pixel at a time.
187

188
    After generating all pixels in a row, the outer decoder is run to recompute
189
    context. This condtions the inner decoder, which then generates the next
190
    row, pixel-by-pixel.
191

192
    Args:
193
      z_gray: grayscale image.
194
      mode: sample or argmax.
195

196
    Returns:
197
      image: coarse image of shape (B, H, W)
198
      image_proba: probalities, shape (B, H, W, 512)
199
    """
200
    num_filters = self.config.hidden_size
201
    batch_size, height, width = z_gray.shape[:3]
202

203
    # channel_cache[i, j] stores the pixel embedding for row i and col j.
204
    canvas_shape = (batch_size, height, width, num_filters)
205
    channel_cache = coltran_layers.Cache(canvas_shape=(height, width))
206
    init_channel = tf.zeros(shape=canvas_shape)
207
    init_ind = tf.stack([0, 0])
208
    channel_cache(inputs=(init_channel, init_ind))
209

210
    # upper_context[row_ind] stores context from all previously generated rows.
211
    upper_context = tf.zeros(shape=canvas_shape)
212

213
    # row_cache[0, j] stores the pixel embedding for the column j of the row
214
    # under generation. After every row is generated, this is rewritten.
215
    row_cache = coltran_layers.Cache(canvas_shape=(1, width))
216
    init_row = tf.zeros(shape=(batch_size, 1, width, num_filters))
217
    row_cache(inputs=(init_row, init_ind))
218

219
    pixel_samples, pixel_probas = [], []
220

221
    for row in range(height):
222
      row_cond_channel = tf.expand_dims(z_gray[:, row], axis=1)
223
      row_cond_upper = tf.expand_dims(upper_context[:, row], axis=1)
224
      row_cache.reset()
225

226
      gen_row, proba_row = [], []
227
      for col in range(width):
228

229
        inner_input = (row_cache.cache, row_cond_upper, row_cond_channel)
230
        # computes output activations at col.
231
        activations = self.inner_decoder(inner_input, row_ind=row,
232
                                         training=False)
233

234
        pixel_sample, pixel_embed, pixel_proba = self.act_logit_sample_embed(
235
            activations, col, mode=mode)
236
        proba_row.append(pixel_proba)
237
        gen_row.append(pixel_sample)
238

239
        # row_cache[:, col] = pixel_embed
240
        row_cache(inputs=(pixel_embed, tf.stack([0, col])))
241

242
        # channel_cache[row, col] = pixel_embed
243
        channel_cache(inputs=(pixel_embed, tf.stack([row, col])))
244

245
      gen_row = tf.stack(gen_row, axis=-1)
246
      pixel_samples.append(gen_row)
247
      pixel_probas.append(tf.stack(proba_row, axis=1))
248

249
      # after a row is generated, recomputes the context for the next row.
250
      # upper_context[row] = self_attention(channel_cache[:row_index])
251
      upper_context = self.outer_decoder(
252
          inputs=(channel_cache.cache, z_gray), training=False)
253

254
    image = tf.stack(pixel_samples, axis=1)
255
    image = self.post_process_image(image)
256

257
    image_proba = tf.stack(pixel_probas, axis=1)
258
    return image, image_proba
259

260
  def act_logit_sample_embed(self, activations, col_ind, mode='sample'):
261
    """Converts activations[col_ind] to the output pixel.
262

263
    Activation -> Logit -> Sample -> Embedding.
264

265
    Args:
266
      activations: 5-D Tensor, shape=(batch_size, 1, width, hidden_size)
267
      col_ind: integer.
268
      mode: 'sample' or 'argmax'
269
    Returns:
270
      pixel_sample: 1-D Tensor, shape=(batch_size, 1, 1)
271
      pixel_embed: 4-D Tensor, shape=(batch_size, 1, 1, hidden_size)
272
      pixel_proba: 4-D Tensor, shape=(batch_size, 1, 512)
273
    """
274
    batch_size = activations.shape[0]
275
    pixel_activation = tf.expand_dims(activations[:, :, col_ind], axis=-2)
276
    pixel_logits = self.final_dense(self.final_norm(pixel_activation))
277
    pixel_logits = tf.squeeze(pixel_logits, axis=[1, 2])
278
    pixel_proba = tf.nn.softmax(pixel_logits, axis=-1)
279

280
    if mode == 'sample':
281
      pixel_sample = tf.random.categorical(
282
          pixel_logits, num_samples=1, dtype=tf.int32)
283
      pixel_sample = tf.squeeze(pixel_sample, axis=-1)
284
    elif mode == 'argmax':
285
      pixel_sample = tf.argmax(pixel_logits, axis=-1, output_type=tf.int32)
286

287
    pixel_sample_expand = tf.reshape(pixel_sample, [batch_size, 1, 1])
288
    pixel_one_hot = tf.one_hot(pixel_sample_expand, depth=self.num_symbols)
289
    pixel_embed = self.pixel_embed_layer(pixel_one_hot)
290
    return pixel_sample, pixel_embed, pixel_proba
291

292
  def post_process_image(self, image):
293
    """Post process image of size (H, W, 512) to a coarse RGB image."""
294
    image = base_utils.bins_to_labels(
295
        image, num_symbols_per_channel=self.num_symbols_per_channel)
296
    image = base_utils.convert_bits(image, n_bits_in=3, n_bits_out=8)
297
    image = tf.cast(image, dtype=tf.uint8)
298
    return image
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.