google-research

Форк
0
321 строка · 11.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Slot Attention model for object discovery and set prediction."""
17
import numpy as np
18
import tensorflow as tf
19
import tensorflow.keras.layers as layers
20

21

22
class SlotAttention(layers.Layer):
23
  """Slot Attention module."""
24

25
  def __init__(self, num_iterations, num_slots, slot_size, mlp_hidden_size,
26
               epsilon=1e-8):
27
    """Builds the Slot Attention module.
28

29
    Args:
30
      num_iterations: Number of iterations.
31
      num_slots: Number of slots.
32
      slot_size: Dimensionality of slot feature vectors.
33
      mlp_hidden_size: Hidden layer size of MLP.
34
      epsilon: Offset for attention coefficients before normalization.
35
    """
36
    super().__init__()
37
    self.num_iterations = num_iterations
38
    self.num_slots = num_slots
39
    self.slot_size = slot_size
40
    self.mlp_hidden_size = mlp_hidden_size
41
    self.epsilon = epsilon
42

43
    self.norm_inputs = layers.LayerNormalization()
44
    self.norm_slots = layers.LayerNormalization()
45
    self.norm_mlp = layers.LayerNormalization()
46

47
    # Parameters for Gaussian init (shared by all slots).
48
    self.slots_mu = self.add_weight(
49
        initializer="glorot_uniform",
50
        shape=[1, 1, self.slot_size],
51
        dtype=tf.float32,
52
        name="slots_mu")
53
    self.slots_log_sigma = self.add_weight(
54
        initializer="glorot_uniform",
55
        shape=[1, 1, self.slot_size],
56
        dtype=tf.float32,
57
        name="slots_log_sigma")
58

59
    # Linear maps for the attention module.
60
    self.project_q = layers.Dense(self.slot_size, use_bias=False, name="q")
61
    self.project_k = layers.Dense(self.slot_size, use_bias=False, name="k")
62
    self.project_v = layers.Dense(self.slot_size, use_bias=False, name="v")
63

64
    # Slot update functions.
65
    self.gru = layers.GRUCell(self.slot_size)
66
    self.mlp = tf.keras.Sequential([
67
        layers.Dense(self.mlp_hidden_size, activation="relu"),
68
        layers.Dense(self.slot_size)
69
    ], name="mlp")
70

71
  def call(self, inputs):
72
    # `inputs` has shape [batch_size, num_inputs, inputs_size].
73
    inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
74
    k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].
75
    v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].
76

77
    # Initialize the slots. Shape: [batch_size, num_slots, slot_size].
78
    slots = self.slots_mu + tf.exp(self.slots_log_sigma) * tf.random.normal(
79
        [tf.shape(inputs)[0], self.num_slots, self.slot_size])
80

81
    # Multiple rounds of attention.
82
    for _ in range(self.num_iterations):
83
      slots_prev = slots
84
      slots = self.norm_slots(slots)
85

86
      # Attention.
87
      q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].
88
      q *= self.slot_size ** -0.5  # Normalization.
89
      attn_logits = tf.keras.backend.batch_dot(k, q, axes=-1)
90
      attn = tf.nn.softmax(attn_logits, axis=-1)
91
      # `attn` has shape: [batch_size, num_inputs, num_slots].
92

93
      # Weigted mean.
94
      attn += self.epsilon
95
      attn /= tf.reduce_sum(attn, axis=-2, keepdims=True)
96
      updates = tf.keras.backend.batch_dot(attn, v, axes=-2)
97
      # `updates` has shape: [batch_size, num_slots, slot_size].
98

99
      # Slot update.
100
      slots, _ = self.gru(updates, [slots_prev])
101
      slots += self.mlp(self.norm_mlp(slots))
102

103
    return slots
104

105

106
def spatial_broadcast(slots, resolution):
107
  """Broadcast slot features to a 2D grid and collapse slot dimension."""
108
  # `slots` has shape: [batch_size, num_slots, slot_size].
109
  slots = tf.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
110
  grid = tf.tile(slots, [1, resolution[0], resolution[1], 1])
111
  # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
112
  return grid
113

114

115
def spatial_flatten(x):
116
  return tf.reshape(x, [-1, x.shape[1] * x.shape[2], x.shape[-1]])
117

118

119
def unstack_and_split(x, batch_size, num_channels=3):
120
  """Unstack batch dimension and split into channels and alpha mask."""
121
  unstacked = tf.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
122
  channels, masks = tf.split(unstacked, [num_channels, 1], axis=-1)
123
  return channels, masks
124

125

126
class SlotAttentionAutoEncoder(layers.Layer):
127
  """Slot Attention-based auto-encoder for object discovery."""
128

129
  def __init__(self, resolution, num_slots, num_iterations):
130
    """Builds the Slot Attention-based auto-encoder.
