google-research

Форк
0
385 строк · 14.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Core components of the colorization transfomer.
17

18
Consists of:
19

20
1. Grayscale Encoder.
21
2. Outer Decoder.
22
3. Inner Decoder.
23
"""
24

25
from __future__ import absolute_import
26
from __future__ import division
27
from __future__ import print_function
28

29
import tensorflow.compat.v2 as tf
30
from tensorflow.compat.v2.keras import layers
31
from coltran.models import layers as coltran_layers
32
from coltran.utils import base_utils
33

34

35
def cond_with_context(inputs, cond_layer, context, cond_type, cond_act):
36
  cond_act_func = base_utils.act_to_func(cond_act)
37
  cond_out = cond_layer(context)
38
  if cond_type == 'shift':
39
    inputs += cond_out
40
  elif cond_type == 'affine':
41
    shift, scale = tf.split(cond_out, num_or_size_splits=2, axis=-1)
42
    inputs *= cond_act_func(scale)
43
    inputs += cond_act_func(shift)
44
  return inputs
45

46

47
def get_pos_embeddings(pos_embed, inputs_shape):
48
  embeddings = tf.zeros(shape=inputs_shape)
49
  return pos_embed(embeddings)
50

51

52
class GrayScaleEncoder(layers.Layer):
53
  """Encodes grayscale version of the image into a 2-D spatial context.
54

55
  Consists of a stack of row/column attention layers.
56
  """
57

58
  def __init__(self, config, **kwargs):
59
    super(GrayScaleEncoder, self).__init__(**kwargs)
60
    self.config = config
61
    self.dropout = config.get('dropout', 0.0)
62

63
  def build(self, input_shapes):
64
    self.embedding = layers.Dense(units=self.config.hidden_size)
65
    self.encoder = coltran_layers.FactorizedAttention(self.config)
66

67
  def call(self, inputs):
68
    if len(inputs.shape) == 4:
69
      if inputs.shape[-1] != 1:
70
        raise ValueError('Expected inputs is a grayscale image')
71
      grayscale = tf.squeeze(inputs, axis=-1)
72
    grayscale = tf.one_hot(grayscale, depth=256)
73
    h_gray = self.embedding(grayscale)
74
    return self.encoder(h_gray)
75

76

77
class OuterDecoder(layers.Layer):
78
  """Outer Decoder with optional conditioning.
79

80
  Contains the following sequence of operations:
81
    1. Positional Embeddings.
82
    2. (Unmasked Row + Masked Column) self attention * num_layers.
83
    3. Shift Down (to preserve causal ordering)
84

85
  The input is a tuple of 2 arguments (X, h) where h is the conditioning
86
  input. Transforms the input X into 2-D spatial context C (H, W, D)
87
  conditioned on h. Each location C[i, j] is a vector of size D that
88
  summarizes information from X[:i] and h.
89

90
  The conditional components can be activated by setting the corresponding
91
  conditional arguments to True.
92
    1. Conditional Layer Norm: config.cond_ln
93
    2. Conditional Self Attention: config.cond_att_k, config.cond_att_q,
94
                                   config.cond_att_v, config.cond_att_scale.
95
    3. Conditional MLP: config.cond_mlp
