google-research
298 строк · 11.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""ColTran core.
17
18Core autoregressive component of the colorization transformer based on
19the AxialTransformer with conditional self-attention layers.
20
21See Section 3 and Section 4.1 of https://openreview.net/pdf?id=5NA1PinlGFu
22for more details.
23"""
24import tensorflow.compat.v2 as tf
25from tensorflow.compat.v2.keras import layers
26from coltran.models import core
27from coltran.models import layers as coltran_layers
28from coltran.utils import base_utils
29
30
31class ColTranCore(tf.keras.Model):
32"""Colorization Transformer."""
33
34def __init__(self, config, **kwargs):
35super(ColTranCore, self).__init__(**kwargs)
36self.config = config
37
38# 3 bits per channel, 8 colors per channel, a total of 512 colors.
39self.num_symbols_per_channel = 2**3
40self.num_symbols = self.num_symbols_per_channel**3
41self.gray_symbols, self.num_channels = 256, 1
42
43self.enc_cfg = config.encoder
44self.dec_cfg = config.decoder
45self.hidden_size = self.config.get('hidden_size',
46self.dec_cfg.hidden_size)
47
48# stage can be 'encoder_decoder' or 'decoder'
49# 1. decoder -> loss only due to autoregressive model.
50# 2. encoder_decoder -> loss due to both the autoregressive and parallel
51# model.
52# encoder_only and all
53self.stage = config.get('stage', 'decoder')
54self.is_parallel_loss = 'encoder' in self.stage
55stages = ['decoder', 'encoder_decoder']
56if self.stage not in stages:
57raise ValueError('Expected stage to be in %s, got %s' %
58(str(stages), self.stage))
59
60@property
61def metric_keys(self):
62if self.stage == 'encoder_decoder':
63return ['encoder']
64return []
65
66def build(self, input_shape):
67# encoder graph
68self.encoder = core.GrayScaleEncoder(self.enc_cfg)
69if self.is_parallel_loss:
70self.parallel_dense = layers.Dense(
71units=self.num_symbols, name='parallel_logits', use_bias=False)
72
73# decoder graph: outer decoder -> inner decoder -> logits.
74self.pixel_embed_layer = layers.Dense(
75units=self.hidden_size, use_bias=False)
76self.outer_decoder = core.OuterDecoder(self.dec_cfg)
77self.inner_decoder = core.InnerDecoder(self.dec_cfg)
78self.final_dense = layers.Dense(
79units=self.num_symbols, name='auto_logits')
80self.final_norm = layers.LayerNormalization()
81
82def call(self, inputs, training=True):
83# encodes grayscale (H, W) into activations of shape (H, W, 512).
84gray = tf.image.rgb_to_grayscale(inputs)
85z = self.encoder(gray)
86
87if self.is_parallel_loss:
88enc_logits = self.parallel_dense(z)
89enc_logits = tf.expand_dims(enc_logits, axis=-2)
90
91dec_logits = self.decoder(inputs, z, training=training)
92if self.is_parallel_loss:
93return dec_logits, {'encoder_logits': enc_logits}
94return dec_logits, {}
95
96def decoder(self, inputs, z, training):
97"""Decodes grayscale representation and masked colors into logits."""
98# (H, W, 512) preprocessing.
99# quantize to 3 bits.
100labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)
101
102# bin each channel triplet -> (H, W, 3) with 8 possible symbols
103# (H, W, 512)
104labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel)
105
106# (H, W) with 512 symbols to (H, W, 512)
107labels = tf.one_hot(labels, depth=self.num_symbols)
108
109h_dec = self.pixel_embed_layer(labels)
110h_upper = self.outer_decoder((h_dec, z), training=training)
111h_inner = self.inner_decoder((h_dec, h_upper, z), training=training)
112
113activations = self.final_norm(h_inner)
114logits = self.final_dense(activations)
115return tf.expand_dims(logits, axis=-2)
116
117def image_loss(self, logits, labels):
118"""Cross-entropy between the logits and labels."""
119height, width = labels.shape[1:3]
120logits = tf.squeeze(logits, axis=-2)
121loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
122labels=labels, logits=logits)
123loss = tf.reduce_mean(loss, axis=0)
124loss = base_utils.nats_to_bits(tf.reduce_sum(loss))
125return loss / (height * width)
126
127def loss(self, targets, logits, train_config, training, aux_output=None):
128"""Converts targets to coarse colors and computes log-likelihood."""
129downsample = train_config.get('downsample', False)
130downsample_res = train_config.get('downsample_res', 64)
131if downsample:
132labels = targets['targets_%d' % downsample_res]
133else:
134labels = targets['targets']
135
136if aux_output is None:
137aux_output = {}
138
139# quantize labels.
140labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3)
141
142# bin each channel triplet.
143labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel)
144
145loss = self.image_loss(logits, labels)
146enc_logits = aux_output.get('encoder_logits')
147if enc_logits is None:
148return loss, {}
149
150enc_loss = self.image_loss(enc_logits, labels)
151return loss, {'encoder': enc_loss}
152
153def get_logits(self, inputs_dict, train_config, training):
154is_downsample = train_config.get('downsample', False)
155downsample_res = train_config.get('downsample_res', 64)
156if is_downsample:
157inputs = inputs_dict['targets_%d' % downsample_res]
158else:
159inputs = inputs_dict['targets']
160return self(inputs=inputs, training=training)
161
162def sample(self, gray_cond, mode='argmax'):
163output = {}
164
165z_gray = self.encoder(gray_cond, training=False)
166if self.is_parallel_loss:
167z_logits = self.parallel_dense(z_gray)
168parallel_image = tf.argmax(z_logits, axis=-1, output_type=tf.int32)
169parallel_image = self.post_process_image(parallel_image)
170
171output['parallel'] = parallel_image
172
173image, proba = self.autoregressive_sample(z_gray=z_gray, mode=mode)
174output['auto_%s' % mode] = image
175output['proba'] = proba
176return output
177
178def autoregressive_sample(self, z_gray, mode='sample'):
179"""Generates pixel-by-pixel.
180
1811. The encoder is run once per-channel.
1822. The outer decoder is run once per-row.
1833. the inner decoder is run once per-pixel.
184
185The context from the encoder and outer decoder conditions the
186inner decoder. The inner decoder then generates a row, one pixel at a time.
187
188After generating all pixels in a row, the outer decoder is run to recompute
189context. This condtions the inner decoder, which then generates the next
190row, pixel-by-pixel.
191
192Args:
193z_gray: grayscale image.
194mode: sample or argmax.
195
196Returns:
197image: coarse image of shape (B, H, W)
198image_proba: probalities, shape (B, H, W, 512)
199"""
200num_filters = self.config.hidden_size
201batch_size, height, width = z_gray.shape[:3]
202
203# channel_cache[i, j] stores the pixel embedding for row i and col j.
204canvas_shape = (batch_size, height, width, num_filters)
205channel_cache = coltran_layers.Cache(canvas_shape=(height, width))
206init_channel = tf.zeros(shape=canvas_shape)
207init_ind = tf.stack([0, 0])
208channel_cache(inputs=(init_channel, init_ind))
209
210# upper_context[row_ind] stores context from all previously generated rows.
211upper_context = tf.zeros(shape=canvas_shape)
212
213# row_cache[0, j] stores the pixel embedding for the column j of the row
214# under generation. After every row is generated, this is rewritten.
215row_cache = coltran_layers.Cache(canvas_shape=(1, width))
216init_row = tf.zeros(shape=(batch_size, 1, width, num_filters))
217row_cache(inputs=(init_row, init_ind))
218
219pixel_samples, pixel_probas = [], []
220
221for row in range(height):
222row_cond_channel = tf.expand_dims(z_gray[:, row], axis=1)
223row_cond_upper = tf.expand_dims(upper_context[:, row], axis=1)
224row_cache.reset()
225
226gen_row, proba_row = [], []
227for col in range(width):
228
229inner_input = (row_cache.cache, row_cond_upper, row_cond_channel)
230# computes output activations at col.
231activations = self.inner_decoder(inner_input, row_ind=row,
232training=False)
233
234pixel_sample, pixel_embed, pixel_proba = self.act_logit_sample_embed(
235activations, col, mode=mode)
236proba_row.append(pixel_proba)
237gen_row.append(pixel_sample)
238
239# row_cache[:, col] = pixel_embed
240row_cache(inputs=(pixel_embed, tf.stack([0, col])))
241
242# channel_cache[row, col] = pixel_embed
243channel_cache(inputs=(pixel_embed, tf.stack([row, col])))
244
245gen_row = tf.stack(gen_row, axis=-1)
246pixel_samples.append(gen_row)
247pixel_probas.append(tf.stack(proba_row, axis=1))
248
249# after a row is generated, recomputes the context for the next row.
250# upper_context[row] = self_attention(channel_cache[:row_index])
251upper_context = self.outer_decoder(
252inputs=(channel_cache.cache, z_gray), training=False)
253
254image = tf.stack(pixel_samples, axis=1)
255image = self.post_process_image(image)
256
257image_proba = tf.stack(pixel_probas, axis=1)
258return image, image_proba
259
260def act_logit_sample_embed(self, activations, col_ind, mode='sample'):
261"""Converts activations[col_ind] to the output pixel.
262
263Activation -> Logit -> Sample -> Embedding.
264
265Args:
266activations: 5-D Tensor, shape=(batch_size, 1, width, hidden_size)
267col_ind: integer.
268mode: 'sample' or 'argmax'
269Returns:
270pixel_sample: 1-D Tensor, shape=(batch_size, 1, 1)
271pixel_embed: 4-D Tensor, shape=(batch_size, 1, 1, hidden_size)
272pixel_proba: 4-D Tensor, shape=(batch_size, 1, 512)
273"""
274batch_size = activations.shape[0]
275pixel_activation = tf.expand_dims(activations[:, :, col_ind], axis=-2)
276pixel_logits = self.final_dense(self.final_norm(pixel_activation))
277pixel_logits = tf.squeeze(pixel_logits, axis=[1, 2])
278pixel_proba = tf.nn.softmax(pixel_logits, axis=-1)
279
280if mode == 'sample':
281pixel_sample = tf.random.categorical(
282pixel_logits, num_samples=1, dtype=tf.int32)
283pixel_sample = tf.squeeze(pixel_sample, axis=-1)
284elif mode == 'argmax':
285pixel_sample = tf.argmax(pixel_logits, axis=-1, output_type=tf.int32)
286
287pixel_sample_expand = tf.reshape(pixel_sample, [batch_size, 1, 1])
288pixel_one_hot = tf.one_hot(pixel_sample_expand, depth=self.num_symbols)
289pixel_embed = self.pixel_embed_layer(pixel_one_hot)
290return pixel_sample, pixel_embed, pixel_proba
291
292def post_process_image(self, image):
293"""Post process image of size (H, W, 512) to a coarse RGB image."""
294image = base_utils.bins_to_labels(
295image, num_symbols_per_channel=self.num_symbols_per_channel)
296image = base_utils.convert_bits(image, n_bits_in=3, n_bits_out=8)
297image = tf.cast(image, dtype=tf.uint8)
298return image
299