131

132
    Args:
133
      resolution: Tuple of integers specifying width and height of input image.
134
      num_slots: Number of slots in Slot Attention.
135
      num_iterations: Number of iterations in Slot Attention.
136
    """
137
    super().__init__()
138
    self.resolution = resolution
139
    self.num_slots = num_slots
140
    self.num_iterations = num_iterations
141

142
    self.encoder_cnn = tf.keras.Sequential([
143
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
144
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
145
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
146
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
147
    ], name="encoder_cnn")
148

149
    self.decoder_initial_size = (8, 8)
150
    self.decoder_cnn = tf.keras.Sequential([
151
        layers.Conv2DTranspose(
152
            64, 5, strides=(2, 2), padding="SAME", activation="relu"),
153
        layers.Conv2DTranspose(
154
            64, 5, strides=(2, 2), padding="SAME", activation="relu"),
155
        layers.Conv2DTranspose(
156
            64, 5, strides=(2, 2), padding="SAME", activation="relu"),
157
        layers.Conv2DTranspose(
158
            64, 5, strides=(2, 2), padding="SAME", activation="relu"),
159
        layers.Conv2DTranspose(
160
            64, 5, strides=(1, 1), padding="SAME", activation="relu"),
161
        layers.Conv2DTranspose(
162
            4, 3, strides=(1, 1), padding="SAME", activation=None)
163
    ], name="decoder_cnn")
164

165
    self.encoder_pos = SoftPositionEmbed(64, self.resolution)
166
    self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
167

168
    self.layer_norm = layers.LayerNormalization()
169
    self.mlp = tf.keras.Sequential([
170
        layers.Dense(64, activation="relu"),
171
        layers.Dense(64)
172
    ], name="feedforward")
173

174
    self.slot_attention = SlotAttention(
175
        num_iterations=self.num_iterations,
176
        num_slots=self.num_slots,
177
        slot_size=64,
178
        mlp_hidden_size=128)
179

180
  def call(self, image):
181
    # `image` has shape: [batch_size, width, height, num_channels].
182

183
    # Convolutional encoder with position embedding.
184
    x = self.encoder_cnn(image)  # CNN Backbone.
185
    x = self.encoder_pos(x)  # Position embedding.
186
    x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
187
    x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
188
    # `x` has shape: [batch_size, width*height, input_size].
189

190
    # Slot Attention module.
191
    slots = self.slot_attention(x)
192
    # `slots` has shape: [batch_size, num_slots, slot_size].
193

194
    # Spatial broadcast decoder.
195
    x = spatial_broadcast(slots, self.decoder_initial_size)
196
    # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
197
    x = self.decoder_pos(x)
198
    x = self.decoder_cnn(x)
199
    # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
200

201
    # Undo combination of slot and batch dimension; split alpha masks.
202
    recons, masks = unstack_and_split(x, batch_size=image.shape[0])
203
    # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
204
    # `masks` has shape: [batch_size, num_slots, width, height, 1].
205

206
    # Normalize alpha masks over slots.
207
    masks = tf.nn.softmax(masks, axis=1)
208
    recon_combined = tf.reduce_sum(recons * masks, axis=1)  # Recombine image.
209
    # `recon_combined` has shape: [batch_size, width, height, num_channels].
210

211
    return recon_combined, recons, masks, slots
212

213

214
class SlotAttentionClassifier(layers.Layer):
215
  """Slot Attention-based classifier for property prediction."""
216

217
  def __init__(self, resolution, num_slots, num_iterations):
218
    """Builds the Slot Attention-based classifier.