96
  """
97

98
  def __init__(self, config, **kwargs):
99
    super(OuterDecoder, self).__init__(**kwargs)
100
    self.config = config
101
    self.dropout = self.config.get('dropout', 0.0)
102
    self.skip = self.config.get('skip', True)
103

104
    # Conditional MLP
105
    self.cond_mlp = self.config.get('cond_mlp', 'affine')
106
    self.cond_mlp_act = self.config.get('cond_mlp_act', 'identity')
107

108
    # Conditional Layer Norm.
109
    self.cond_ln = self.config.get('cond_ln', True)
110
    self.cond_ln_act = self.config.get('cond_ln_act', 'identity')
111
    self.cond_ln_seq = self.config.get('cond_ln_seq', 'sc')
112
    self.cond_ln_sp_ave = self.config.get('cond_ln_sp_ave', 'learnable')
113
    self.cond_ln_init = self.config.get('cond_ln_init', 'glorot_uniform')
114

115
    # Conditional Self Attention.
116
    self.cond_att_act = self.config.get('cond_att_act', 'identity')
117
    self.cond_att_k = self.config.get('cond_att_k', True)
118
    self.cond_att_q = self.config.get('cond_att_q', True)
119
    self.cond_att_v = self.config.get('cond_att_v', True)
120
    self.cond_att_scale = self.config.get('cond_att_scale', True)
121
    self.cond_att_init = self.config.get('cond_att_init', 'glorot_uniform')
122
    self.cond_att = self.cond_att_v or self.cond_att_q or self.cond_att_k
123

124
  def build(self, input_shapes):
125
    embed_shape = input_shapes[0]
126
    height, width, num_filters = embed_shape[1:]
127
    hidden_size = self.config.hidden_size
128
    num_heads = self.config.num_heads
129
    ff_size = self.config.ff_size
130
    res = [height, width]
131

132
    self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res)
133

134
    self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
135
    num_norms = self.config.num_outer_layers * 4
136
    if self.cond_ln:
137
      for _ in range(num_norms):
138
        curr_norm = coltran_layers.ConditionalLayerNorm(
139
            spatial_average=self.cond_ln_sp_ave,
140
            sequence=self.cond_ln_seq,
141
            out_init=self.cond_ln_init,
142
            out_act=self.cond_ln_act)
143
        self.layer_norms.append(curr_norm)
144
    else:
145
      self.layer_norms = [layers.LayerNormalization() for _ in range(num_norms)]
146

147
    for layer_ind in range(self.config.num_outer_layers):
148
      # unmasked row
149
      unmask_row = coltran_layers.SelfAttentionND(
150
          hidden_size=hidden_size, num_heads=num_heads,
151
          nd_block_size=[1, width], resolution=[height, width],
152
          cond_q=self.cond_att_q,
153
          cond_k=self.cond_att_k,
154
          cond_v=self.cond_att_v,
155
          cond_init=self.cond_att_init,
156
          cond_scale=self.cond_att_scale,
157
          cond_act=self.cond_att_act,
158
          name='unmask_row_att_%d' % layer_ind)
159

160
      ff_row = tf.keras.Sequential([
161
          layers.Dense(units=ff_size, activation='relu'),
162
          layers.Dense(units=num_filters)
163
      ], name='row_dense_%d' % layer_ind)
164

165
      # masked column,
166
      mask_col = coltran_layers.SelfAttentionND(
167
          hidden_size=hidden_size, num_heads=num_heads, mask='future',
168
          nd_block_size=[height, 1], resolution=[height, width],
169
          cond_q=self.cond_att_q,
170
          cond_k=self.cond_att_k,
171
          cond_v=self.cond_att_v,
172
          cond_act=self.cond_att_act,
173
          cond_init=self.cond_att_init,
174
          cond_scale=self.cond_att_scale,
175
          name='mask_col_att_%d' % layer_ind)
176

177
      ff_col = tf.keras.Sequential([
178
          layers.Dense(units=ff_size, activation='relu'),
179
          layers.Dense(units=num_filters)
180
      ], name='col_dense_%d' % layer_ind)
181

182
      self.residual_layers.append(unmask_row)
183
      self.residual_layers.append(ff_row)
184
      self.residual_layers.append(mask_col)
185
      self.residual_layers.append(ff_col)
186

187
      # Conditional MLP layers.
188
      if self.cond_mlp == 'shift':
189
        shift_r = layers.Dense(units=hidden_size, name='shift_r_%d' % layer_ind)
190
        shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind)
191
        self.cmlp_layers.append(shift_r)
192
        self.cmlp_layers.append(shift_c)
193
      elif self.cond_mlp == 'affine':
194
        aff_r = layers.Dense(
195
            units=2*hidden_size, name='affine_r_%d' % layer_ind)
196
        aff_c = layers.Dense(
197
            units=2*hidden_size, name='affine_c_%d' % layer_ind)
198
        self.cmlp_layers.append(aff_r)
199
        self.cmlp_layers.append(aff_c)
200

201
    self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)
202

203
  def call(self, inputs, training=True):
204
    embeddings, channel_context = inputs
205
    cond_layer_ind = 0
206

207
    output = self.pos_embed(embeddings)
208
    if self.skip:
209
      output += channel_context
210
    inputs = output
211

212
    for layer, norm in zip(self.residual_layers, self.layer_norms):
213
      if 'att' in layer.name and self.cond_att:
214
        output = layer((inputs, channel_context))
215
      else:
216
        output = layer(inputs)
217

218
      if 'dense' in layer.name and self.cond_mlp:
219
        curr_cond_layer = self.cmlp_layers[cond_layer_ind]
220
        output = cond_with_context(output, curr_cond_layer, channel_context,
221
                                   self.cond_mlp, self.cond_mlp_act)
222
        cond_layer_ind += 1
223

224
      output = coltran_layers.residual_dropout(
225
          inputs, output, self.dropout, training)
226

227
      if self.cond_ln:
228
        inputs = norm((output, channel_context))
229
      else:
230
        inputs = norm(output)
231

232
    output = self.shift_down(inputs)
233
    return output
234

235

236
class InnerDecoder(layers.Layer):
237

238
  """Inner Decoder with optional conditioning.
