google-research
385 строк · 14.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Core components of the colorization transfomer.
17
18Consists of:
19
201. Grayscale Encoder.
212. Outer Decoder.
223. Inner Decoder.
23"""
24
25from __future__ import absolute_import26from __future__ import division27from __future__ import print_function28
29import tensorflow.compat.v2 as tf30from tensorflow.compat.v2.keras import layers31from coltran.models import layers as coltran_layers32from coltran.utils import base_utils33
34
35def cond_with_context(inputs, cond_layer, context, cond_type, cond_act):36cond_act_func = base_utils.act_to_func(cond_act)37cond_out = cond_layer(context)38if cond_type == 'shift':39inputs += cond_out40elif cond_type == 'affine':41shift, scale = tf.split(cond_out, num_or_size_splits=2, axis=-1)42inputs *= cond_act_func(scale)43inputs += cond_act_func(shift)44return inputs45
46
47def get_pos_embeddings(pos_embed, inputs_shape):48embeddings = tf.zeros(shape=inputs_shape)49return pos_embed(embeddings)50
51
52class GrayScaleEncoder(layers.Layer):53"""Encodes grayscale version of the image into a 2-D spatial context.54
55Consists of a stack of row/column attention layers.
56"""
57
58def __init__(self, config, **kwargs):59super(GrayScaleEncoder, self).__init__(**kwargs)60self.config = config61self.dropout = config.get('dropout', 0.0)62
63def build(self, input_shapes):64self.embedding = layers.Dense(units=self.config.hidden_size)65self.encoder = coltran_layers.FactorizedAttention(self.config)66
67def call(self, inputs):68if len(inputs.shape) == 4:69if inputs.shape[-1] != 1:70raise ValueError('Expected inputs is a grayscale image')71grayscale = tf.squeeze(inputs, axis=-1)72grayscale = tf.one_hot(grayscale, depth=256)73h_gray = self.embedding(grayscale)74return self.encoder(h_gray)75
76
77class OuterDecoder(layers.Layer):78"""Outer Decoder with optional conditioning.79
80Contains the following sequence of operations:
811. Positional Embeddings.
822. (Unmasked Row + Masked Column) self attention * num_layers.
833. Shift Down (to preserve causal ordering)
84
85The input is a tuple of 2 arguments (X, h) where h is the conditioning
86input. Transforms the input X into 2-D spatial context C (H, W, D)
87conditioned on h. Each location C[i, j] is a vector of size D that
88summarizes information from X[:i] and h.
89
90The conditional components can be activated by setting the corresponding
91conditional arguments to True.
921. Conditional Layer Norm: config.cond_ln
932. Conditional Self Attention: config.cond_att_k, config.cond_att_q,
94config.cond_att_v, config.cond_att_scale.
953. Conditional MLP: config.cond_mlp
96"""
97
98def __init__(self, config, **kwargs):99super(OuterDecoder, self).__init__(**kwargs)100self.config = config101self.dropout = self.config.get('dropout', 0.0)102self.skip = self.config.get('skip', True)103
104# Conditional MLP105self.cond_mlp = self.config.get('cond_mlp', 'affine')106self.cond_mlp_act = self.config.get('cond_mlp_act', 'identity')107
108# Conditional Layer Norm.109self.cond_ln = self.config.get('cond_ln', True)110self.cond_ln_act = self.config.get('cond_ln_act', 'identity')111self.cond_ln_seq = self.config.get('cond_ln_seq', 'sc')112self.cond_ln_sp_ave = self.config.get('cond_ln_sp_ave', 'learnable')113self.cond_ln_init = self.config.get('cond_ln_init', 'glorot_uniform')114
115# Conditional Self Attention.116self.cond_att_act = self.config.get('cond_att_act', 'identity')117self.cond_att_k = self.config.get('cond_att_k', True)118self.cond_att_q = self.config.get('cond_att_q', True)119self.cond_att_v = self.config.get('cond_att_v', True)120self.cond_att_scale = self.config.get('cond_att_scale', True)121self.cond_att_init = self.config.get('cond_att_init', 'glorot_uniform')122self.cond_att = self.cond_att_v or self.cond_att_q or self.cond_att_k123
124def build(self, input_shapes):125embed_shape = input_shapes[0]126height, width, num_filters = embed_shape[1:]127hidden_size = self.config.hidden_size128num_heads = self.config.num_heads129ff_size = self.config.ff_size130res = [height, width]131
132self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res)133
134self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []135num_norms = self.config.num_outer_layers * 4136if self.cond_ln:137for _ in range(num_norms):138curr_norm = coltran_layers.ConditionalLayerNorm(139spatial_average=self.cond_ln_sp_ave,140sequence=self.cond_ln_seq,141out_init=self.cond_ln_init,142out_act=self.cond_ln_act)143self.layer_norms.append(curr_norm)144else:145self.layer_norms = [layers.LayerNormalization() for _ in range(num_norms)]146
147for layer_ind in range(self.config.num_outer_layers):148# unmasked row149unmask_row = coltran_layers.SelfAttentionND(150hidden_size=hidden_size, num_heads=num_heads,151nd_block_size=[1, width], resolution=[height, width],152cond_q=self.cond_att_q,153cond_k=self.cond_att_k,154cond_v=self.cond_att_v,155cond_init=self.cond_att_init,156cond_scale=self.cond_att_scale,157cond_act=self.cond_att_act,158name='unmask_row_att_%d' % layer_ind)159
160ff_row = tf.keras.Sequential([161layers.Dense(units=ff_size, activation='relu'),162layers.Dense(units=num_filters)163], name='row_dense_%d' % layer_ind)164
165# masked column,166mask_col = coltran_layers.SelfAttentionND(167hidden_size=hidden_size, num_heads=num_heads, mask='future',168nd_block_size=[height, 1], resolution=[height, width],169cond_q=self.cond_att_q,170cond_k=self.cond_att_k,171cond_v=self.cond_att_v,172cond_act=self.cond_att_act,173cond_init=self.cond_att_init,174cond_scale=self.cond_att_scale,175name='mask_col_att_%d' % layer_ind)176
177ff_col = tf.keras.Sequential([178layers.Dense(units=ff_size, activation='relu'),179layers.Dense(units=num_filters)180], name='col_dense_%d' % layer_ind)181
182self.residual_layers.append(unmask_row)183self.residual_layers.append(ff_row)184self.residual_layers.append(mask_col)185self.residual_layers.append(ff_col)186
187# Conditional MLP layers.188if self.cond_mlp == 'shift':189shift_r = layers.Dense(units=hidden_size, name='shift_r_%d' % layer_ind)190shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind)191self.cmlp_layers.append(shift_r)192self.cmlp_layers.append(shift_c)193elif self.cond_mlp == 'affine':194aff_r = layers.Dense(195units=2*hidden_size, name='affine_r_%d' % layer_ind)196aff_c = layers.Dense(197units=2*hidden_size, name='affine_c_%d' % layer_ind)198self.cmlp_layers.append(aff_r)199self.cmlp_layers.append(aff_c)200
201self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)202
203def call(self, inputs, training=True):204embeddings, channel_context = inputs205cond_layer_ind = 0206
207output = self.pos_embed(embeddings)208if self.skip:209output += channel_context210inputs = output211
212for layer, norm in zip(self.residual_layers, self.layer_norms):213if 'att' in layer.name and self.cond_att:214output = layer((inputs, channel_context))215else:216output = layer(inputs)217
218if 'dense' in layer.name and self.cond_mlp:219curr_cond_layer = self.cmlp_layers[cond_layer_ind]220output = cond_with_context(output, curr_cond_layer, channel_context,221self.cond_mlp, self.cond_mlp_act)222cond_layer_ind += 1223
224output = coltran_layers.residual_dropout(225inputs, output, self.dropout, training)226
227if self.cond_ln:228inputs = norm((output, channel_context))229else:230inputs = norm(output)231
232output = self.shift_down(inputs)233return output234
235
236class InnerDecoder(layers.Layer):237
238"""Inner Decoder with optional conditioning.239
240Contains the following sequence of operations:
2411. Adds positional Embeddings + context to the pixel embeddings.
2422. Shift right (to preserve causal order).
2432. (Masked Row) self attention * num_layers.
244
245The input is a tuple of 2 arguments (X, h_out, h) where h_out and h are the
246conditioning inputs from the grayscale image and the outer decoder
247respectively. Transforms the input X into 2-D spatial context C (H, W, D)
248conditioned on h. Each location C[i, j] is a vector of size D that
249summarizes information from X[:i], X[i, :j] and h.
250
251The conditional components can be activated by setting the corresponding
252conditional arguments to True.
2531. Conditional Layer Norm: config.cond_ln
2542. Conditional Self Attention: config.cond_att_k, config.cond_att_q,
255config.cond_att_v, config.cond_att_scale.
2563. Conditional MLP: config.cond_mlp
257"""
258
259def __init__(self,260config,261**kwargs):262super(InnerDecoder, self).__init__(**kwargs)263self.config = config264self.skip = self.config.get('skip', True)265self.dropout = self.config.get('dropout', 0.0)266
267self.cond_mlp = self.config.get('cond_mlp', 'affine')268self.cond_mlp_act = self.config.get('cond_mlp_act', 'identity')269
270self.cond_ln = self.config.get('cond_ln', True)271self.cond_ln_act = self.config.get('cond_ln_act', 'identity')272self.cond_ln_seq = self.config.get('cond_ln_seq', 'sc')273self.cond_ln_sp_ave = self.config.get('cond_ln_sp_ave', 'learnable')274self.cond_ln_init = self.config.get('cond_ln_init', 'glorot_uniform')275
276self.cond_att_act = self.config.get('cond_att_act', 'identity')277self.cond_att_k = self.config.get('cond_att_k', False)278self.cond_att_q = self.config.get('cond_att_q', False)279self.cond_att_v = self.config.get('cond_att_v', False)280self.cond_att_scale = self.config.get('cond_att_scale', False)281self.cond_att_init = self.config.get('cond_att_init', 'glorot_uniform')282self.cond_att = self.cond_att_v or self.cond_att_q or self.cond_att_k283
284def build(self, input_shapes):285context_shape = input_shapes[1]286height, width = context_shape[1:3]287ff_size = self.config.ff_size288hidden_size = self.config.hidden_size289num_heads = self.config.num_heads290res = [height, width]291
292self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res)293self.shift_right = coltran_layers.Shift(dimension=1, resolution=res)294
295self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []296num_norms = 2 * self.config.num_inner_layers297if self.cond_ln:298for _ in range(num_norms):299curr_norm = coltran_layers.ConditionalLayerNorm(300spatial_average=self.cond_ln_sp_ave,301sequence=self.cond_ln_seq,302out_init=self.cond_ln_init,303out_act=self.cond_ln_act)304self.layer_norms.append(curr_norm)305else:306self.layer_norms = [layers.LayerNormalization() for _ in range(num_norms)]307
308for layer_ind in range(self.config.num_inner_layers):309
310mask_row = coltran_layers.SelfAttentionND(311hidden_size=hidden_size, num_heads=num_heads, mask='future',312nd_block_size=[1, width], resolution=[height, width],313cond_q=self.cond_att_q,314cond_k=self.cond_att_k,315cond_v=self.cond_att_v,316cond_init=self.cond_att_init,317cond_scale=self.cond_att_scale,318cond_act=self.cond_att_act,319name='mask_row_att_%d' % layer_ind)320
321ff_block = tf.keras.Sequential([322layers.Dense(units=ff_size, activation='relu'),323layers.Dense(units=hidden_size)324], name='dense_%d' % layer_ind)325
326self.residual_layers.append(mask_row)327self.residual_layers.append(ff_block)328
329if self.cond_mlp == 'shift':330shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind)331self.cmlp_layers.append(shift_c)332elif self.cond_mlp == 'affine':333aff_c = layers.Dense(334units=2*hidden_size, name='affine_c_%d' % layer_ind)335self.cmlp_layers.append(aff_c)336
337def call(self, inputs, row_ind=None, training=True):338embeddings, upper_context, channel_context = inputs339
340embeddings = self.shift_right(embeddings)341if row_ind is None:342embeddings = self.pos_embed(embeddings)343# special case during sampling.344else:345input_shape = embeddings.shape.as_list()346pos_embed = get_pos_embeddings(self.pos_embed, input_shape)347pos_embed = pos_embed[:, row_ind: row_ind + 1]348embeddings += pos_embed349
350inputs = embeddings351if self.skip:352inputs += channel_context353inputs += upper_context354
355layer_zip = zip(self.residual_layers, self.layer_norms)356all_context = tf.concat((channel_context, upper_context), -1)357
358cond_layer_ind = 0359for layer, norm in layer_zip:360
361# Conditional Self-Attention.362if 'att' in layer.name and self.cond_att:363output = layer((inputs, all_context))364else:365output = layer(inputs)366
367# Conditional MLP.368if 'dense' in layer.name and self.cond_mlp:369curr_cond_layer = self.cmlp_layers[cond_layer_ind]370output = cond_with_context(output, curr_cond_layer, all_context,371self.cond_mlp, self.cond_mlp_act)372cond_layer_ind += 1373
374output = coltran_layers.residual_dropout(375inputs, output, self.dropout, training)376
377# providing all context here violates causal masking due to the spatial378# averaging.379# Conditional Layer norm.380if self.cond_ln:381inputs = norm((output, channel_context))382else:383inputs = norm(output)384
385return inputs386