219

220
    Args:
221
      resolution: Tuple of integers specifying width and height of input image.
222
      num_slots: Number of slots in Slot Attention.
223
      num_iterations: Number of iterations in Slot Attention.
224
    """
225
    super().__init__()
226
    self.resolution = resolution
227
    self.num_slots = num_slots
228
    self.num_iterations = num_iterations
229

230
    self.encoder_cnn = tf.keras.Sequential([
231
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
232
        layers.Conv2D(64, kernel_size=5, strides=(2, 2),
233
                      padding="SAME", activation="relu"),
234
        layers.Conv2D(64, kernel_size=5, strides=(2, 2),
235
                      padding="SAME", activation="relu"),
236
        layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
237
    ], name="encoder_cnn")
238

239
    self.encoder_pos = SoftPositionEmbed(64, (32, 32))
240

241
    self.layer_norm = layers.LayerNormalization()
242
    self.mlp = tf.keras.Sequential([
243
        layers.Dense(64, activation="relu"),
244
        layers.Dense(64)
245
    ], name="feedforward")
246

247
    self.slot_attention = SlotAttention(
248
        num_iterations=self.num_iterations,
249
        num_slots=self.num_slots,
250
        slot_size=64,
251
        mlp_hidden_size=128)
252

253
    self.mlp_classifier = tf.keras.Sequential(
254
        [layers.Dense(64, activation="relu"),
255
         layers.Dense(19, activation="sigmoid")],  # Number of targets in CLEVR.
256
        name="mlp_classifier")
257

258
  def call(self, image):
259
    # `image` has shape: [batch_size, width, height, num_channels].
260

261
    # Convolutional encoder with position embedding.
262
    x = self.encoder_cnn(image)  # CNN Backbone.
263
    x = self.encoder_pos(x)  # Position embedding.
264
    x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
265
    x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
266
    # `x` has shape: [batch_size, width*height, input_size].
267

268
    # Slot Attention module.
269
    slots = self.slot_attention(x)
270
    # `slots` has shape: [batch_size, num_slots, slot_size].
271

272
    # Apply classifier per slot. The predictions have shape
273
    # [batch_size, num_slots, set_dimension].
274

275
    predictions = self.mlp_classifier(slots)
276

277
    return predictions
278

279

280
def build_grid(resolution):
281
  ranges = [np.linspace(0., 1., num=res) for res in resolution]
282
  grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
283
  grid = np.stack(grid, axis=-1)
284
  grid = np.reshape(grid, [resolution[0], resolution[1], -1])
285
  grid = np.expand_dims(grid, axis=0)
286
  grid = grid.astype(np.float32)
287
  return np.concatenate([grid, 1.0 - grid], axis=-1)
288

289

290
class SoftPositionEmbed(layers.Layer):
291
  """Adds soft positional embedding with learnable projection."""
292

293
  def __init__(self, hidden_size, resolution):
294
    """Builds the soft position embedding layer.
295

296
    Args:
297
      hidden_size: Size of input feature dimension.
298
      resolution: Tuple of integers specifying width and height of grid.
299
    """
300
    super().__init__()
301
    self.dense = layers.Dense(hidden_size, use_bias=True)
302
    self.grid = build_grid(resolution)
303

304
  def call(self, inputs):
305
    return inputs + self.dense(self.grid)
306

307

308
def build_model(resolution, batch_size, num_slots, num_iterations,
309
                num_channels=3, model_type="object_discovery"):
310
  """Build keras model."""
311
  if model_type == "object_discovery":
312
    model_def = SlotAttentionAutoEncoder
313
  elif model_type == "set_prediction":
314
    model_def = SlotAttentionClassifier
315
  else:
316
    raise ValueError("Invalid name for model type.")
317

318
  image = tf.keras.Input(list(resolution) + [num_channels], batch_size)
319
  outputs = model_def(resolution, num_slots, num_iterations)(image)
320
  model = tf.keras.Model(inputs=image, outputs=outputs)
321
  return model
322

323

324

325

326

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.