239

240
  Contains the following sequence of operations:
241
    1. Adds positional Embeddings + context to the pixel embeddings.
242
    2. Shift right (to preserve causal order).
243
    2. (Masked Row) self attention * num_layers.
244

245
  The input is a tuple of 2 arguments (X, h_out, h) where h_out and h are the
246
  conditioning inputs from the grayscale image and the outer decoder
247
  respectively. Transforms the input X into 2-D spatial context C (H, W, D)
248
  conditioned on h. Each location C[i, j] is a vector of size D that
249
  summarizes information from X[:i], X[i, :j] and h.
250

251
  The conditional components can be activated by setting the corresponding
252
  conditional arguments to True.
253
    1. Conditional Layer Norm: config.cond_ln
254
    2. Conditional Self Attention: config.cond_att_k, config.cond_att_q,
255
                                   config.cond_att_v, config.cond_att_scale.
256
    3. Conditional MLP: config.cond_mlp
257
  """
258

259
  def __init__(self,
260
               config,
261
               **kwargs):
262
    super(InnerDecoder, self).__init__(**kwargs)
263
    self.config = config
264
    self.skip = self.config.get('skip', True)
265
    self.dropout = self.config.get('dropout', 0.0)
266

267
    self.cond_mlp = self.config.get('cond_mlp', 'affine')
268
    self.cond_mlp_act = self.config.get('cond_mlp_act', 'identity')
269

270
    self.cond_ln = self.config.get('cond_ln', True)
271
    self.cond_ln_act = self.config.get('cond_ln_act', 'identity')
272
    self.cond_ln_seq = self.config.get('cond_ln_seq', 'sc')
273
    self.cond_ln_sp_ave = self.config.get('cond_ln_sp_ave', 'learnable')
274
    self.cond_ln_init = self.config.get('cond_ln_init', 'glorot_uniform')
275

276
    self.cond_att_act = self.config.get('cond_att_act', 'identity')
277
    self.cond_att_k = self.config.get('cond_att_k', False)
278
    self.cond_att_q = self.config.get('cond_att_q', False)
279
    self.cond_att_v = self.config.get('cond_att_v', False)
280
    self.cond_att_scale = self.config.get('cond_att_scale', False)
281
    self.cond_att_init = self.config.get('cond_att_init', 'glorot_uniform')
282
    self.cond_att = self.cond_att_v or self.cond_att_q or self.cond_att_k
283

284
  def build(self, input_shapes):
285
    context_shape = input_shapes[1]
286
    height, width = context_shape[1:3]
287
    ff_size = self.config.ff_size
288
    hidden_size = self.config.hidden_size
289
    num_heads = self.config.num_heads
290
    res = [height, width]
291

292
    self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res)
293
    self.shift_right = coltran_layers.Shift(dimension=1, resolution=res)
294

295
    self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
296
    num_norms = 2 * self.config.num_inner_layers
297
    if self.cond_ln:
298
      for _ in range(num_norms):
299
        curr_norm = coltran_layers.ConditionalLayerNorm(
300
            spatial_average=self.cond_ln_sp_ave,
301
            sequence=self.cond_ln_seq,
302
            out_init=self.cond_ln_init,
303
            out_act=self.cond_ln_act)
304
        self.layer_norms.append(curr_norm)
305
    else:
306
      self.layer_norms = [layers.LayerNormalization() for _ in range(num_norms)]
307

308
    for layer_ind in range(self.config.num_inner_layers):
309

310
      mask_row = coltran_layers.SelfAttentionND(
311
          hidden_size=hidden_size, num_heads=num_heads, mask='future',
312
          nd_block_size=[1, width], resolution=[height, width],
313
          cond_q=self.cond_att_q,
314
          cond_k=self.cond_att_k,
315
          cond_v=self.cond_att_v,
316
          cond_init=self.cond_att_init,
317
          cond_scale=self.cond_att_scale,
318
          cond_act=self.cond_att_act,
319
          name='mask_row_att_%d' % layer_ind)
320

321
      ff_block = tf.keras.Sequential([
322
          layers.Dense(units=ff_size, activation='relu'),
323
          layers.Dense(units=hidden_size)
324
      ], name='dense_%d' % layer_ind)
325

326
      self.residual_layers.append(mask_row)
327
      self.residual_layers.append(ff_block)
328

329
      if self.cond_mlp == 'shift':
330
        shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind)
331
        self.cmlp_layers.append(shift_c)
332
      elif self.cond_mlp == 'affine':
333
        aff_c = layers.Dense(
334
            units=2*hidden_size, name='affine_c_%d' % layer_ind)
335
        self.cmlp_layers.append(aff_c)
336

337
  def call(self, inputs, row_ind=None, training=True):
338
    embeddings, upper_context, channel_context = inputs
339

340
    embeddings = self.shift_right(embeddings)
341
    if row_ind is None:
342
      embeddings = self.pos_embed(embeddings)
343
    # special case during sampling.
344
    else:
345
      input_shape = embeddings.shape.as_list()
346
      pos_embed = get_pos_embeddings(self.pos_embed, input_shape)
347
      pos_embed = pos_embed[:, row_ind: row_ind + 1]
348
      embeddings += pos_embed
349

350
    inputs = embeddings
351
    if self.skip:
352
      inputs += channel_context
353
      inputs += upper_context
354

355
    layer_zip = zip(self.residual_layers, self.layer_norms)
356
    all_context = tf.concat((channel_context, upper_context), -1)
357

358
    cond_layer_ind = 0
359
    for layer, norm in layer_zip:
360

361
      # Conditional Self-Attention.
362
      if 'att' in layer.name and self.cond_att:
363
        output = layer((inputs, all_context))
364
      else:
365
        output = layer(inputs)
366

367
      # Conditional MLP.
368
      if 'dense' in layer.name and self.cond_mlp:
369
        curr_cond_layer = self.cmlp_layers[cond_layer_ind]
370
        output = cond_with_context(output, curr_cond_layer, all_context,
371
                                   self.cond_mlp, self.cond_mlp_act)
372
        cond_layer_ind += 1
373

374
      output = coltran_layers.residual_dropout(
375
          inputs, output, self.dropout, training)
376

377
      # providing all context here violates causal masking due to the spatial
378
      # averaging.
379
      # Conditional Layer norm.
380
      if self.cond_ln:
381
        inputs = norm((output, channel_context))
382
      else:
383
        inputs = norm(output)
384

385
    return inputs
386

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.